Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
"""
# read training config
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), 'model.json'))
# load trained model parameters
logging.info('reading model parameters from ' + model_path)
# To be compatible with v.0.3.0 models
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
raise ValueError("use '--api v2' option to decode with non-default language model")
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(
len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
rnnlm.predictor, word_dict, char_dict))
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor,
word_dict, char_dict))
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
# assert isinstance(model, STInterface)
# TODO(hirofumi0810) fix this for after supporting Transformer
args.ctc_weight = 0.0
model.trans_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") != "default":
raise ValueError("use '--api v2' option to decode with non-default language model")
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(
len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
# TODO(ruizhili): implement enhance for multi-encoder model
assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format(args.num_encs)
# load trained model parameters
logging.info('reading model parameters from ' + args.model)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
torch_load(args.model, model)
model.recog_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
# read json data
with open(args.recog_json, 'rb') as f:
js = json.load(f)['utts']
load_inputs_and_targets = LoadInputsAndTargets(
mode='asr', load_output=False, sort_in_input_length=False,
preprocess_conf=None # Apply pre_process in outer func
)
Args:
model_path(str): Path to model.***.best
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), 'model.json'))
logging.warning('reading model parameters from ' + model_path)
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
trainer.extend(extensions.PlotReport(['main/bleu', 'validation/main/bleu'],
'epoch', file_name='cer.png'))
# Save best models
trainer.extend(snapshot_object(model, 'model.loss.best'),
trigger=training.triggers.MinValueTrigger('validation/main/loss'))
trainer.extend(snapshot_object(model, 'model.acc.best'),
trigger=training.triggers.MaxValueTrigger('validation/main/acc'))
# save snapshot which contains model and optimizer states
trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))
# epsilon decay in the optimizer
if args.opt == 'adadelta':
if args.criterion == 'acc':
trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
trigger=CompareValueTrigger(
'validation/main/acc',
lambda best_value, current_value: best_value > current_value))
trainer.extend(adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
'validation/main/acc',
lambda best_value, current_value: best_value > current_value))
elif args.criterion == 'loss':
trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
trigger=CompareValueTrigger(
'validation/main/loss',
lambda best_value, current_value: best_value < current_value))
trainer.extend(adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
'validation/main/loss',
lambda best_value, current_value: best_value < current_value))
'epoch', file_name='cer.png'))
# Save best models
trainer.extend(snapshot_object(model, 'model.loss.best'),
trigger=training.triggers.MinValueTrigger('validation/main/loss'))
if mtl_mode != 'ctc':
trainer.extend(snapshot_object(model, 'model.acc.best'),
trigger=training.triggers.MaxValueTrigger('validation/main/acc'))
# save snapshot which contains model and optimizer states
trainer.extend(torch_snapshot(), trigger=(1, 'epoch'))
# epsilon decay in the optimizer
if args.opt == 'adadelta':
if args.criterion == 'acc' and mtl_mode != 'ctc':
trainer.extend(restore_snapshot(model, args.outdir + '/model.acc.best', load_fn=torch_load),
trigger=CompareValueTrigger(
'validation/main/acc',
lambda best_value, current_value: best_value > current_value))
trainer.extend(adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
'validation/main/acc',
lambda best_value, current_value: best_value > current_value))
elif args.criterion == 'loss':
trainer.extend(restore_snapshot(model, args.outdir + '/model.loss.best', load_fn=torch_load),
trigger=CompareValueTrigger(
'validation/main/loss',
lambda best_value, current_value: best_value < current_value))
trainer.extend(adadelta_eps_decay(args.eps_decay),
trigger=CompareValueTrigger(
'validation/main/loss',
lambda best_value, current_value: best_value < current_value))
def _load_teacher_model(self, model_path):
# get teacher model config
idim, odim, args = get_model_conf(model_path)
# assert dimension is the same between teacher and studnet
assert idim == self.idim
assert odim == self.odim
assert args.reduction_factor == self.reduction_factor
# load teacher model
model = Transformer(idim, odim, args)
torch_load(model_path, model)
# freeze teacher model parameters
for p in model.parameters():
p.requires_grad = False
return model
"""Decode with the given args
:param Namespace args: The program arguments
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, MTInterface)
model.recog_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
if rnnlm:
rnnlm.cuda()
# read json data
with open(args.recog_json, 'rb') as f:
js = json.load(f)['utts']
new_js = {}