Pembelajaran Bahasa Dalam yang Sadar Ketidakpastian dengan BERT-SNGP

Lihat di TensorFlow.org Jalankan di Google Colab Lihat di GitHub Unduh buku catatan Lihat model TF Hub

Dalam tutorial SNGP , Anda mempelajari cara membangun model SNGP di atas jaringan residual yang dalam untuk meningkatkan kemampuannya dalam mengukur ketidakpastiannya. Dalam tutorial ini, Anda akan menerapkan SNGP ke tugas pemahaman bahasa alami (NLU) dengan membangunnya di atas encoder BERT yang dalam untuk meningkatkan kemampuan model NLU yang dalam dalam mendeteksi kueri di luar cakupan.

Secara khusus, Anda akan:

  • Bangun BERT-SNGP, model BERT yang ditambah SNGP .
  • Muat set data deteksi intent CLINC Out-of-scope (OOS) .
  • Latih model BERT-SNGP.
  • Evaluasi kinerja model BERT-SNGP dalam kalibrasi ketidakpastian dan deteksi di luar domain.

Di luar CLINC OOS, model SNGP telah diterapkan pada kumpulan data skala besar seperti deteksi toksisitas Jigsaw , dan pada kumpulan data gambar seperti CIFAR-100 dan ImageNet . Untuk hasil benchmark SNGP dan metode ketidakpastian lainnya, serta implementasi berkualitas tinggi dengan skrip pelatihan / evaluasi end-to-end, Anda dapat melihat benchmark Uncertainty Baselines .

Mempersiapkan

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Tutorial ini membutuhkan GPU untuk berjalan secara efisien. Periksa apakah GPU tersedia.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

Pertama-tama terapkan pengklasifikasi BERT standar mengikuti teks klasifikasi dengan tutorial BERT . Kami akan menggunakan encoder berbasis BERT, dan ClassificationHead sebagai classifier.

Model BERT standar

Bangun model SNGP

Untuk mengimplementasikan model BERT-SNGP, Anda hanya perlu mengganti ClassificationHead dengan GaussianProcessClassificationHead . Normalisasi spektral sudah dikemas sebelumnya ke dalam kepala klasifikasi ini. Seperti dalam tutorial SNGP , tambahkan panggilan balik reset kovarians ke model, sehingga model secara otomatis menyetel ulang penaksir kovarians di awal epoch baru untuk menghindari penghitungan data yang sama dua kali.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Muat kumpulan data CLINC OOS

Sekarang muat set data deteksi maksud CLINC OOS . Kumpulan data ini berisi 15000 kueri lisan pengguna yang dikumpulkan lebih dari 150 kelas maksud, juga berisi 1000 kalimat di luar domain (OOD) yang tidak tercakup oleh kelas yang dikenal.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Buat data kereta dan uji.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Buat kumpulan data evaluasi OOD. Untuk ini, gabungkan data uji dalam domain clinc_test dan data di luar domain clinc_test_oos . Kami juga akan menetapkan label 0 untuk contoh dalam domain, dan memberi label 1 untuk contoh di luar domain.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Latih dan evaluasi

Pertama-tama atur konfigurasi pelatihan dasar.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

Evaluasi kinerja OOD

Evaluasi seberapa baik model dapat mendeteksi kueri di luar domain yang tidak dikenal. Untuk evaluasi yang ketat, gunakan dataset evaluasi OOD ood_eval_dataset dibuat sebelumnya.

Menghitung probabilitas OOD sebagai \(1 - p(x)\), di mana \(p(x)=softmax(logit(x))\) adalah probabilitas prediktif.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Sekarang evaluasi seberapa baik skor ketidakpastian model ood_probs memprediksi label di luar domain. Pertama-tama hitung Area di bawah kurva precision-recall (AUPRC) untuk probabilitas OOD vs akurasi deteksi OOD.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

Ini cocok dengan kinerja SNGP yang dilaporkan pada benchmark CLINC OOS di bawah Uncertainty Baselines .

Selanjutnya, periksa kualitas model dalam kalibrasi ketidakpastian , yaitu, apakah probabilitas prediksi model sesuai dengan akurasi prediksinya. Model yang dikalibrasi dengan baik dianggap layak dipercaya, karena, misalnya, probabilitas prediktifnya \(p(x)=0.8\) berarti bahwa model tersebut benar 80% sepanjang waktu.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

Sumber daya dan bacaan lebih lanjut