Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# load dict
word_ids_dict = load_dict(word_dict_path)
label_ids_dict = load_dict(label_dict_path)
# read data to index
word_ids = vectorize_data(train_word_path, word_ids_dict)
label_ids = vectorize_data(train_label_path, label_ids_dict)
max_len = np.max([len(i) for i in word_ids])
print('max_len:', max_len)
# pad sequence
word_seq = pad_sequence(word_ids, maxlen=maxlen)
label_seq = pad_sequence(label_ids, maxlen=maxlen)
# reshape label for crf model use
label_seq = np.reshape(label_seq, (label_seq.shape[0], label_seq.shape[1], 1))
print(word_seq.shape)
print(label_seq.shape)
logger.info("Data loaded.")
# model
logger.info("Training BILSTM_CRF model...")
model = create_model(word_ids_dict, label_ids_dict,
embedding_dim, rnn_hidden_dim, dropout)
# callback
callbacks_list = callback(save_model_path, logger)
# fit
model.fit(word_seq,
label_seq,
batch_size=batch_size,
epochs=epoch,
validation_split=0.2,
callbacks=callbacks_list)
logger.info("Training has finished.")
def load_word_dict(save_path):
dict_data = dict()
with open(save_path, 'r', encoding='utf-8') as f:
for line in f:
items = line.strip().split()
try:
dict_data[items[0]] = int(items[1])
except IndexError:
logger.error('error', line)
return dict_data
# Zero-pad up to the sequence length.
padding = [0] * (max_seq_length - len(input_ids))
input_ids += padding
input_mask += padding
segment_ids += padding
assert len(input_ids) == max_seq_length
assert len(input_mask) == max_seq_length
assert len(segment_ids) == max_seq_length
if example_index < 5:
logger.info("*** Example ***")
logger.info("example_index: %s" % (example_index))
logger.info("guid: %s" % (example.guid))
logger.info("tokens: %s" % " ".join(
[str(x) for x in tokens]))
logger.info("input_ids: %s" % " ".join([str(x) for x in input_ids]))
logger.info("input_mask: %s" % " ".join([str(x) for x in input_mask]))
logger.info("segment_ids: %s" % " ".join([str(x) for x in segment_ids]))
features.append(
InputFeatures(input_ids=input_ids,
input_mask=input_mask,
mask_positions=mask_positions,
segment_ids=segment_ids,
input_tokens=tokens))
# Mask each word
# features = create_sequential_mask(tokens, input_ids, input_mask, segment_ids, mask_id, tokenizer)
# all_features.extend(features)
# all_tokens.extend(tokens)
# return all_features, all_tokens
device = torch.device('cpu')
print('device:', device)
# load vocab
self.vocab2id = load_word_dict(vocab_path)
self.id2vocab = {v: k for k, v in self.vocab2id.items()}
logger.debug('Loaded vocabulary file:%s, size: %s' % (vocab_path, len(self.vocab2id)))
# load model
start_time = time.time()
self.model = self._create_model(self.vocab2id, device)
if use_gpu:
self.model.load_state_dict(torch.load(model_path))
else:
# 把所有的张量加载到CPU中
self.model.load_state_dict(torch.load(model_path, map_location=lambda storage, loc: storage))
logger.info("Loaded model:%s, spend:%s s" % (model_path, time.time() - start_time))
self.model.eval()
self.src_seq_lens = src_seq_lens
self.trg_seq_lens = trg_seq_lens
self.beam_size = beam_size
self.batch_size = batch_size
self.device = device
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip)
optimizer.step()
end_time = time.time()
losses.append([
epoch, batch_id,
loss.data.cpu().numpy(),
loss_cv.data.cpu().numpy()[0],
(end_time - start_time) / 3600.0])
if batch_id % save_model_batch_num == 0:
model_path = os.path.join(save_model_dir, 'seq2seq_' + str(epoch) + '_' + str(batch_id) + '.model')
with open(model_path, 'wb') as f:
torch.save(model.state_dict(), f)
logger.info("Model save to " + model_path)
if batch_id % 1 == 0:
end_time = time.time()
sen_pred = [id2vocab[x] if x in id2vocab else ext_id2oov[x] for x in word_prob[0]]
print('epoch={}, batch={}, loss={}, loss_cv={}, time_escape={}s={}h'.format(
epoch,
batch_id,
loss.data.cpu().numpy(),
loss_cv.data.cpu().numpy()[0],
end_time - start_time, (end_time - start_time) / 3600.0
))
print(' '.join(sen_pred))
del logits, attn_, p_gen, loss_cv, loss
with open(os.path.join(save_model_dir, 'loss.txt'), 'a', encoding='utf-8') as f:
for i in losses:
epoch,
batch_id,
loss.data.cpu().numpy(),
loss_cv.data.cpu().numpy()[0],
end_time - start_time, (end_time - start_time) / 3600.0
))
print(' '.join(sen_pred))
del logits, attn_, p_gen, loss_cv, loss
with open(os.path.join(save_model_dir, 'loss.txt'), 'a', encoding='utf-8') as f:
for i in losses:
f.write(str(i) + '\n')
model_path = os.path.join(save_model_dir, 'seq2seq_' + str(epoch) + '_' + str(batch_id) + '.model')
with open(model_path, 'wb') as f:
torch.save(model.state_dict(), f)
logger.info("Model save to " + model_path)
last_model_path = model_path
logger.info("Training has finished.")
# Eval model
eval(model, last_model_path, val_path, output_dir, batch_size, vocab2id, src_seq_lens, trg_seq_lens, device)
logger.info("Eval has finished.")