Học ngôn ngữ sâu với nhận thức không chắc chắn với BERT-SNGP

Xem trên TensorFlow.org Chạy trong Google Colab Xem trên GitHub Tải xuống sổ ghi chép Xem mô hình TF Hub

Trong hướng dẫn về SNGP , bạn đã học cách xây dựng mô hình SNGP trên đầu của một mạng dư sâu để cải thiện khả năng định lượng độ không đảm bảo của nó. Trong hướng dẫn này, bạn sẽ áp dụng SNGP cho nhiệm vụ hiểu ngôn ngữ tự nhiên (NLU) bằng cách xây dựng nó trên đầu bộ mã hóa BERT sâu để cải thiện khả năng của mô hình NLU sâu trong việc phát hiện các truy vấn ngoài phạm vi.

Cụ thể, bạn sẽ:

  • Xây dựng BERT-SNGP, một mô hình BERT tăng cường SNGP.
  • Tải tập dữ liệu phát hiện ý định ngoài phạm vi (OOS) CLINC .
  • Đào tạo mô hình BERT-SNGP.
  • Đánh giá hoạt động của mô hình BERT-SNGP trong hiệu chuẩn độ không đảm bảo và phát hiện ngoài miền.

Ngoài CLINC OOS, mô hình SNGP đã được áp dụng cho các bộ dữ liệu quy mô lớn như phát hiện độc tính của Jigsaw và cho các bộ dữ liệu hình ảnh như CIFAR-100ImageNet . Để biết kết quả điểm chuẩn của SNGP và các phương pháp độ không đảm bảo khác, cũng như việc triển khai chất lượng cao với các kịch bản đào tạo / đánh giá từ đầu đến cuối, bạn có thể xem tiêu chuẩn Đường cơ sở về độ không đảm bảo.

Thành lập

pip uninstall -y tensorflow tf-text
pip install -U tensorflow-text-nightly
pip install -U tf-nightly
pip install -U tf-models-nightly
import matplotlib.pyplot as plt

import sklearn.metrics
import sklearn.calibration

import tensorflow_hub as hub
import tensorflow_datasets as tfds

import numpy as np
import tensorflow as tf

import official.nlp.modeling.layers as layers
import official.nlp.optimization as optimization
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_addons/utils/ensure_tf_install.py:43: UserWarning: You are currently using a nightly version of TensorFlow (2.9.0-dev20220203). 
TensorFlow Addons offers no support for the nightly versions of TensorFlow. Some things might work, some other might not. 
If you encounter a bug, do not file an issue on GitHub.
  UserWarning,

Hướng dẫn này cần GPU để chạy hiệu quả. Kiểm tra xem GPU có khả dụng không.

tf.__version__
'2.9.0-dev20220203'
gpus = tf.config.list_physical_devices('GPU')
gpus
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
assert gpus, """
  No GPU(s) found! This tutorial will take many hours to run without a GPU.

  You may hit this error if the installed tensorflow package is not
  compatible with the CUDA and CUDNN versions."""

Đầu tiên hãy triển khai một trình phân loại BERT tiêu chuẩn theo sau văn bản phân loại với hướng dẫn BERT . Chúng tôi sẽ sử dụng bộ mã hóa cơ sở BERT và phần mềm ClassificationHead tích hợp sẵn làm bộ phân loại.

Mô hình BERT tiêu chuẩn

Xây dựng mô hình SNGP

Để triển khai mô hình BERT-SNGP, bạn chỉ cần thay thế ClassificationHead loạiHead bằng GaussianProcessClassificationHead tích hợp sẵn. Chuẩn hóa phổ đã được đóng gói sẵn trong đầu phân loại này. Giống như trong hướng dẫn SNGP , thêm lệnh gọi lại đặt lại hiệp phương sai vào mô hình, để mô hình tự động đặt lại công cụ ước tính hiệp phương sai khi bắt đầu một kỷ nguyên mới để tránh đếm cùng một dữ liệu hai lần.

class ResetCovarianceCallback(tf.keras.callbacks.Callback):

  def on_epoch_begin(self, epoch, logs=None):
    """Resets covariance matrix at the begining of the epoch."""
    if epoch > 0:
      self.model.classifier.reset_covariance_matrix()
class SNGPBertClassifier(BertClassifier):

  def make_classification_head(self, num_classes, inner_dim, dropout_rate):
    return layers.GaussianProcessClassificationHead(
        num_classes=num_classes, 
        inner_dim=inner_dim,
        dropout_rate=dropout_rate,
        gp_cov_momentum=-1,
        temperature=30.,
        **self.classifier_kwargs)

  def fit(self, *args, **kwargs):
    """Adds ResetCovarianceCallback to model callbacks."""
    kwargs['callbacks'] = list(kwargs.get('callbacks', []))
    kwargs['callbacks'].append(ResetCovarianceCallback())

    return super().fit(*args, **kwargs)

Tải tập dữ liệu CLINC OOS

Bây giờ tải tập dữ liệu phát hiện ý định CLINC OOS . Tập dữ liệu này chứa 15000 câu truy vấn được nói của người dùng được thu thập trên 150 lớp ý định, nó cũng chứa 1000 câu ngoài miền (OOD) không thuộc bất kỳ lớp nào đã biết.

