Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
args.dataset_type = dataset_type
args.save_dir = os.path.join(args.save_dir, dataset_name)
args.num_folds = num_folds
args.metric = metric
if features_dir is not None:
args.features_path = [os.path.join(features_dir, dataset_name + '.pckl')]
modify_train_args(args)
# Set up logging for training
os.makedirs(args.save_dir, exist_ok=True)
fh = logging.FileHandler(os.path.join(args.save_dir, args.log_name))
fh.setLevel(logging.DEBUG)
# Cross validate
TRAIN_LOGGER.addHandler(fh)
mean_score, std_score = cross_validate(args, TRAIN_LOGGER)
TRAIN_LOGGER.removeHandler(fh)
# Record results
logger.info(f'{mean_score} +/- {std_score} {metric}')
temp_model = build_model(args)
logger.info(f'num params: {param_count(temp_model):,}')
optimize_hyperparameters(args)
# Determine best hyperparameters, update args, and train
results = load_sorted_results(args.results_dir)
config = results[0]
config.pop('loss')
print('Best config')
pprint(config)
for key, value in config.items():
setattr(args, key, value)
args.data_path = args.train_val_save
args.separate_test_set = None
args.split_sizes = [0.8, 0.2, 0.0] # no need for a test set during training
cross_validate(args, logger)
# Predict on test data
args.checkpoint_dir = args.save_dir
update_args_from_checkpoint_dir(args)
args.compound_names = True # only if test set has compound names
args.ensemble_size = 5 # might want to make this an arg somehow (w/o affecting hyperparameter optimization)
make_predictions(args)
for key in INT_KEYS:
hyperparams[key] = int(hyperparams[key])
# Update args with hyperparams
hyper_args = deepcopy(args)
if args.save_dir is not None:
folder_name = '_'.join([f'{key}_{value}' if key in INT_KEYS else f'{key}_{value}' for key, value in hyperparams.items()])
hyper_args.save_dir = os.path.join(hyper_args.save_dir, folder_name)
for key, value in hyperparams.items():
setattr(hyper_args, key, value)
# Record hyperparameters
logger.info(hyperparams)
# Cross validate
mean_score, std_score = cross_validate(hyper_args, train_logger)
# Record results
temp_model = build_model(hyper_args)
num_params = param_count(temp_model)
logger.info(f'num params: {num_params:,}')
logger.info(f'{mean_score} +/- {std_score} {hyper_args.metric}')
results.append({
'mean_score': mean_score,
'std_score': std_score,
'hyperparams': hyperparams,
'num_params': num_params
})
# Deal with nan
if np.isnan(mean_score):
import logging
from chemprop.parsing import parse_train_args
from chemprop.train import cross_validate
from chemprop.utils import set_logger
# Initialize logger
logger = logging.getLogger('train')
logger.setLevel(logging.DEBUG)
logger.propagate = False
if __name__ == '__main__':
args = parse_train_args()
set_logger(logger, args.save_dir, args.quiet)
cross_validate(args, logger)