View source on GitHub |
Expected calibration error via quantiles(exp(pred_log_prob),num_buckets)
.
tfp.substrates.jax.stats.expected_calibration_error_quantiles(
hit,
pred_log_prob,
num_buckets=20,
axis=0,
log_space_buckets=False,
name=None
)
Calibration is a measure of how well a model reports its own uncertainty. A model is said to be "calibrated" if buckets of predicted probabilities have the same within bucket average accurcy. The exected calibration error is the average absolute difference between predicted probability and (bucket) average accuracy. That is:
bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)
bucket_error = abs(bucket_accuracy - bucket_confidence)
ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)
where bucket_accuracy, bucket_confidence, bucket_count
are statistics
aggregated by num_buckets
-quantiles of tf.math.exp(pred_log_prob)
. Note:
bucket_*
always have num_buckets
size for the zero-th dimension.
Args | |
---|---|
hit
|
bool Tensor where True means the model prediction was correct
and False means the model prediction was incorrect. Shape must
broadcast with pred_log_prob.
|
pred_log_prob
|
Tensor representing the model's predicted log probability
for the given hit . Shape must broadcast with hit .
|
num_buckets
|
int representing the number of buckets over which to
aggregate hits. Buckets are quantiles of exp(pred_log_prob) .
Default value: 20 .
|
axis
|
Dimension over which to compute buckets and aggregate stats.
Default value: 0 .
|
log_space_buckets
|
When False bucket edges are computed from
tf.math.exp(pred_log_prob) ; when True bucket edges are computed from
pred_log_prob .
Default value: False .
|
name
|
Prefer str name used for ops created by this function.
Default value: None (i.e.,
"expected_calibration_error_quantiles" ).
|
Examples
# Example 1: Generic use.
label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)
log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])
(
ece,
acc,
conf,
cnt,
edges,
bucket,
) = tfp.stats.expected_calibration_error_quantiles(
label, log_pred, num_buckets=3)
# ece ==> tf.Tensor(0.145, shape=(), dtype=float32)
# acc ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)
# conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)
# cnt ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)
# Example 2: Categorgical classification.
# Assume we have evidence `x`, targets `y`, and model function `dnn`.
d = tfd.Categorical(logits=dnn(x))
def all_categories(d):
num_classes = tf.shape(d.logits_parameter())[-1]
batch_ndims = tf.size(d.batch_shape_tensor())
expand_shape = tf.pad(
[num_classes], paddings=[[0, batch_ndims]], constant_values=1)
return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)
all_pred_log_prob = d.log_prob(all_categories(d))
yhat = tf.argmax(all_pred_log_prob, axis=0)
def rollaxis(x, shift):
return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),
yhat,
batch_dims=len(d.batch_shape))
hit = tf.equal(y, yhat)
(
ece,
acc,
conf,
cnt,
edges,
bucket,
) = tfp.stats.expected_calibration_error_quantiles(
hit, pred_log_prob, num_buckets=10)