Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
with testing_utils.capture_output() as _:
modfn = os.path.join(tmp, 'model')
with open(modfn, 'w') as f:
f.write('Test.')
optfn = modfn + '.opt'
base_opt = {
'model': 'tests.test_params:_ExampleUpgradeOptAgent',
'dict_file': modfn + '.dict',
'model_file': modfn,
}
with open(optfn, 'w') as f:
json.dump(base_opt, f)
pp = ParlaiParser(True, True)
opt = pp.parse_args(['--model-file', modfn])
agents.create_agent(opt)
raise RuntimeError("Please specify a model file")
if opt.get("fixed_cands_path") is None:
fcp = os.path.join(
"/".join(opt.get("model_file").split("/")[:-1]), "candidates.txt"
)
opt["fixed_cands_path"] = fcp
opt["override"]["fixed_cands_path"] = fcp
opt["task"] = "parlai.agents.local_human.local_human:LocalHumanAgent"
opt["image_mode"] = "resnet152"
opt["no_cuda"] = True
opt["override"]["no_cuda"] = True
SHARED["opt"] = opt
SHARED["image_loader"] = ImageLoader(opt)
# Create model and assign it to the specified task
SHARED["agent"] = create_agent(opt, requireModelExists=True)
SHARED["world"] = create_task(opt, SHARED["agent"])
# Dialog History
SHARED["dialog_history"] = []
def _load_model(self):
"""
Load model if necessary.
"""
if 'model_file' in self.opt or 'model' in self.opt:
self.runner_opt['shared_bot_params'] = create_agent(self.runner_opt).share()
def run(opt):
opt = copy.deepcopy(opt)
_path(opt)
opt['model_file'] = opt['tourist_model_file']
tourist = create_agent(opt)
opt['model_file'] = opt['guide_model_file']
guide = create_agent(opt)
world = SimulateWorld(opt, [tourist, guide])
if opt.get('numthreads', 1) > 1:
# use hogwild world if more than one thread requested
# hogwild world will create sub batch worlds as well if bsz > 1
world = HogwildWorld(opt, world)
elif opt.get('batchsize', 1) > 1:
# otherwise check if should use batchworld
world = BatchWorld(opt, world)
# Show some example dialogs:
cnt = 0
while not world.epoch_done():
def run(opt):
opt = copy.deepcopy(opt)
_path(opt)
opt['model_file'] = opt['tourist_model_file']
tourist = create_agent(opt)
opt['model_file'] = opt['guide_model_file']
guide = create_agent(opt)
world = SimulateWorld(opt, [tourist, guide])
if opt.get('numthreads', 1) > 1:
# use hogwild world if more than one thread requested
# hogwild world will create sub batch worlds as well if bsz > 1
world = HogwildWorld(opt, world)
elif opt.get('batchsize', 1) > 1:
# otherwise check if should use batchworld
world = BatchWorld(opt, world)
# Show some example dialogs:
cnt = 0
while not world.epoch_done():
cnt += opt.get('batchsize', 1)
world.parley()
def __train_single_model(opt):
"""Train single model.
opt is a dictionary returned by arg_parse
"""
# Create model and assign it to the specified task
agent = create_agent(opt)
world = create_task(opt, agent)
print('[ training... ]')
train_dict = {'train_time': Timer(),
'validate_time': Timer(),
'log_time': Timer(),
'new_epoch': None,
'epochs_done': 0,
'max_exs': opt['num_epochs'] * len(world),
'total_exs': 0,
'parleys': 0,
'max_parleys': math.ceil(opt['num_epochs'] * len(world) / opt['batchsize']),
'best_metrics': opt['chosen_metrics'],
'best_metrics_value': 0,
'impatience': 0,
'lr_drop_impatience': 0,
'prev_response_filter': True,
'person_tokens': True,
'history_size': 2,
'eval_candidates': 'fixed',
'fixed_candidates_path': 'data/convai2_cands.txt',
'fixed_candidate_vecs': opt['fixed_candidate_vecs'],
# Pull these from current opt dictionary
'rating_frequency': opt['rating_frequency'],
'rating_gap': opt['rating_gap'],
'rating_threshold': opt['rating_threshold'],
'request_explanation': opt['request_explanation'],
'request_rating': opt['request_rating'],
}
# Create model and assign it to the specified task
agent = create_agent(opt, requireModelExists=True)
world = create_task(opt, agent)
if print_parser:
# Show arguments after loading model
print_parser.opt = agent.opt
print_parser.print_args()
# Show some example dialogs:
while True:
world.parley()
if opt.get('display_examples'):
print("---")
print(world.display())
if world.epoch_done():
print("EPOCH DONE")
break
kwargs is interpreted by appending '--' to it and replacing underscores
with hyphens, so 'dict_file=/tmp/dict.tsv' would be interpreted as
'--dict-file /tmp/dict.tsv'.
"""
from parlai.core.params import ParlaiParser
from parlai.core.agents import create_agent
if args is None:
args = []
for k, v in kwargs.items():
args.append('--' + str(k).replace('_', '-'))
args.append(str(v))
parser = ParlaiParser(True, True)
self.opt = parser.parse_args(args)
self.agent = create_agent(self.opt)
def interactive_rank(opt, print_parser=None):
# Create model and assign it to the specified task
human = create_agent(opt)
task = create_task_agent_from_taskname(opt)[0]
metrics = Metrics(opt)
episodes = 0
def print_metrics():
report = metrics.report()
report['episodes'] = episodes
print(report)
# Show some example dialogs:
try:
while not task.epoch_done():
msg = task.act()
print('[{id}]: {text}'.format(id=task.getID(), text=msg.get('text', '')))
cands = list(msg.get('label_candidates', []))
def _set_up_tfidf_retriever(self, opt):
retriever_opt = {
'model_file': opt['retriever_model_file'],
'remove_title': False,
'datapath': opt['datapath'],
'override': {'remove_title': False},
}
self.retriever = create_agent(retriever_opt)
self._set_up_sent_tok()
wiki_map_path = os.path.join(self.model_path, 'chosen_topic_to_passage.json')
self.wiki_map = json.load(open(wiki_map_path, 'r'))