Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
val_losses = []
val_accs = []
test_accs = []
for trial_idx in range(1, num_trials+1):
if config.train:
print("-" * 80)
print("Task {} trial {}".format(config.task, trial_idx))
mkdirs(config, trial_idx)
graph = tf.Graph()
# TODO : initialize BaseTower-subclassed objects
towers = [Tower(config) for _ in range(config.num_devices)]
sess = tf.Session(graph=graph, config=tf.ConfigProto(allow_soft_placement=True))
# TODO : initialize BaseRunner-subclassed object
runner = Runner(config, sess, towers)
with graph.as_default(), tf.device("/cpu:0"):
runner.initialize()
if config.train:
if config.load:
runner.load()
val_loss, val_acc = runner.train(comb_train_ds, config.num_epochs, val_data_set=comb_dev_ds,
num_batches=config.train_num_batches,
val_num_batches=config.val_num_batches, eval_ph_names=eval_ph_names)
val_losses.append(val_loss)
val_accs.append(val_acc)
else:
runner.load()
test_loss, test_acc = runner.eval(comb_test_ds, eval_tensor_names=eval_tensor_names,
num_batches=config.test_num_batches, eval_ph_names=eval_ph_names)
test_accs.append(test_acc)