Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def eval_accuracies(pred_s, target_s, pred_e, target_e):
"""An unofficial evalutation helper.
Compute exact start/end/complete match accuracies for a batch.
"""
# Convert 1D tensors to lists of lists (compatibility)
if torch.is_tensor(target_s):
target_s = [[e.item()] for e in target_s]
target_e = [[e.item()] for e in target_e]
# Compute accuracies from targets
batch_size = len(pred_s)
start = utils.AverageMeter()
end = utils.AverageMeter()
em = utils.AverageMeter()
for i in range(batch_size):
# Start matches
if pred_s[i] in target_s[i]:
start.update(1)
else:
start.update(0)
# End matches
if pred_e[i] in target_e[i]:
end.update(1)
else:
end.update(0)
# Both start and end match
def eval_accuracies(pred_s, target_s, pred_e, target_e):
"""An unofficial evalutation helper.
Compute exact start/end/complete match accuracies for a batch.
"""
# Convert 1D tensors to lists of lists (compatibility)
if torch.is_tensor(target_s):
target_s = [[e.item()] for e in target_s]
target_e = [[e.item()] for e in target_e]
# Compute accuracies from targets
batch_size = len(pred_s)
start = utils.AverageMeter()
end = utils.AverageMeter()
em = utils.AverageMeter()
for i in range(batch_size):
# Start matches
if pred_s[i] in target_s[i]:
start.update(1)
else:
start.update(0)
# End matches
if pred_e[i] in target_e[i]:
end.update(1)
else:
end.update(0)
# Both start and end match
if any([1 for _s, _e in zip(target_s[i], target_e[i])
if _s == pred_s[i] and _e == pred_e[i]]):
def validate_unofficial(args, data_loader, model, global_stats, mode):
"""Run one full unofficial validation.
Unofficial = doesn't use SQuAD script.
"""
eval_time = utils.Timer()
start_acc = utils.AverageMeter()
end_acc = utils.AverageMeter()
exact_match = utils.AverageMeter()
loss = utils.AverageMeter()
# Make predictions
examples = 0
for ex in data_loader:
batch_size = ex[0].size(0)
(pred_s, pred_e, _), batch_loss = model.predict(ex)
target_s, target_e = ex[-3:-1]
# We get metrics for independent start/end and joint start/end
accuracies = eval_accuracies(pred_s, target_s, pred_e, target_e)
start_acc.update(accuracies[0], batch_size)
end_acc.update(accuracies[1], batch_size)
exact_match.update(accuracies[2], batch_size)
loss.update(batch_loss, batch_size)
# If getting train accuracies, sample max 10k
examples += batch_size
def validate_unofficial(args, data_loader, model, global_stats, mode):
"""Run one full unofficial validation.
Unofficial = doesn't use SQuAD script.
"""
eval_time = utils.Timer()
start_acc = utils.AverageMeter()
end_acc = utils.AverageMeter()
exact_match = utils.AverageMeter()
loss = utils.AverageMeter()
# Make predictions
examples = 0
for ex in data_loader:
batch_size = ex[0].size(0)
(pred_s, pred_e, _), batch_loss = model.predict(ex)
target_s, target_e = ex[-3:-1]
# We get metrics for independent start/end and joint start/end
accuracies = eval_accuracies(pred_s, target_s, pred_e, target_e)
start_acc.update(accuracies[0], batch_size)
end_acc.update(accuracies[1], batch_size)
exact_match.update(accuracies[2], batch_size)
loss.update(batch_loss, batch_size)
def validate_official(args, data_loader, model, global_stats,
offsets, texts, answers):
"""Run one full official validation. Uses exact spans and same
exact match/F1 score computation as in the SQuAD script.
Extra arguments:
offsets: The character start/end indices for the tokens in each context.
texts: Map of qid --> raw text of examples context (matches offsets).
answers: Map of qid --> list of accepted answers.
"""
eval_time = utils.Timer()
f1 = utils.AverageMeter()
exact_match = utils.AverageMeter()
# Run through examples
examples = 0
for ex in data_loader:
ex_id, batch_size = ex[-1], ex[0].size(0)
( pred_s, pred_e, _ ), _= model.predict(ex)
for i in range(batch_size):
s_offset = offsets[ex_id[i]][pred_s[i][0]][0]
e_offset = offsets[ex_id[i]][pred_e[i][0]][1]
prediction = texts[ex_id[i]][s_offset:e_offset]
# Compute metrics
ground_truths = answers[ex_id[i]]
exact_match.update(utils.metric_max_over_ground_truths(
def validate_unofficial(args, data_loader, model, global_stats, mode):
"""Run one full unofficial validation.
Unofficial = doesn't use SQuAD script.
"""
eval_time = utils.Timer()
start_acc = utils.AverageMeter()
end_acc = utils.AverageMeter()
exact_match = utils.AverageMeter()
loss = utils.AverageMeter()
# Make predictions
examples = 0
for ex in data_loader:
batch_size = ex[0].size(0)
(pred_s, pred_e, _), batch_loss = model.predict(ex)
target_s, target_e = ex[-3:-1]
# We get metrics for independent start/end and joint start/end
accuracies = eval_accuracies(pred_s, target_s, pred_e, target_e)
start_acc.update(accuracies[0], batch_size)
end_acc.update(accuracies[1], batch_size)
exact_match.update(accuracies[2], batch_size)
loss.update(batch_loss, batch_size)
# If getting train accuracies, sample max 10k