Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def train(args: argparse.Namespace):
if args.dry_run:
# Modify arguments so that we write to a temporary directory and
# perform 0 training iterations
temp_dir = tempfile.TemporaryDirectory() # Will be automatically removed
args.output = temp_dir.name
args.max_updates = 0
utils.seed_rngs(args.seed)
check_arg_compatibility(args)
output_folder = os.path.abspath(args.output)
resume_training = check_resume(args, output_folder)
global logger
logger = setup_main_logger(__name__,
file_logging=True,
console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
utils.log_basic_info(args)
arguments.save_args(args, os.path.join(output_folder, C.ARGS_STATE_NAME))
max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
with ExitStack() as exit_stack:
context = utils.determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
def prepare_vocab(args: argparse.Namespace):
num_words, num_words_other = args.num_words
num_words = num_words if num_words > 0 else None
num_words_other = num_words_other if num_words_other > 0 else None
utils.check_condition(num_words == num_words_other,
"Vocabulary CLI only allows a common value for --num-words")
word_min_count, word_min_count_other = args.word_min_count
utils.check_condition(word_min_count == word_min_count_other,
"Vocabulary CLI only allows a common value for --word-min-count")
setup_main_logger(file_logging=not args.no_logfile, console=not args.quiet,
path="%s.%s" % (args.output, C.LOG_NAME))
vocab = build_from_paths(args.inputs,
num_words=num_words,
min_count=word_min_count,
pad_to_multiple_of=args.pad_vocab_to_multiple_of)
logger.info("Vocabulary size: %d ", len(vocab))
vocab_to_json(vocab, args.output)
def train(args: argparse.Namespace):
# TODO: make training compatible with full net
args.image_preextracted_features = True # override this for now
utils.seed_rngs(args.seed)
check_arg_compatibility(args)
output_folder = os.path.abspath(args.output)
resume_training = check_resume(args, output_folder)
setup_main_logger(file_logging=True,
console=not args.quiet, path=os.path.join(output_folder, C.LOG_NAME))
utils.log_basic_info(args)
with open(os.path.join(output_folder, C.ARGS_STATE_NAME), "w") as fp:
json.dump(vars(args), fp)
max_seq_len_source, max_seq_len_target = args.max_seq_len
# The maximum length is the length before we add the BOS/EOS symbols
max_seq_len_source = max_seq_len_source + C.SPACE_FOR_XOS
max_seq_len_target = max_seq_len_target + C.SPACE_FOR_XOS
logger.info("Adjusting maximum length to reserve space for a BOS/EOS marker. New maximum length: (%d, %d)",
max_seq_len_source, max_seq_len_target)
with ExitStack() as exit_stack:
context = utils.determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
def create(args):
setup_main_logger(console=not args.quiet, file_logging=not args.no_logfile, path=args.output + ".log")
global logger
logger = logging.getLogger('create')
log_sockeye_version(logger)
logger.info("Creating top-k lexicon from \"%s\"", args.input)
logger.info("Reading source and target vocab from \"%s\"", args.model)
vocab_source = vocab.load_source_vocabs(args.model)[0]
vocab_target = vocab.load_target_vocab(args.model)
logger.info("Building top-%d lexicon", args.k)
lexicon = TopKLexicon(vocab_source, vocab_target)
lexicon.create(args.input, args.k)
lexicon.save(args.output)
def prepare_data(args: argparse.Namespace):
output_folder = os.path.abspath(args.output)
os.makedirs(output_folder, exist_ok=True)
setup_main_logger(console=not args.quiet,
file_logging=not args.no_logfile,
path=os.path.join(output_folder, C.LOG_NAME))
utils.seed_rngs(args.seed)
minimum_num_shards = args.min_num_shards
samples_per_shard = args.num_samples_per_shard
bucketing = not args.no_bucketing
bucket_width = args.bucket_width
source_paths = [args.source] + args.source_factors
source_factor_vocab_paths = [args.source_factor_vocabs[i] if i < len(args.source_factor_vocabs)
else None for i in range(len(args.source_factors))]
source_vocab_paths = [args.source_vocab] + source_factor_vocab_paths
num_words_source, num_words_target = args.num_words
def caption(args: argparse.Namespace):
image_preextracted_features = not args.extract_image_features
if args.output is not None:
setup_main_logger(console=not args.quiet,
file_logging=True,
path="%s.%s" % (args.output, C.LOG_NAME))
else:
setup_main_logger(file_logging=False)
if args.checkpoints is not None:
check_condition(len(args.checkpoints) == len(args.models),
"must provide checkpoints for each model")
log_basic_info(args)
out_handler = output_handler.get_output_handler(args.output_type,
args.output,
args.sure_align_threshold)
with ExitStack() as exit_stack:
context = determine_context(device_ids=args.device_ids,
use_cpu=args.use_cpu,
disable_device_locking=args.disable_device_locking,
lock_dir=args.lock_dir,
def main():
"""
Command-line tool to inspect model embeddings.
"""
setup_main_logger(file_logging=False)
params = argparse.ArgumentParser(description='Shows nearest neighbours of input tokens in the embedding space.')
params.add_argument('--model', '-m', required=True,
help='Model folder to load config from.')
params.add_argument('--checkpoint', '-c', required=False, type=int, default=None,
help='Optional specific checkpoint to load parameters from. Best params otherwise.')
params.add_argument('--side', '-s', required=True, choices=['source', 'target'], help='what embeddings to look at')
params.add_argument('--norm', '-n', action='store_true', help='normalize embeddings to unit length')
params.add_argument('-k', type=int, default=5, help='Number of neighbours to print')
params.add_argument('--gamma', '-g', type=float, default=1.0, help='Softmax distribution steepness.')
args = params.parse_args()
embeddings(args)
def main():
"""
Commandline interface to extract parameters.
"""
setup_main_logger(console=True, file_logging=False)
params = argparse.ArgumentParser(description="Extract specific parameters.")
arguments.add_extract_args(params)
args = params.parse_args()
extract_parameters(args)
def run_translate(args: argparse.Namespace):
# Seed randomly unless a seed has been passed
utils.seed_rngs(args.seed if args.seed is not None else int(time.time()))
if args.output is not None:
setup_main_logger(console=not args.quiet,
file_logging=not args.no_logfile,
path="%s.%s" % (args.output, C.LOG_NAME),
level=args.loglevel)
else:
setup_main_logger(file_logging=False, level=args.loglevel)
log_basic_info(args)
if args.nbest_size > 1:
if args.output_type != C.OUTPUT_HANDLER_JSON:
logger.warning("For nbest translation, you must specify `--output-type '%s'; overriding your setting of '%s'.",
C.OUTPUT_HANDLER_JSON, args.output_type)
args.output_type = C.OUTPUT_HANDLER_JSON
output_handler = get_output_handler(args.output_type,
args.output,
args.sure_align_threshold)