(clinc_train, clinc_test, clinc_test_oos), ds_info = tfds.load(
    'clinc_oos', split=['train', 'test', 'test_oos'], with_info=True, batch_size=-1)

Thực hiện các chuyến tàu và dữ liệu thử nghiệm.

train_examples = clinc_train['text']
train_labels = clinc_train['intent']

# Makes the in-domain (IND) evaluation data.
ind_eval_data = (clinc_test['text'], clinc_test['intent'])

Tạo tập dữ liệu đánh giá OOD. Đối với điều này, hãy kết hợp dữ liệu kiểm tra trong miền clinc_test và dữ liệu ngoài miền clinc_test_oos . Chúng tôi cũng sẽ gán nhãn 0 cho các ví dụ trong miền và nhãn 1 cho các ví dụ ngoài miền.

test_data_size = ds_info.splits['test'].num_examples
oos_data_size = ds_info.splits['test_oos'].num_examples

# Combines the in-domain and out-of-domain test examples.
oos_texts = tf.concat([clinc_test['text'], clinc_test_oos['text']], axis=0)
oos_labels = tf.constant([0] * test_data_size + [1] * oos_data_size)

# Converts into a TF dataset.
ood_eval_dataset = tf.data.Dataset.from_tensor_slices(
    {"text": oos_texts, "label": oos_labels})

Đào tạo và đánh giá

Đầu tiên hãy thiết lập các cấu hình đào tạo cơ bản.

TRAIN_EPOCHS = 3
TRAIN_BATCH_SIZE = 32
EVAL_BATCH_SIZE = 256

optimizer = bert_optimizer(learning_rate=1e-4)
loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)
metrics = tf.metrics.SparseCategoricalAccuracy()
fit_configs = dict(batch_size=TRAIN_BATCH_SIZE,
                   epochs=TRAIN_EPOCHS,
                   validation_batch_size=EVAL_BATCH_SIZE, 
                   validation_data=ind_eval_data)
sngp_model = SNGPBertClassifier()
sngp_model.compile(optimizer=optimizer, loss=loss, metrics=metrics)
sngp_model.fit(train_examples, train_labels, **fit_configs)
Epoch 1/3
469/469 [==============================] - 219s 427ms/step - loss: 1.0725 - sparse_categorical_accuracy: 0.7870 - val_loss: 0.4358 - val_sparse_categorical_accuracy: 0.9380
Epoch 2/3
469/469 [==============================] - 198s 422ms/step - loss: 0.0885 - sparse_categorical_accuracy: 0.9797 - val_loss: 0.2424 - val_sparse_categorical_accuracy: 0.9518
Epoch 3/3
469/469 [==============================] - 199s 424ms/step - loss: 0.0259 - sparse_categorical_accuracy: 0.9951 - val_loss: 0.1927 - val_sparse_categorical_accuracy: 0.9642
<keras.callbacks.History at 0x7fe24c0a7090>

Đánh giá hiệu suất OOD

Đánh giá mức độ hiệu quả của mô hình có thể phát hiện ra các truy vấn ngoài miền không quen thuộc. Để đánh giá chặt chẽ, hãy sử dụng tập dữ liệu đánh giá OOD ood_eval_dataset được xây dựng trước đó.

Tính xác suất OOD dưới dạng \(1 - p(x)\), trong đó \(p(x)=softmax(logit(x))\) là xác suất dự đoán.

sngp_probs, ood_labels = oos_predict(sngp_model, ood_eval_dataset)
ood_probs = 1 - sngp_probs

Bây giờ, hãy đánh giá xem điểm không chắc chắn của mô hình ood_probs dự đoán nhãn ngoài miền tốt như thế nào. Trước tiên, tính toán Diện tích dưới đường cong truy lại độ chính xác (AUPRC) cho xác suất OOD so với độ chính xác phát hiện OOD.

precision, recall, _ = sklearn.metrics.precision_recall_curve(ood_labels, ood_probs)
auprc = sklearn.metrics.auc(recall, precision)
print(f'SNGP AUPRC: {auprc:.4f}')
SNGP AUPRC: 0.9039

Điều này khớp với hiệu suất SNGP được báo cáo tại điểm chuẩn CLINC OOS theo Đường cơ sở về độ không chắc chắn .

Tiếp theo, kiểm tra chất lượng của mô hình trong hiệu chuẩn độ không đảm bảo , tức là, liệu xác suất dự đoán của mô hình có tương ứng với độ chính xác dự đoán của nó hay không. Một mô hình được hiệu chỉnh tốt được coi là đáng tin cậy, vì chẳng hạn, xác suất dự đoán của nó là \(p(x)=0.8\) có nghĩa là mô hình đó đúng 80% thời gian.

prob_true, prob_pred = sklearn.calibration.calibration_curve(
    ood_labels, ood_probs, n_bins=10, strategy='quantile')
plt.plot(prob_pred, prob_true)

plt.plot([0., 1.], [0., 1.], c='k', linestyle="--")
plt.xlabel('Predictive Probability')
plt.ylabel('Predictive Accuracy')
plt.title('Calibration Plots, SNGP')

plt.show()

png

Tài nguyên và đọc thêm