Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
num_iters = len(data) if args.last_batch else len(data) // args.batch_size * args.batch_size
if args.parallel_featurization:
batch_queue = Queue(args.batch_queue_max_size)
exit_queue = Queue(1)
batch_process = Process(target=async_mol2graph, args=(batch_queue, data, args, num_iters, args.batch_size, exit_queue, args.last_batch))
batch_process.start()
currently_loaded_batches = []
iter_size = 1 if args.maml else args.batch_size
for i in trange(0, num_iters, iter_size):
if args.moe:
if not args.batch_domain_encs:
model.compute_domain_encs(train_smiles) # want to recompute every batch
mol_batch = [MoleculeDataset(d[i:i + args.batch_size]) for d in data]
train_batch, train_targets = [], []
for b in mol_batch:
tb, tt = b.smiles(), b.targets()
train_batch.append(tb)
train_targets.append(tt)
test_batch = test_smiles[i:i + args.batch_size]
loss = model.compute_loss(train_batch, train_targets, test_batch)
model.zero_grad()
loss_sum += loss.item()
iter_count += len(mol_batch)
elif args.maml:
task_train_data, task_test_data, task_idx = data.sample_maml_task(args)
mol_batch = task_test_data
smiles_batch, features_batch, target_batch = task_train_data.smiles(), task_train_data.features(), task_train_data.targets(task_idx)
# no mask since we only picked data points that have the desired target
raise ValueError(f'similarity_measure "{args.similarity_measure}" not supported or not implemented yet.')
# Load model and scalers
model = load_checkpoint(args.checkpoint_path)
scaler, features_scaler = load_scalers(args.checkpoint_path)
data.normalize_features(features_scaler)
# Random seed
if args.seed is not None:
random.seed(args.seed)
# Generate visualizations
for i in trange(args.num_examples):
# Get random three molecules with similar properties
index = random.randint(1, len(data) - 2)
molecules = MoleculeDataset(data[index - 1:index + 2])
molecule_targets = [t[args.task_index] for t in molecules.targets()]
# Encode three molecules
molecule_encodings = model.encoder(molecules.smiles())
# Define interpolation
def predict_property(point: List[int]) -> float:
# Return true value on endpoints of triangle
argmax = np.argmax(point)
if point[argmax] == 1:
return molecule_targets[argmax]
# Interpolate and predict task value
encoding = sum(point[j] * molecule_encodings[j] for j in range(len(molecule_encodings)))
pred = model.ffn(encoding).data.cpu().numpy()
pred = scaler.inverse_transform(pred)
def async_mol2graph(q: Queue,
data: MoleculeDataset,
args: Namespace,
num_iters: int,
iter_size: int,
exit_q: Queue,
last_batch: bool=False):
batches = []
for i in range(0, num_iters, iter_size): # will only go up to max size of queue, then yield
if not last_batch and i + args.batch_size > len(data):
break
batch = MoleculeDataset(data[i:i + args.batch_size])
batches.append(batch)
if len(batches) == args.batches_per_queue_group: # many at a time, since synchronization is expensive
with Pool() as pool:
processed_batches = pool.map(mol2graph_helper, [(batch, args) for batch in batches])
q.put(processed_batches)
batches = []
if len(batches) > 0:
with Pool() as pool:
processed_batches = pool.map(mol2graph_helper, [(batch, args) for batch in batches])
q.put(processed_batches)
exit_q.get() # prevent from exiting until main process tells it to; otherwise we apparently can't read the end of the queue and crash
# Learning rate schedulers
scheduler = build_lr_scheduler(optimizer, args)
# Run training
best_score = float('inf') if args.minimize_score else -float('inf')
best_epoch, n_iter = 0, 0
for epoch in trange(args.epochs):
debug(f'Epoch {epoch}')
if args.prespecified_chunk_dir is not None:
# load some different random chunks each epoch
train_data, val_data = load_prespecified_chunks(args, logger)
debug('Loaded prespecified chunks for epoch')
if args.dataset_type == 'unsupervised': # won't work with moe
full_data = MoleculeDataset(train_data.data + val_data.data)
generate_unsupervised_cluster_labels(build_model(args), full_data, args) # cluster with a new random init
model.create_ffn(args) # reset the ffn since we're changing targets-- we're just pretraining the encoder.
optimizer.param_groups.pop() # remove ffn parameters
optimizer.add_param_group({'params': model.ffn.parameters(), 'lr': args.init_lr[1], 'weight_decay': args.weight_decay[1]})
if args.cuda:
model.ffn.cuda()
if args.gradual_unfreezing:
if epoch % args.epochs_per_unfreeze == 0:
unfroze_layer = model.unfreeze_next() # consider just stopping early after we have nothing left to unfreeze?
if unfroze_layer:
debug('Unfroze last frozen layer')
n_iter = train(
model=model,
data=train_data,
loss = loss.sum() / len(smiles_batch)
grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad])
theta = [p for p in model.named_parameters() if p[1].requires_grad] # comes in same order as grad
theta_prime = {p[0]: p[1] - args.maml_lr * grad[i] for i, p in enumerate(theta)}
for name, nongrad_param in [p for p in model.named_parameters() if not p[1].requires_grad]:
theta_prime[name] = nongrad_param + torch.zeros(nongrad_param.size()).to(nongrad_param)
else:
# Prepare batch
if args.parallel_featurization:
if len(currently_loaded_batches) == 0:
currently_loaded_batches = batch_queue.get()
mol_batch, featurized_mol_batch = currently_loaded_batches.pop()
else:
if not args.last_batch and i + args.batch_size > len(data):
break
mol_batch = MoleculeDataset(data[i:i + args.batch_size])
smiles_batch, features_batch, target_batch = mol_batch.smiles(), mol_batch.features(), mol_batch.targets()
if args.dataset_type == 'bert_pretraining':
batch = mol2graph(smiles_batch, args)
mask = mol_batch.mask()
batch.bert_mask(mask)
mask = 1 - torch.FloatTensor(mask) # num_atoms
features_targets = torch.FloatTensor(target_batch['features']) if target_batch['features'] is not None else None # num_molecules x features_size
targets = torch.FloatTensor(target_batch['vocab']) # num_atoms
if args.bert_vocab_func == 'feature_vector':
mask = mask.reshape(-1, 1)
else:
targets = targets.long()
else:
batch = smiles_batch
mask = torch.Tensor([[x is not None for x in tb] for tb in target_batch])
# Update args with training arguments
for key, value in vars(train_args).items():
if not hasattr(args, key):
setattr(args, key, value)
print('Loading data')
if smiles is not None:
test_data = get_data_from_smiles(smiles=smiles, skip_invalid_smiles=False)
else:
test_data = get_data(path=args.test_path, args=args, use_compound_names=args.compound_names, skip_invalid_smiles=False)
print('Validating SMILES')
valid_indices = [i for i in range(len(test_data)) if test_data[i].mol is not None]
full_data = test_data
test_data = MoleculeDataset([test_data[i] for i in valid_indices])
# Edge case if empty list of smiles is provided
if len(test_data) == 0:
return [None] * len(full_data)
test_smiles = test_data.smiles()
if args.compound_names:
compound_names = test_data.compound_names()
print(f'Test size = {len(test_data):,}')
# Normalize features
if train_args.features_scaling:
test_data.normalize_features(features_scaler)
# Predict with each model individually and sum predictions
return [0] # rest of this is meaningless when unsupervised
# Evaluate on test set using model with best validation score
info(f'Model {model_idx} best validation {args.metric} = {best_score:.6f} on epoch {best_epoch}')
model = load_checkpoint(os.path.join(save_dir, 'model.pt'), cuda=args.cuda, logger=logger)
if args.split_test_by_overlap_dataset is not None:
overlap_data = get_data(path=args.split_test_by_overlap_dataset, logger=logger)
overlap_smiles = set(overlap_data.smiles())
test_data_intersect, test_data_nonintersect = [], []
for d in test_data.data:
if d.smiles in overlap_smiles:
test_data_intersect.append(d)
else:
test_data_nonintersect.append(d)
test_data_intersect, test_data_nonintersect = MoleculeDataset(test_data_intersect), MoleculeDataset(test_data_nonintersect)
for name, td in [('Intersect', test_data_intersect), ('Nonintersect', test_data_nonintersect)]:
test_preds = predict(
model=model,
data=td,
args=args,
scaler=scaler,
logger=logger
)
test_scores = evaluate_predictions(
preds=test_preds,
targets=td.targets(),
metric_func=metric_func,
dataset_type=args.dataset_type,
args=args,
logger=logger
)
for i in trange(0, num_iters, iter_step):
if args.maml:
task_train_data, task_test_data, task_idx = data.sample_maml_task(args, seed=0)
mol_batch = task_test_data
smiles_batch, features_batch, targets_batch = task_train_data.smiles(), task_train_data.features(), task_train_data.targets(task_idx)
targets = torch.Tensor(targets_batch).unsqueeze(1)
if args.cuda:
targets = targets.cuda()
else:
# Prepare batch
if args.parallel_featurization:
if len(currently_loaded_batches) == 0:
currently_loaded_batches = batch_queue.get()
mol_batch, featurized_mol_batch = currently_loaded_batches.pop(0)
else:
mol_batch = MoleculeDataset(data[i:i + args.batch_size])
smiles_batch, features_batch = mol_batch.smiles(), mol_batch.features()
# Run model
if args.dataset_type == 'bert_pretraining':
batch = mol2graph(smiles_batch, args)
batch.bert_mask(mol_batch.mask())
else:
batch = smiles_batch
if args.maml: # TODO refactor with train loop
model.zero_grad()
intermediate_preds = model(batch, features_batch)
loss = get_loss_func(args)(intermediate_preds, targets)
loss = loss.sum() / len(batch)
grad = torch.autograd.grad(loss, [p for p in model.parameters() if p.requires_grad])
theta = [p for p in model.named_parameters() if p[1].requires_grad] # comes in same order as grad