Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
# get input and output dimension info
with open(args.valid_json, 'rb') as f:
valid_json = json.load(f)['utts']
utts = list(valid_json.keys())
idim = int(valid_json[utts[0]]['output'][1]['shape'][1])
odim = int(valid_json[utts[0]]['output'][0]['shape'][1])
logging.info('#input dims : ' + str(idim))
logging.info('#output dims: ' + str(odim))
# specify model architecture
model_class = dynamic_import(args.model_module)
model = model_class(idim, odim, args)
assert isinstance(model, MTInterface)
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + '/model.json'
with open(model_conf, 'wb') as f:
logging.info('writing a model config file to ' + model_conf)
f.write(json.dumps((idim, odim, vars(args)),
indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
for key in sorted(vars(args).keys()):
logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))
model = model_class(idim, odim, args)
assert isinstance(model, ASRInterface)
freeze_mode = False
if args.resume is None and \
(args.enc_init is not None or args.dec_init is not None):
model = load_pretrained_modules(model, args.rnnt_mode,
args.enc_init, args.dec_init,
args.enc_init_mods, args.dec_init_mods)
if args.freeze_modules:
freeze_mode = freeze_modules(model, args.freeze_modules)
subsampling_factor = model.subsample[0]
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + '/model.json'
with open(model_conf, 'wb') as f:
logging.info('writing a model config file to ' + model_conf)
f.write(json.dumps((idim, odim, vars(args)), indent=4,
sort_keys=True).encode('utf_8'))
for key in sorted(vars(args).keys()):
logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
if getattr(rnnlm_args, "model_module", "default") != "default":
raise ValueError("use '--api v2' option to decode with non-default language model")
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_pytorch.ClassifierWithState(lm_pytorch.RNNLM(
len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.word_rnnlm, word_rnnlm)
word_rnnlm.eval()
if rnnlm is not None:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.MultiLevelLM(word_rnnlm.predictor,
rnnlm.predictor, word_dict, char_dict))
else:
rnnlm = lm_pytorch.ClassifierWithState(
extlm_pytorch.LookAheadWordLM(word_rnnlm.predictor,
word_dict, char_dict))
def _load_teacher_model(self, model_path):
# get teacher model config
idim, odim, args = get_model_conf(model_path)
# assert dimension is the same between teacher and studnet
assert idim == self.idim
assert odim == self.odim
assert args.reduction_factor == self.reduction_factor
# load teacher model
model = Transformer(idim, odim, args)
torch_load(model_path, model)
# freeze teacher model parameters
for p in model.parameters():
p.requires_grad = False
return model
if args.mtlalpha == 1.0:
mtl_mode = 'ctc'
logging.info('Pure CTC mode')
elif args.mtlalpha == 0.0:
mtl_mode = 'att'
logging.info('Pure attention mode')
else:
mtl_mode = 'mtl'
logging.info('Multitask learning mode')
# specify model architecture
model = E2E(idim, odim, args)
subsampling_factor = model.subsample[0]
if args.rnnlm is not None:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch.load(args.rnnlm, rnnlm)
model.rnnlm = rnnlm
# write model config
if not os.path.exists(args.outdir):
os.makedirs(args.outdir)
model_conf = args.outdir + '/model.json'
with open(model_conf, 'wb') as f:
logging.info('writing a model config file to ' + model_conf)
f.write(json.dumps((idim, odim, vars(args)),
indent=4, ensure_ascii=False, sort_keys=True).encode('utf_8'))
for key in sorted(vars(args).keys()):
logging.info('ARGS: ' + key + ': ' + str(vars(args)[key]))
logging.info('reading model parameters from ' + args.model)
e2e = E2E(idim, odim, train_args)
model = Loss(e2e, train_args.mtlalpha)
chainer_load(args.model, model)
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_chainer.ClassifierWithState(lm_chainer.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
chainer_load(args.rnnlm, rnnlm)
else:
rnnlm = None
if args.word_rnnlm:
rnnlm_args = get_model_conf(args.word_rnnlm, args.word_rnnlm_conf)
word_dict = rnnlm_args.char_list_dict
char_dict = {x: i for i, x in enumerate(train_args.char_list)}
word_rnnlm = lm_chainer.ClassifierWithState(lm_chainer.RNNLM(
len(word_dict), rnnlm_args.layer, rnnlm_args.unit))
chainer_load(args.word_rnnlm, word_rnnlm)
if rnnlm is not None:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.MultiLevelLM(word_rnnlm.predictor,
rnnlm.predictor, word_dict, char_dict))
else:
rnnlm = lm_chainer.ClassifierWithState(
extlm_chainer.LookAheadWordLM(word_rnnlm.predictor,
word_dict, char_dict))
# read json data
def trans(args):
"""Decode with the given args
:param Namespace args: The program arguments
"""
set_deterministic_pytorch(args)
model, train_args = load_trained_model(args.model)
assert isinstance(model, MTInterface)
model.recog_args = args
# read rnnlm
if args.rnnlm:
rnnlm_args = get_model_conf(args.rnnlm, args.rnnlm_conf)
rnnlm = lm_pytorch.ClassifierWithState(
lm_pytorch.RNNLM(
len(train_args.char_list), rnnlm_args.layer, rnnlm_args.unit))
torch_load(args.rnnlm, rnnlm)
rnnlm.eval()
else:
rnnlm = None
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
if rnnlm:
rnnlm.cuda()
def load_trained_model(model_path):
"""Load the trained model for recognition.
Args:
model_path(str): Path to model.***.best
"""
idim, odim, train_args = get_model_conf(
model_path, os.path.join(os.path.dirname(model_path), 'model.json'))
logging.warning('reading model parameters from ' + model_path)
if hasattr(train_args, "model_module"):
model_module = train_args.model_module
else:
model_module = "espnet.nets.pytorch_backend.e2e_asr:E2E"
model_class = dynamic_import(model_module)
model = model_class(idim, odim, train_args)
torch_load(model_path, model)
return model, train_args
def enchance(args):
"""Dumping enhanced speech and mask
Args:
args (Namespace): The program arguments
"""
set_deterministic_pytorch(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
# load trained model parameters
logging.info('reading model parameters from ' + args.model)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
torch_load(args.model, model)
model.recog_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))
model.cuda()
# read json data
def enhance(args):
"""Dumping enhanced speech and mask.
Args:
args (namespace): The program arguments.
"""
set_deterministic_pytorch(args)
# read training config
idim, odim, train_args = get_model_conf(args.model, args.model_conf)
# TODO(ruizhili): implement enhance for multi-encoder model
assert args.num_encs == 1, "number of encoder should be 1 ({} is given)".format(args.num_encs)
# load trained model parameters
logging.info('reading model parameters from ' + args.model)
model_class = dynamic_import(train_args.model_module)
model = model_class(idim, odim, train_args)
assert isinstance(model, ASRInterface)
torch_load(args.model, model)
model.recog_args = args
# gpu
if args.ngpu == 1:
gpu_id = list(range(args.ngpu))
logging.info('gpu id: ' + str(gpu_id))