Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def test_qangaroo(self):
from parlai.core.params import ParlaiParser
from parlai.tasks.qangaroo.agents import DefaultTeacher
opt = ParlaiParser().parse_args(args=self.args)
opt['datatype'] = 'train'
teacher = DefaultTeacher(opt)
reply = teacher.act()
check(opt, reply)
shutil.rmtree(self.TMP_PATH)
def test_file_inference(self):
"""
Test --inference with older model files.
"""
testing_utils.download_unittest_models()
with testing_utils.capture_output():
pp = ParlaiParser(True, True)
opt = pp.parse_args(
['--model-file', 'zoo:unittest/transformer_generator2/model']
)
agent = create_agent(opt, True)
self.assertEqual(agent.opt['inference'], 'greedy')
with testing_utils.capture_output():
pp = ParlaiParser(True, True)
opt = pp.parse_args(
[
'--model-file',
'zoo:unittest/transformer_generator2/model',
'--beam-size',
'5',
],
print_args=False,
)
agent = create_agent(opt, True)
self.assertEqual(agent.opt['inference'], 'beam')
def setup_args(parser=None):
if parser is None:
parser = ParlaiParser(True, True, 'Evaluate a model')
parser.add_pytorch_datateacher_args()
# Get command line arguments
parser.add_argument(
'-rf',
'--report-filename',
type=str,
default='',
help='Saves a json file of the evaluation report either as an '
'extension to the model-file (if begins with a ".") or a whole '
'file path. Set to the empty string to not save at all.',
)
parser.add_argument(
'--save-world-logs',
type='bool',
default=False,
help='Saves a jsonl file containing all of the task examples and '
# Copyright (c) 2017-present, Facebook, Inc.
# All rights reserved.
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. An additional grant
# of patent rights can be found in the PATENTS file in the same directory.
from parlai.core.params import ParlaiParser
from core import manage_hit
argparser = ParlaiParser(False, False)
argparser.add_parlai_data_path()
argparser.add_mturk_args()
opt = argparser.parse_args()
task_module_name = 'parlai.mturk.tasks.' + opt['task']
Agent = __import__(task_module_name+'.agents', fromlist=['']).default_agent_class
task_config = __import__(task_module_name+'.task_config', fromlist=['']).task_config
print("Creating HIT tasks for "+task_module_name+" ...")
manage_hit.create_hits(
opt=opt,
task_config=task_config,
task_module_name=task_module_name,
bot=Agent(opt=opt),
def setup_args(parser=None):
if parser is None:
parser = ParlaiParser(True, True, 'compute statistics from model predictions')
parser.add_pytorch_datateacher_args()
DictionaryAgent.add_cmdline_args(parser)
# Get command line arguments
parser.add_argument('-ne', '--num-examples', type=int, default=-1)
parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
parser.add_argument(
'-ed',
'--external-dict',
type=str,
default=None,
help='External dictionary for stat computation',
)
parser.add_argument(
'-fb',
'--freq-bins',
type=str,
def main():
completed_workers = []
argparser = ParlaiParser(False, False)
argparser.add_parlai_data_path()
argparser.add_mturk_args()
opt = argparser.parse_args()
opt['task'] = os.path.basename(os.path.dirname(os.path.abspath(__file__)))
opt.update(task_config)
mturk_agent_id = 'Worker'
mturk_manager = MTurkManager(opt=opt, mturk_agent_ids=[mturk_agent_id])
mturk_manager.setup_server()
qual_name = 'ParlAIExcludeQual{}t{}'.format(
random.randint(10000, 99999), random.randint(10000, 99999)
)
qual_desc = (
'Qualification for a worker not correctly completing the '
'first iteration of a task. Used to filter to different task pools.'
)
def interactive(opt, print_parser=None):
if print_parser is not None:
if print_parser is True and isinstance(opt, ParlaiParser):
print_parser = opt
elif print_parser is False:
print_parser = None
if isinstance(opt, ParlaiParser):
print('[ Deprecated Warning: interactive should be passed opt not Parser ]')
opt = opt.parse_args()
opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
# Set the task to dialog, since that's the type we want its outputs to be
print("Warning: hardcoding history_size=2")
opt['override'] = {
'no_cuda': True,
'subtasks': ['dialog', 'sentiment'],
'interactive': True,
'prev_response_filter': True,
'person_tokens': True,
'history_size': 2,
def interactive(opt, print_parser=None):
if print_parser is not None:
if print_parser is True and isinstance(opt, ParlaiParser):
print_parser = opt
elif print_parser is False:
print_parser = None
if isinstance(opt, ParlaiParser):
print('[ Deprecated Warning: interactive should be passed opt not Parser ]')
opt = opt.parse_args()
opt['task'] = 'parlai.agents.local_human.local_human:LocalHumanAgent'
# Create model and assign it to the specified task
agent = create_agent(opt, requireModelExists=True)
world = create_task(opt, agent)
if print_parser:
# Show arguments after loading model
print_parser.opt = agent.opt
print_parser.print_args()
# Create ConvAI2 data so we can assign personas.
convai2_opt = opt.copy()
convai2_opt['task'] = 'convai2:both'
convai2_agent = RepeatLabelAgent(convai2_opt)
convai2_world = create_task(convai2_opt, convai2_agent)
def setup_args(parser=None):
if parser is None:
parser = ParlaiParser(True, True, 'Interactive chat with a model')
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.add_argument(
'--display-prettify',
type='bool',
default=False,
help='Set to use a prettytable when displaying '
'examples with text candidates',
)
parser.add_argument(
'--display-ignore-fields',
type=str,
default='label_candidates,text_candidates',
help='Do not display these fields',
)
parser.set_defaults(model_file='models:convai2/kvmemnn/model')
LocalHumanAgent.add_cmdline_args(parser)
def setup_args(parser=None):
if parser is None:
parser = ParlaiParser(True, True, 'Lint for ParlAI tasks')
# Get command line arguments
parser.add_argument('-ltim', '--log-every-n-secs', type=float, default=2)
parser.add_argument('-d', '--display-examples', type='bool', default=False)
parser.set_defaults(datatype='train:stream')
return parser