Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def fit(self, epochs, lr, validate=True, schedule_type="warmup_linear"):
num_train_steps = int(len(self.data.train_dl) / self.grad_accumulation_steps * epochs)
if self.optimizer is None:
self.optimizer, self.schedule = self.get_optimizer(lr , num_train_steps)
t_total = num_train_steps
if self.multi_gpu == False:
t_total = t_total // torch.distributed.get_world_size()
global_step = 0
pbar = master_bar(range(epochs))
for epoch in pbar:
self.model.train()
tr_loss = 0
nb_tr_examples, nb_tr_steps = 0, 0
for step, batch in enumerate(progress_bar(self.data.train_dl, parent=pbar)):
batch = tuple(t.to(self.device) for t in batch)
input_ids, input_mask, segment_ids, label_ids = batch
if self.is_fp16 and self.multi_label:
label_ids = label_ids.half()
loss = self.model(input_ids, segment_ids, input_mask, label_ids)
if self.multi_gpu:
def fit(self, epochs: int, lr: float,
params_opt_dict: Optional[Dict] = None):
"Main training loop"
# Print logger at the start of the training loop
self.logger.info(self.cfg)
# Initialize the progress_bar
mb = master_bar(range(epochs))
# Initialize optimizer
# Prepare Optimizer may need to be re-written as per use
self.optimizer = self.prepare_optimizer(params_opt_dict)
# Initialize scheduler
# Prepare scheduler may need to re-written as per use
self.lr_scheduler = self.prepare_scheduler(self.optimizer)
# Write the top row display
# mb.write(self.log_keys, table=True)
self.master_bar_write(mb, line=self.log_keys, table=True)
exception = False
met_to_use = None
# Keep record of time until exit
st_time = time.time()
try:
# Loop over epochs
self.logger.info(" Num examples = %d", len(train_dataloader.dataset))
self.logger.info(" Num Epochs = %d", epochs)
self.logger.info(
" Total train batch size (w. parallel, distributed & accumulation) = %d",
self.data.train_batch_size * self.grad_accumulation_steps,
)
self.logger.info(
" Gradient Accumulation steps = %d", self.grad_accumulation_steps
)
self.logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epoch_step = 0
tr_loss, logging_loss, epoch_loss = 0.0, 0.0, 0.0
self.model.zero_grad()
pbar = master_bar(range(epochs))
for epoch in pbar:
epoch_step = 0
epoch_loss = 0.0
for step, batch in enumerate(progress_bar(train_dataloader, parent=pbar)):
self.model.train()
batch = tuple(t.to(self.device) for t in batch)
inputs = {
"input_ids": batch[0],
"attention_mask": batch[1],
"labels": batch[3],
}
if self.model_type in ["bert", "xlnet"]:
inputs["token_type_ids"] = batch[2]
self.model = torch.nn.DataParallel(self.model)
# Start Training
self.logger.info("***** Running training *****")
self.logger.info(" Num examples = %d", len(train_dataloader.dataset))
self.logger.info(" Num Epochs = %d", epochs)
self.logger.info(" Total train batch size (w. parallel, distributed & accumulation) = %d",
self.data.train_batch_size * self.grad_accumulation_steps)
self.logger.info(" Gradient Accumulation steps = %d", self.grad_accumulation_steps)
self.logger.info(" Total optimization steps = %d", t_total)
global_step = 0
epoch_step = 0
tr_loss, logging_loss, epoch_loss = 0.0, 0.0, 0.0
self.model.zero_grad()
pbar = master_bar(range(epochs))
for epoch in pbar:
epoch_step = 0
epoch_loss = 0.0
for step, batch in enumerate(progress_bar(train_dataloader, parent=pbar)):
inputs, labels = self.data.mask_tokens(batch)
cpu_device = torch.device('cpu')
inputs = inputs.to(self.device)
labels = labels.to(self.device)
self.model.train()
outputs = self.model(inputs, masked_lm_labels=labels)
loss = outputs[0] # model outputs are always tuple in pytorch-transformers (see doc)