Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def save_checkpoint(checkpoint_dir,
current_epoch,
global_step,
best_score,
exe,
main_program=fluid.default_main_program()):
ckpt_meta_path = os.path.join(checkpoint_dir, CKPT_FILE_NAME)
ckpt = checkpoint_pb2.CheckPoint()
model_saved_dir = os.path.join(checkpoint_dir, "step_%d" % global_step)
logger.info("Saving model checkpoint to {}".format(model_saved_dir))
fluid.io.save_persistables(
exe, dirname=model_saved_dir, main_program=main_program)
ckpt.current_epoch = current_epoch
ckpt.global_step = global_step
ckpt.latest_model_dir = model_saved_dir
ckpt.best_score = best_score
with open(ckpt_meta_path, "wb") as f:
f.write(ckpt.SerializeToString())
# NOTE: current saved checkpoint machanism is not completed, it can't
# resotre dataset training status
save_checkpoint(
checkpoint_dir=config.checkpoint_dir,
current_epoch=num_epoch + 1,
global_step=global_step,
exe=exe)
# Final evaluation
if do_eval:
evaluate_seq_label_task(
task, data_reader, feed_list, phase="dev", config=config)
evaluate_seq_label_task(
task, data_reader, feed_list, phase="test", config=config)
logger.info("PaddleHub finetune finished.")
def __init__(self):
self.dataset_dir = os.path.join(DATA_HOME, "toxic")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self._load_train_examples()
self._load_test_examples()
self._load_dev_examples()
def _log_interval_event(self, run_states):
scores, avg_loss, run_speed = self._calculate_metrics(run_states)
self.tb_writer.add_scalar(
tag="Loss_{}".format(self.phase),
scalar_value=avg_loss,
global_step=self._envs['train'].current_step)
log_scores = ""
for metric in scores:
self.tb_writer.add_scalar(
tag="{}_{}".format(metric, self.phase),
scalar_value=scores[metric],
global_step=self._envs['train'].current_step)
log_scores += "%s=%.5f " % (metric, scores[metric])
logger.info("step %d / %d: loss=%.5f %s[step/sec: %.2f]" %
(self.current_step, self.max_train_steps, avg_loss,
log_scores, run_speed))
def _do_memory_optimization(task, config):
if config.enable_memory_optim:
logger.info("Memory optimization start...")
task_var_name = task.metric_variable_names()
logger.info(
"Skip memory optimization on variables: {}".format(task_var_name))
optimize_time_begin = time.time()
fluid.memory_optimize(
input_program=fluid.default_main_program(),
# skip memory optimization on task metric variables
skip_opt_set=task_var_name)
time_used = time.time() - optimize_time_begin
logger.info("Memory optimization done! Time elapsed %f sec" % time_used)
except BaseException as err:
logger.warning("Infer Error with server {} : {}".format(
self.serving_list[self.con_index], err))
if len(self.serving_list) == 0:
logger.error('All server failed, process will exit')
return 'fail'
else:
self.con_index += 1
return 'retry'
elif self.load_balance == 'random':
try:
random.seed()
self.con_index = random.randint(0, len(self.serving_list) - 1)
logger.info(self.con_index)
cur_con = httplib.HTTPConnection(
self.serving_list[self.con_index])
cur_con.request('POST', "/BertService/inference", request_msg,
{"Content-Type": "application/json"})
response = cur_con.getresponse()
response_msg = response.read()
response_msg = ujson.loads(response_msg)
return response_msg
except BaseException as err:
logger.warning("Infer Error with server {} : {}".format(
self.serving_list[self.con_index], err))
if len(self.serving_list) == 0:
logger.error('All server failed, process will exit')
return 'fail'
def _do_memory_optimization(task, config):
if config.enable_memory_optim:
logger.info("Memory optimization start...")
task_var_name = task.metric_variable_names()
logger.info(
"Skip memory optimization on variables: {}".format(task_var_name))
optimize_time_begin = time.time()
fluid.memory_optimize(
input_program=fluid.default_main_program(),
# skip memory optimization on task metric variables
skip_opt_set=task_var_name)
time_used = time.time() - optimize_time_begin
logger.info("Memory optimization done! Time elapsed %f sec" % time_used)
instance_dict["token_ids"] = token_list[si * self.max_seq_len:(
si + 1) * self.max_seq_len]
instance_dict["sentence_type_ids"] = sent_list[
si * self.max_seq_len:(si + 1) * self.max_seq_len]
instance_dict["position_ids"] = pos_list[si * self.max_seq_len:(
si + 1) * self.max_seq_len]
instance_dict["input_masks"] = mask_list[si * self.max_seq_len:(
si + 1) * self.max_seq_len]
request.append(instance_dict)
request = {"instances": request}
request["max_seq_len"] = self.max_seq_len
request["feed_var_names"] = self.feed_var_names
request_msg = ujson.dumps(request)
if self.show_ids:
logger.info(request_msg)
return request_msg
def __init__(self, version_2_with_negative=False):
self.dataset_dir = os.path.join(DATA_HOME, "squad_data")
if not os.path.exists(self.dataset_dir):
ret, tips, self.dataset_dir = default_downloader.download_file_and_uncompress(
url=_DATA_URL, save_path=DATA_HOME, print_progress=True)
else:
logger.info("Dataset {} already cached.".format(self.dataset_dir))
self.version_2_with_negative = version_2_with_negative
self._load_train_examples(version_2_with_negative, if_has_answer=True)
self._load_dev_examples(version_2_with_negative, if_has_answer=True)