![]() |
Expected calibration error via quantiles(exp(pred_log_prob),num_buckets)
.
tfp.substrates.jax.stats.expected_calibration_error_quantiles(
hit,
pred_log_prob,
num_buckets=20,
axis=0,
log_space_buckets=False,
name=None
)
Calibration is a measure of how well a model reports its own uncertainty. A model is said to be "calibrated" if buckets of predicted probabilities have the same within bucket average accurcy. The exected calibration error is the average absolute difference between predicted probability and (bucket) average accuracy. That is:
bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)
bucket_error = abs(bucket_accuracy - bucket_confidence)
ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)
where bucket_accuracy, bucket_confidence, bucket_count
are statistics
aggregated by num_buckets
-quantiles of tf.math.exp(pred_log_prob)
. Note:
bucket_*
always have num_buckets
size for the zero-th dimension.
Args | |
---|---|
hit
|
bool Tensor where True means the model prediction was correct
and False means the model prediction was incorrect. Shape must
broadcast with pred_log_prob.
|
pred_log_prob
|
Tensor representing the model's predicted log probability
for the given hit . Shape must broadcast with hit .
|
num_buckets
|
int representing the number of buckets over which to
aggregate hits. Buckets are quantiles of exp(pred_log_prob) .
Default value: 20 .
|
axis
|
Dimension over which to compute buckets and aggregate stats.
Default value: 0 .
|
log_space_buckets
|
When False bucket edges are computed from
tf.math.exp(pred_log_prob) ; when True bucket edges are computed from
pred_log_prob .
Default value: False .
|
name
|
Prefer str name used for ops created by this function.
Default value: None (i.e.,
"expected_calibration_error_quantiles" ).
|
Returns | |
---|---|
ece
|
Expected calibration error; tf.reduce_sum(abs(bucket_accuracy -
bucket_confidence) * bucket_count, axis=0) / tf.reduce_sum(bucket_count,
axis=0) .
|
bucket_accuracy
|
Tensor representing the within bucket average hits, i.e.,
total bucket hits divided by bucket count. Has shape
tf.concat([[num_buckets], tf.shape(tf.reduce_sum(pred_log_prob,
axis=axis))], axis=0) .
|
bucket_confidence
|
Tensor representing the within bucket average
probability, i.e., total bucket predicted probability divided by bucket
count. Has shape tf.concat([[num_buckets],
tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0) .
|
bucket_count
|
Tensor representing the total number of obervations in each
bucket. Has shape tf.concat([[num_buckets],
tf.shape(tf.reduce_sum(pred_log_prob, axis=axis))], axis=0) .
|
bucket_pred_log_prob
|
Tensor representing pred_log_prob bucket edges.
Always in log space, regardless of the value of log_space_buckets .
|
bucket
|
int Tensor representing the bucket within which pred_log_prob
lies.
|
Examples
# Example 1: Generic use.
label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)
log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])
(
ece,
acc,
conf,
cnt,
edges,
bucket,
) = tfp.stats.expected_calibration_error_quantiles(
label, log_pred, num_buckets=3)
# ece ==> tf.Tensor(0.145, shape=(), dtype=float32)
# acc ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)
# conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)
# cnt ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)
# Example 2: Categorgical classification.
# Assume we have evidence `x`, targets `y`, and model function `dnn`.
d = tfd.Categorical(logits=dnn(x))
def all_categories(d):
num_classes = tf.shape(d.logits_parameter())[-1]
batch_ndims = tf.size(d.batch_shape_tensor())
expand_shape = tf.pad(
[num_classes], paddings=[[0, batch_ndims]], constant_values=1)
return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)
all_pred_log_prob = d.log_prob(all_categories(d))
yhat = tf.argmax(all_pred_log_prob, axis=0)
def rollaxis(x, shift):
return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),
yhat,
batch_dims=len(d.batch_shape))
hit = tf.equal(y, yhat)
(
ece,
acc,
conf,
cnt,
edges,
bucket,
) = tfp.stats.expected_calibration_error_quantiles(
hit, pred_log_prob, num_buckets=10)