TF-Hub의 BERT 전문가

TensorFlow.org에서 보기 Google Colab에서 실행 GitHub에서 보기 노트북 다운로드 TF Hub 모델보기

이 Colab에서는 다음 방법을 보여줍니다.

  • MNLI, SQuAD 및 PubMed를 포함한 다양한 작업에 대해 학습된 TensorFlow Hub에서 BERT 모델 로드
  • 일치하는 전처리 모델을 사용하여 원시 텍스트를 토큰화하고 이를 ID로 변환
  • 로드된 모델을 사용하여 토큰 입력 ID에서 풀링 및 시퀀스 출력 생성
  • 서로 다른 문장의 풀링된 출력에서 의미론적 유사성 고찰

참고: 이 colab은 GPU 런타임으로 실행해야 합니다.

설정 및 가져오기

pip install --quiet "tensorflow-text==2.8.*"
import seaborn as sns
from sklearn.metrics import pairwise

import tensorflow as tf
import tensorflow_hub as hub
import tensorflow_text as text  # Imports TF ops for preprocessing.

Configure the model

문장

모델을 살펴보기 위해 Wikipedia에서 몇 가지 문장을 가져와 보겠습니다.

sentences = [
  "Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.",
  "The album went straight to number one on the Norwegian album chart, and sold to double platinum.",
  "Among the singles released from the album were the songs \"Be My Lover\" and \"Hard To Stay Awake\".",
  "Riccardo Zegna is an Italian jazz musician.",
  "Rajko Maksimović is a composer, writer, and music pedagogue.",
  "One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.",
  "Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum",
  "A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.",
  "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth.",
]

모델 실행하기

TF-Hub에서 BERT 모델을 로드하고 TF-Hub에서 일치하는 전처리 모델을 사용하여 문장을 토큰화한 다음 토큰화된 문장을 모델에 입력시킵니다. 이 colab을 빠르고 간단하게 유지하려면 GPU에서 실행하는 것이 좋습니다.

런타임런타임 유형 변경으로 이동하여 GPU가 선택되었는지 확인합니다.

preprocess = hub.load(PREPROCESS_MODEL)
bert = hub.load(BERT_MODEL)
inputs = preprocess(sentences)
outputs = bert(inputs)
print("Sentences:")
print(sentences)

print("\nBERT inputs:")
print(inputs)

print("\nPooled embeddings:")
print(outputs["pooled_output"])

print("\nPer token embeddings:")
print(outputs["sequence_output"])
Sentences:
["Here We Go Then, You And I is a 1999 album by Norwegian pop artist Morten Abel. It was Abel's second CD as a solo artist.", 'The album went straight to number one on the Norwegian album chart, and sold to double platinum.', 'Among the singles released from the album were the songs "Be My Lover" and "Hard To Stay Awake".', 'Riccardo Zegna is an Italian jazz musician.', 'Rajko Maksimović is a composer, writer, and music pedagogue.', 'One of the most significant Serbian composers of our time, Maksimović has been and remains active in creating works for different ensembles.', 'Ceylon spinach is a common name for several plants and may refer to: Basella alba Talinum fruticosum', 'A solar eclipse occurs when the Moon passes between Earth and the Sun, thereby totally or partly obscuring the image of the Sun for a viewer on Earth.', "A partial solar eclipse occurs in the polar regions of the Earth when the center of the Moon's shadow misses the Earth."]

BERT inputs:
{'input_word_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy=
array([[  101,  2182,  2057, ...,     0,     0,     0],
       [  101,  1996,  2201, ...,     0,     0,     0],
       [  101,  2426,  1996, ...,     0,     0,     0],
       ...,
       [  101, 16447,  6714, ...,     0,     0,     0],
       [  101,  1037,  5943, ...,     0,     0,     0],
       [  101,  1037,  7704, ...,     0,     0,     0]], dtype=int32)>, 'input_mask': <tf.Tensor: shape=(9, 128), dtype=int32, numpy=
array([[1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       ...,
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0],
       [1, 1, 1, ..., 0, 0, 0]], dtype=int32)>, 'input_type_ids': <tf.Tensor: shape=(9, 128), dtype=int32, numpy=
array([[0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       ...,
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0],
       [0, 0, 0, ..., 0, 0, 0]], dtype=int32)>}

Pooled embeddings:
tf.Tensor(
[[ 0.7975983  -0.4858047   0.49781665 ... -0.34488207  0.3972758
  -0.20639578]
 [ 0.5712035  -0.41205317  0.70489097 ... -0.35185057  0.19032365
  -0.40419084]
 [-0.6993837   0.1586686   0.06569945 ... -0.06232291 -0.81550217
  -0.07923597]
 ...
 [-0.3572722   0.77089787  0.15756367 ...  0.44185576 -0.86448324
   0.04504809]
 [ 0.9107701   0.41501644  0.5606342  ... -0.49263844  0.3964056
  -0.05036103]
 [ 0.90502876 -0.15505227  0.72672117 ... -0.34734455  0.50526446
  -0.19542967]], shape=(9, 768), dtype=float32)

