Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_csv_dataset_quotechar(self):
# Based on issue #349
example_data = [("text", "label"),
('" hello world', "0"),
('goodbye " world', "1"),
('this is a pen " ', "0")]
with tempfile.NamedTemporaryFile(dir=self.test_dir) as f:
for example in example_data:
f.write(six.b("{}\n".format(",".join(example))))
TEXT = data.Field(lower=True, tokenize=lambda x: x.split())
fields = {
"label": ("label", data.Field(use_vocab=False,
sequential=False)),
"text": ("text", TEXT)
}
f.seek(0)
dataset = data.TabularDataset(
path=f.name, format="csv",
skip_header=False, fields=fields,
csv_reader_params={"quotechar": None})
TEXT.build_vocab(dataset)
self.assertEqual(len(dataset), len(example_data) - 1)
for i, example in enumerate(dataset):
def test_dataset_split_arguments(self):
num_examples, num_labels = 30, 3
self.write_test_splitting_dataset(num_examples=num_examples,
num_labels=num_labels)
text_field = data.Field()
label_field = data.LabelField()
fields = [('text', text_field), ('label', label_field)]
dataset = data.TabularDataset(
path=self.test_dataset_splitting_path, format="csv", fields=fields)
# Test default split ratio (0.7)
expected_train_size = 21
expected_test_size = 9
train, test = dataset.split()
assert len(train) == expected_train_size
assert len(test) == expected_test_size
# Test array arguments with same ratio
split_ratio = [0.7, 0.3]
path=self.test_has_header_dataset_path, format=format_,
skip_header=False, fields=fields)
TEXT.build_vocab(dataset)
for i, example in enumerate(dataset):
self.assertEqual(example.text,
example_with_header[i + 1][0].lower().split())
self.assertEqual(example.label, example_with_header[i + 1][1])
# check that the vocabulary is built correctly (#225)
expected_freqs = {"hello": 1, "world": 2, "goodbye": 1, "text": 0}
for k, v in expected_freqs.items():
self.assertEqual(TEXT.vocab.freqs[k], v)
data_iter = data.Iterator(dataset, batch_size=1,
sort_within_batch=False, repeat=False)
next(data_iter.__iter__())
def test_build_vocab_from_dataset(self):
nesting_field = data.Field(tokenize=list, unk_token="", pad_token="",
init_token="", eos_token="")
CHARS = data.NestedField(nesting_field, init_token="<s>", eos_token="</s>")
ex1 = data.Example.fromlist(["aaa bbb c"], [("chars", CHARS)])
ex2 = data.Example.fromlist(["bbb aaa"], [("chars", CHARS)])
dataset = data.Dataset([ex1, ex2], [("chars", CHARS)])
CHARS.build_vocab(dataset, min_freq=2)
expected = "a b <s> </s> ".split()
assert len(CHARS.vocab) == len(expected)
for c in expected:
assert c in CHARS.vocab.stoi
expected_freqs = Counter({"a": 6, "b": 6, "c": 1})
assert CHARS.vocab.freqs == CHARS.nesting_field.vocab.freqs == expected_freqs
def test_vocab_specials_first(self):
c = Counter("a a b b c c".split())
# add specials into vocabulary at first
v = vocab.Vocab(c, max_size=2, specials=['', ''])
expected_itos = ['', '', 'a', 'b']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)
# add specials into vocabulary at last
v = vocab.Vocab(c, max_size=2, specials=['', ''], specials_first=False)
expected_itos = ['a', 'b', '', '']
expected_stoi = {x: index for index, x in enumerate(expected_itos)}
self.assertEqual(v.itos, expected_itos)
self.assertEqual(dict(v.stoi), expected_stoi)
def filter_init(ex_val1, ex_val2, ex_val3):
text_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = [("text1", text_field), ("text2", text_field),
("label", label_field)]
example1 = data.Example.fromlist(ex_val1, fields)
example2 = data.Example.fromlist(ex_val2, fields)
example3 = data.Example.fromlist(ex_val3, fields)
examples = [example1, example2, example3]
dataset = data.Dataset(examples, fields)
text_field.build_vocab(dataset)
return dataset, text_field
def test_errors(self):
# Ensure that trying to retrieve a key not in JSON data errors
self.write_test_ppid_dataset(data_format="json")
question_field = data.Field(sequential=True)
label_field = data.Field(sequential=False)
fields = {"qeustion1": ("q1", question_field),
"question2": ("q2", question_field),
"label": ("label", label_field)}
with self.assertRaises(ValueError):
data.TabularDataset(
path=self.test_ppid_dataset_path, format="json", fields=fields)
def test_targetfield_specials(self):
test_path = os.path.dirname(os.path.realpath(__file__))
data_path = os.path.join(test_path, 'data/eng-fra.txt')
field = TargetField()
train = torchtext.data.TabularDataset(
path=data_path, format='tsv',
fields=[('src', torchtext.data.Field()), ('trg', field)]
)
self.assertTrue(field.sos_id is None)
self.assertTrue(field.eos_id is None)
field.build_vocab(train)
self.assertFalse(field.sos_id is None)
self.assertFalse(field.eos_id is None)
def udpos_dataset(batch_size):
# Setup fields with batch dimension first
inputs = data.Field(init_token="", eos_token="", batch_first=True)
tags = data.Field(init_token="", eos_token="", batch_first=True)
# Download and the load default data.
train, val, test = datasets.UDPOS.splits(
fields=(('inputs_word', inputs), ('labels', tags), (None, None)))
# Build vocab
inputs.build_vocab(train.inputs)
tags.build_vocab(train.tags)
# Get iterators
train_iter, val_iter, test_iter = data.BucketIterator.splits(
(train, val, test), batch_size=batch_size,
device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu"))
train_iter.repeat = False
return train_iter, val_iter, test_iter, inputs, tags
from nntoolbox.utils import get_device
from nntoolbox.sequence.models import LanguageModel
from nntoolbox.sequence.learner import LanguageModelLearner
from nntoolbox.sequence.components import AdditiveContextEmbedding
from nntoolbox.sequence.utils import load_embedding
from torch import nn
from torch.optim import Adam
import torch
from nntoolbox.callbacks import *
from nntoolbox.metrics import *
MAX_VOCAB_SIZE = 25000
BATCH_SIZE = 16
TEXT = data.Field(tokenize='spacy')
LABEL = data.LabelField(dtype=torch.float)
# train_iterator, val_iterator, test_iterator = WikiText2.iters()
# for tmp in train_iterator:
# print(tmp)
train_data, val_data, test_data = WikiText2.splits(TEXT)
train_iterator = data.BPTTIterator(
train_data,
batch_size=BATCH_SIZE,
sort_within_batch=True,
device=get_device(),
bptt_len=35,
shuffle=True
)