|  View source on GitHub | 
Expected calibration error via quantiles(exp(pred_log_prob),num_buckets).
tfp.stats.expected_calibration_error_quantiles(
    hit,
    pred_log_prob,
    num_buckets=20,
    axis=0,
    log_space_buckets=False,
    name=None
)
Calibration is a measure of how well a model reports its own uncertainty. A model is said to be "calibrated" if buckets of predicted probabilities have the same within bucket average accurcy. The exected calibration error is the average absolute difference between predicted probability and (bucket) average accuracy. That is:
bucket weight = bucket_count / tf.reduce_sum(bucket_count, axis=0)
bucket_error = abs(bucket_accuracy - bucket_confidence)
ece = tf.reduce_sum(bucket_weight * bucket_error, axis=0)
where bucket_accuracy, bucket_confidence, bucket_count are statistics
aggregated by num_buckets-quantiles of tf.math.exp(pred_log_prob). Note:
bucket_* always have num_buckets size for the zero-th dimension.
| Args | |
|---|---|
| hit | boolTensorwhereTruemeans the model prediction was correct
andFalsemeans the model prediction was incorrect. Shape must
broadcast with pred_log_prob. | 
| pred_log_prob | Tensorrepresenting the model's predicted log probability
for the givenhit. Shape must broadcast withhit. | 
| num_buckets | intrepresenting the number of buckets over which to
aggregate hits. Buckets are quantiles ofexp(pred_log_prob).
Default value:20. | 
| axis | Dimension over which to compute buckets and aggregate stats.
Default value: 0. | 
| log_space_buckets | When Falsebucket edges are computed fromtf.math.exp(pred_log_prob); whenTruebucket edges are computed frompred_log_prob.
Default value:False. | 
| name | Prefer strname used for ops created by this function.
Default value:None(i.e.,"expected_calibration_error_quantiles"). | 
Examples
# Example 1: Generic use.
label = tf.cast([0, 0, 1, 0, 1, 1], dtype=tf.bool)
log_pred = tf.math.log([0.1, 0.05, 0.5, 0.2, 0.99, 0.99])
(
  ece,
  acc,
  conf,
  cnt,
  edges,
  bucket,
) = tfp.stats.expected_calibration_error_quantiles(
    label, log_pred, num_buckets=3)
# ece  ==> tf.Tensor(0.145, shape=(), dtype=float32)
# acc  ==> tf.Tensor([0. 0. 1.], shape=(3,), dtype=float32)
# conf ==> tf.Tensor([0.075, 0.2, 0.826665], shape=(3,), dtype=float32)
# cnt  ==> tf.Tensor([2. 1. 3.], shape=(3,), dtype=float32)
# Example 2: Categorgical classification.
# Assume we have evidence `x`, targets `y`, and model function `dnn`.
d = tfd.Categorical(logits=dnn(x))
def all_categories(d):
  num_classes = tf.shape(d.logits_parameter())[-1]
  batch_ndims = tf.size(d.batch_shape_tensor())
  expand_shape = tf.pad(
      [num_classes], paddings=[[0, batch_ndims]], constant_values=1)
  return tf.reshape(tf.range(num_classes, dtype=d.dtype), expand_shape)
all_pred_log_prob = d.log_prob(all_categories(d))
yhat = tf.argmax(all_pred_log_prob, axis=0)
def rollaxis(x, shift):
  return tf.transpose(x, tf.roll(tf.range(tf.rank(x)), shift=shift, axis=0))
pred_log_prob = tf.gather(rollaxis(all_pred_log_prob, shift=-1),
                          yhat,
                          batch_dims=len(d.batch_shape))
hit = tf.equal(y, yhat)
(
  ece,
  acc,
  conf,
  cnt,
  edges,
  bucket,
) = tfp.stats.expected_calibration_error_quantiles(
    hit, pred_log_prob, num_buckets=10)