Uncertainty-aware Deep Language Learning with BERT-SNGP

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook See TF Hub model

In the SNGP tutorial, you learned how to build SNGP model on top of a deep residual network to improve its ability to quantify its uncertainty. In this tutorial, you will apply SNGP to a natural language understanding (NLU) task by building it on top of a deep BERT encoder to improve deep NLU model's ability in detecting out-of-scope queries.

Specifically, you will:

  • Build BERT-SNGP, a SNGP-augmented BERT model.
  • Load the CLINC Out-of-scope (OOS) intent detection dataset.
  • Train the BERT-SNGP model.
  • Evaluate the BERT-SNGP model's performance in uncertainty calibration and out-of-domain detection.

Beyond CLINC OOS, the SNGP model has been applied to large-scale datasets such as Jigsaw toxicity detection, and to the image datasets such as CIFAR-100 and ImageNet. For benchmark results of SNGP and other uncertainty methods, as well as high-quality implementation with end-to-end training / evaluation scripts, you can check out the Uncertainty Baselines benchmark.


pip uninstall -y tensorflow tf-text
pip install "tensorflow-text==2.11.*"
pip install -U tf-models-official==2.11.0
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization

This tutorial needs the GPU to run efficiently. Check if the GPU is available.

gpus = tf.config.list_physical_devices('GPU')
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

First implement a standard BERT classifier following the classify text with BERT tutorial. We will use the BERT-base encoder, and the built-in ClassificationHead as the classifier.

Standard BERT model

Build SNGP model

To implement a BERT-SNGP model, you only need to replace the ClassificationHead with the built-in GaussianProcessClassificationHead. Spectral normalization is already pre-packaged into this classification head. Like in the SNGP tutorial, add a covariance reset callback to the model, so the model automatically reset the covariance estimator at the beginning of a new epoch to avoid counting the same data twice.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the beginning of the epoch."""
    if epoch > 0:
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))

    return super().fit(*args, **kwargs)

Load CLINC OOS dataset

Now load the CLINC OOS intent detection dataset. This dataset contains 15000 user's spoken queries collected over 150 intent classes, it also contains 1000 out-of-domain (OOD) sentences that are not covered by any of the known classes.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Make the train and test data.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Create a OOD evaluation dataset. For this, combine the in-domain test data clinc_test and the out-of-domain data clinc_test_oos. We will also assign label 0 to the in-domain examples, and label 1 to the out-of-domain examples.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Train and evaluate

First set up the basic training configurations.


optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)

Evaluate OOD performance

Evaluate how well the model can detect the unfamiliar out-of-domain queries. For rigorous evaluation, use the OOD evaluation dataset ood_eval_dataset built earlier.

Computes the OOD probabilities as \(1 - p(x)\), where \(p(x)=softmax(logit(x))\) is the predictive probability.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Now evaluate how well the model's uncertainty score ood_probs predicts the out-of-domain label. First compute the Area under precision-recall curve (AUPRC) for OOD probability v.s. OOD detection accuracy.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')

This matches the SNGP performance reported at the CLINC OOS benchmark under the Uncertainty Baselines.

Next, examine the model's quality in uncertainty calibration, i.e., whether the model's predictive probability corresponds to its predictive accuracy. A well-calibrated model is considered trust-worthy, since, for example, its predictive probability \(p(x)=0.8\) means that the model is correct 80% of the time.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')


Resources and further reading