Secure your code as it's written. Use Snyk Code to scan source code in minutes - no build needed - and fix issues immediately.
def evaluate(model, dataset, task, global_step, num_ece_bins, num_samples,
ensemble_size, name, steps=-1):
"""Evaluates the model on a dataset."""
nll_metric = tf.keras.metrics.Mean()
ece_metric = ed.metrics.ExpectedCalibrationError(num_bins=num_ece_bins)
aucpr_metric = tf.keras.metrics.AUC(curve="PR")
aucroc_metric = tf.keras.metrics.AUC(curve="ROC")
acc_metric = tf.keras.metrics.Accuracy()
top_k = 5 if task.logits_dimension >= 5 else None
sensitivity_metric = tf.keras.metrics.Recall(top_k=top_k)
ppv_metric = tf.keras.metrics.Precision(top_k=top_k)
for inputs, labels in dataset.take(steps):
logits = tf.reshape(
[model(inputs, training=False) for _ in range(num_samples)],
[num_samples, ensemble_size, -1, task.logits_dimension])
if task.logits_dimension == 1:
label_dist = tfp.distributions.Bernoulli(logits)
labels_1d = labels
probs = tf.sigmoid(logits)
else: