tfp.substrates.jax.stats.expected_calibration_error_quantiles

Expected calibration error via quantiles(exp(pred_log_prob),num_buckets).

Calibration is a measure of how well a model reports its own uncertainty. A model is said to be "calibrated" if buckets of predicted probabilities have the same within bucket average accurcy. The exected calibration error is the average absolute difference between predicted probability and (bucket) average accuracy. That is:

bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)
bucket_error = abs(bucket_accuracy - bucket_confidence)
ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)

where bucket_accuracy, bucket_confidence, bucket_count are statistics aggregated by num_buckets-quantiles of tf.math.exp(pred_log_prob). Note: bucket_* always have num_buckets size for the zero-th dimension.

hit bool Tensor where True means the model prediction was correct and False means the model prediction was incorrect. Shape must broadcast with pred_log_prob.
pred_log_prob Tensor representing the model's predicted log probability for the given hit. Shape must broadcast with hit.
num_buckets int representing the number of buckets over which to aggregate hits. Buckets are quantiles of exp(pred_log_prob). Default value: 20.
axis Dimension over which to compute buckets and aggregate stats. Default value: 0.
log_space_buckets When False bucket edges are computed from tf.math.exp(pred_log_prob); when True bucket edges are computed from pred_log_prob. Default value: False.
name Prefer str name used for ops created by this function. Default value: None (i.e., "expected_calibration_error_quantiles").

ece Expected calibration error; tf.reduce_sum(abs(bucket_accuracy - bucket_confidence) * bucket_count, axis=0) / tf.reduce_sum(bucket_count, axis=0).
bucket_accuracy Tensor representing the within bucket average hits, i.e., total bucket hits divided by bucket count. Has shape tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0).
bucket_confidence Tensor representing the within bucket average probability, i.e., total bucket predicted probability divided by bucket count. Has shape tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0).
bucket_count Tensor representing the total number of obervations in each bucket. Has shape tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0).
bucket_pred_log_prob Tensor representing pred_log_prob bucket edges. Always in log space, regardless of the value of log_space_buckets.
bucket int Tensor representing the bucket within which pred_log_prob lies.

Examples

# Example 1: Generic use.
label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)
log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])
(
  ece,
  acc,
  conf,
  cnt,
  edges,
  bucket,
) = tfp.stats.expected_calibration_error_quantiles(
    label, log_pred, num_buckets=3)
# ece  ==> tf.Tensor(0.145, shape=(), dtype=float32)
# acc  ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)
# conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)
# cnt  ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)
# Example 2: Categorgical classification.
# Assume we have evidence `x`, targets `y`, and model function `dnn`.
d = tfd.Categorical(logits=dnn(x))
def all_categories(d):
  num_classes = tf.shape(d.logits_parameter())[-1]
  batch_ndims = tf.size(d.batch_shape_tensor())
  expand_shape = tf.pad(
      [num_classes], paddings=[[0, batch_ndims]], constant_values=1)
  return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)
all_pred_log_prob = d.log_prob(all_categories(d))
yhat = tf.argmax(all_pred_log_prob, axis=0)
def rollaxis(x, shift):
  return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),
                          yhat,
                          batch_dims=len(d.batch_shape))
hit = tf.equal(y, yhat)
(
  ece,
  acc,
  conf,
  cnt,
  edges,
  bucket,
) = tfp.stats.expected_calibration_error_quantiles(
    hit, pred_log_prob, num_buckets=10)