Per token embeddings:
tf.Tensor(
[[[ 1.0919763e+00 -5.3055435e-01  5.4639924e-01 ... -3.5962319e-01
    4.2041004e-01 -2.0940384e-01]
  [ 1.0143832e+00  7.8078997e-01  8.5375911e-01 ...  5.5282390e-01
   -1.1245768e+00  5.6027830e-01]
  [ 7.8862834e-01  7.7776447e-02  9.5150828e-01 ... -1.9075394e-01
    5.9206229e-01  6.1910677e-01]
  ...
  [-3.2203096e-01 -4.2521316e-01 -1.2823755e-01 ... -3.9094931e-01
   -7.9097426e-01  4.2236397e-01]
  [-3.1037472e-02  2.3985589e-01 -2.1994336e-01 ... -1.1440081e-01
   -1.2680490e+00 -1.6136405e-01]
  [-4.2063668e-01  5.4972923e-01 -3.2444507e-01 ... -1.8478569e-01
   -1.1342961e+00 -5.8976438e-02]]

 [[ 6.4930725e-01 -4.3808180e-01  8.7695575e-01 ... -3.6755425e-01
    1.9267297e-01 -4.2864799e-01]
  [-1.1248751e+00  2.9931432e-01  1.1799647e+00 ...  4.8729539e-01
    5.3400397e-01  2.2836086e-01]
  [-2.7057484e-01  3.2353774e-02  1.0425684e+00 ...  5.8993781e-01
    1.5367906e+00  5.8425695e-01]
  ...
  [-1.4762504e+00  1.8239306e-01  5.5877924e-02 ... -1.6733217e+00
   -6.7398900e-01 -7.2449714e-01]
  [-1.5138137e+00  5.8184761e-01  1.6141929e-01 ... -1.2640836e+00
   -4.0272185e-01 -9.7197187e-01]
  [-4.7152787e-01  2.2817361e-01  5.2776086e-01 ... -7.5483733e-01
   -9.0903133e-01 -1.6954741e-01]]

 [[-8.6609292e-01  1.6002062e-01  6.5794230e-02 ... -6.2403791e-02
   -1.1432397e+00 -7.9402432e-02]
  [ 7.7118009e-01  7.0804596e-01  1.1350013e-01 ...  7.8830987e-01
   -3.1438011e-01 -9.7487241e-01]
  [-4.4002396e-01 -3.0059844e-01  3.5479474e-01 ...  7.9736769e-02
   -4.7393358e-01 -1.1001850e+00]
  ...
  [-1.0205296e+00  2.6938295e-01 -4.7310317e-01 ... -6.6319406e-01
   -1.4579906e+00 -3.4665293e-01]
  [-9.7003269e-01 -4.5014530e-02 -5.9779799e-01 ... -3.0526215e-01
   -1.2744255e+00 -2.8051612e-01]
  [-7.3144299e-01  1.7699258e-01 -4.6257949e-01 ... -1.6062324e-01
   -1.6346085e+00 -3.2060498e-01]]

 ...

 [[-3.7375548e-01  1.0225370e+00  1.5888736e-01 ...  4.7453445e-01
   -1.3108220e+00  4.5078602e-02]
  [-4.1589195e-01  5.0019342e-01 -4.5844358e-01 ...  4.1482633e-01
   -6.2065941e-01 -7.1554971e-01]
  [-1.2504396e+00  5.0936830e-01 -5.7103878e-01 ...  3.5491806e-01
    2.4368122e-01 -2.0577202e+00]
  ...
  [ 1.3393565e-01  1.1859145e+00 -2.2170596e-01 ... -8.1946641e-01
   -1.6737353e+00 -3.9692396e-01]
  [-3.3662772e-01  1.6556194e+00 -3.7813133e-01 ... -9.6745455e-01
   -1.4801090e+00 -8.3330792e-01]
  [-2.2649661e-01  1.6178432e+00 -6.7044818e-01 ... -4.9078292e-01
   -1.4535757e+00 -7.1707249e-01]]

 [[ 1.5320230e+00  4.4165635e-01  6.3375759e-01 ... -5.3953838e-01
    4.1937724e-01 -5.0403673e-02]
  [ 8.9377761e-01  8.9395475e-01  3.0627429e-02 ...  5.9038877e-02
   -2.0649567e-01 -8.4811318e-01]
  [-1.8558376e-02  1.0479058e+00 -1.3329605e+00 ... -1.3869658e-01
   -3.7879506e-01 -4.9068686e-01]
  ...
  [ 1.4275625e+00  1.0696868e-01 -4.0634036e-02 ... -3.1777412e-02
   -4.1459864e-01  7.0036912e-01]
  [ 1.1286640e+00  1.4547867e-01 -6.1372513e-01 ...  4.7491822e-01
   -3.9852142e-01  4.3124473e-01]
  [ 1.4393290e+00  1.8030715e-01 -4.2854571e-01 ... -2.5022799e-01
   -1.0000539e+00  3.5985443e-01]]

 [[ 1.4993387e+00 -1.5631306e-01  9.2174339e-01 ... -3.6242083e-01
    5.5635023e-01 -1.9797631e-01]
  [ 1.1110525e+00  3.6651248e-01  3.5505861e-01 ... -5.4297489e-01
    1.4471433e-01 -3.1676081e-01]
  [ 2.4048671e-01  3.8116074e-01 -5.9182751e-01 ...  3.7410957e-01
   -5.9829539e-01 -1.0166274e+00]
  ...
  [ 1.0158602e+00  5.0260085e-01  1.0736975e-01 ... -9.5642674e-01
   -4.1039643e-01 -2.6760373e-01]
  [ 1.1848910e+00  6.5479511e-01  1.0155141e-03 ... -8.6154616e-01
   -8.8041753e-02 -3.0636895e-01]
  [ 1.2669089e+00  4.7767794e-01  6.6289604e-03 ... -1.1585804e+00
   -7.0679039e-02 -1.8678637e-01]]], shape=(9, 128, 768), dtype=float32)

의미론적 유사성

이제 문장의 pooled_output 임베딩을 살펴보고 문장 전체적으로 얼마나 유사한지 비교해 보겠습니다.

Helper functions

plot_similarity(outputs["pooled_output"], sentences)

png

자세히 알아보기