Ta strona została przetłumaczona przez Cloud Translation API.
Switch to English

Klasyfikacja tekstu za pomocą TensorFlow Lite Model Maker

Zobacz na TensorFlow.org Wyświetl źródło na GitHub Pobierz notatnik

Biblioteka TensorFlow Lite Model Maker upraszcza proces dostosowywania i konwertowania modelu TensorFlow do określonych danych wejściowych podczas wdrażania tego modelu w aplikacjach ML na urządzeniu.

Ten notatnik przedstawia kompleksowy przykład wykorzystujący bibliotekę Model Maker do zilustrowania adaptacji i konwersji powszechnie używanego modelu klasyfikacji tekstu do klasyfikowania recenzji filmów na urządzeniu mobilnym. Model klasyfikacji tekstu klasyfikuje tekst do predefiniowanych kategorii. Dane wejściowe powinny być wstępnie przetworzonym tekstem, a wyjściami są prawdopodobieństwa kategorii. Zbiór danych użyty w tym samouczku to pozytywne i negatywne recenzje filmów.

Wymagania wstępne

Zainstaluj wymagane pakiety

Aby uruchomić ten przykład, zainstaluj wymagane pakiety, w tym pakiet Model Maker z repozytorium GitHub .

pip install tflite-model-maker
Collecting tflite-model-maker
[?25l  Downloading https://files.pythonhosted.org/packages/13/bc/4c23b9cb9ef612a1f48bac5543bd531665de5eab8f8231111aac067f8c30/tflite_model_maker-0.1.2-py3-none-any.whl (104kB)
[K     |████████████████████████████████| 112kB 8.2MB/s 
[?25hRequirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)
Collecting fire
[?25l  Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)
[K     |████████████████████████████████| 81kB 7.7MB/s 
[?25hCollecting flatbuffers==1.12
  Downloading https://files.pythonhosted.org/packages/eb/26/712e578c5f14e26ae3314c39a1bdc4eb2ec2f4ddc89b708cf8e0a0d20423/flatbuffers-1.12-py2.py3-none-any.whl
Collecting tf-models-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/d3/e9/c4e5a451c268a5a75a27949562364f6086f6bb33b226a065a8beceefa9ba/tf_models_nightly-2.3.0.dev20200914-py2.py3-none-any.whl (993kB)
[K     |████████████████████████████████| 1.0MB 17.6MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 31.6MB/s 
[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)
Collecting tf-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)
[K     |████████████████████████████████| 390.2MB 46kB/s 
[?25hRequirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)
Requirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)
Requirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)
Collecting tflite-support==0.1.0rc3.dev2
[?25l  Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 50.0MB/s 
[?25hRequirement already satisfied: protobuf>=3.8.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (3.12.4)
Requirement already satisfied: six>=1.12.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-hub>=0.8.0->tflite-model-maker) (1.15.0)
Requirement already satisfied: termcolor in /usr/local/lib/python3.6/dist-packages (from fire->tflite-model-maker) (1.1.0)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)
Collecting tensorflow-model-optimization>=0.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |████████████████████████████████| 174kB 57.7MB/s 
[?25hCollecting tf-slim>=1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)
[K     |████████████████████████████████| 358kB 54.9MB/s 
[?25hRequirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)
Requirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)
Collecting seqeval
  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)
Collecting pyyaml>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 55.1MB/s 
[?25hRequirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)
Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)
[K     |████████████████████████████████| 36.6MB 88kB/s 
[?25hRequirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)
Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)
Requirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)
Requirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)
Collecting py-cpuinfo>=3.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)
[K     |████████████████████████████████| 102kB 11.1MB/s 
[?25hRequirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)
Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)
Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)
Requirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)
Collecting tb-nightly<3.0.0a0,>=2.4.0a0
[?25l  Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)
[K     |████████████████████████████████| 10.1MB 46.1MB/s 
[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)
Collecting tf-estimator-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)
[K     |████████████████████████████████| 460kB 51.7MB/s 
[?25hRequirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)
Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.12.1)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)
Requirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)
Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)
Collecting pybind11>=2.4
[?25l  Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)
[K     |████████████████████████████████| 296kB 55.2MB/s 
[?25hRequirement already satisfied: setuptools in /usr/local/lib/python3.6/dist-packages (from protobuf>=3.8.0->tensorflow-hub>=0.8.0->tflite-model-maker) (50.3.0)
Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)
Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)
Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)
Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)
Requirement already satisfied: python-dateutil in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2.8.1)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)
Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.17.4)
Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)
Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)
Requirement already satisfied: google-auth>=1.4.1 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (1.17.2)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)
Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth>=1.4.1->google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (4.1.1)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Building wheels for collected packages: fire, seqeval, pyyaml, py-cpuinfo
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=9eaa2d36e17621d136f8ab1707a5a4e8994c53d5076a9edde21aab7696ba3e09
  Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7423 sha256=1ce4604da2a395f0304db708bf2e2c1831033ed8b1f7c23927d70ed9ed7b7110
  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44619 sha256=d51b6ef3e90de74d0c1cee8f7aafe0a6d8674348c8437cd89ad5c60a6c3dc726
  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd
  Building wheel for py-cpuinfo (setup.py) ... [?25l[?25hdone
  Created wheel for py-cpuinfo: filename=py_cpuinfo-7.0.0-cp36-none-any.whl size=20071 sha256=096439bff3cb3e4cc21b86472c629017fd9c972d6e2ed231e1a91d2096fc687d
  Stored in directory: /root/.cache/pip/wheels/f1/93/7b/127daf0c3a5a49feb2fecd468d508067c733fba5192f726ad1
Successfully built fire seqeval pyyaml py-cpuinfo
Installing collected packages: fire, flatbuffers, tensorflow-model-optimization, tf-slim, seqeval, pyyaml, opencv-python-headless, sentencepiece, tb-nightly, tf-estimator-nightly, tf-nightly, py-cpuinfo, tf-models-nightly, pybind11, tflite-support, tflite-model-maker
  Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2

Zaimportuj wymagane pakiety.

import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import text_classifier
from tflite_model_maker import TextClassifierDataLoader

Uzyskaj ścieżkę danych

Pobierz zestaw danych do tego samouczka.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')
Downloading data from https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8
7446528/7439277 [==============================] - 0s 0us/step

Możesz również przesłać własny zestaw danych, aby pracować z tym samouczkiem. Prześlij swój zestaw danych za pomocą lewego paska bocznego w Colab.

Przesyłanie pliku

Jeśli wolisz nie przesyłać zestawu danych do chmury, możesz również uruchomić bibliotekę lokalnie, postępując zgodnie z instrukcjami .

Kompleksowy przepływ pracy

Ten przepływ pracy składa się z pięciu kroków opisanych poniżej:

Krok 1. Wybierz specyfikację modelu, która reprezentuje model klasyfikacji tekstu.

W tym samouczku jako przykład wykorzystano MobileBERT .

spec = model_spec.get('mobilebert_classifier')

Krok 2. Załaduj pociąg i przetestuj dane specyficzne dla aplikacji ML na urządzeniu i model_spec wstępnie dane zgodnie z określonym model_spec .

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

Krok 3. Dostosuj model TensorFlow.

model = text_classifier.create(train_data, model_spec=spec)

Krok 4. Oceń model.

loss, acc = model.evaluate(test_data)

Krok 5. Eksportuj jako model TensorFlow Lite z metadanymi .

Ponieważ MobileBERT jest zbyt duży dla aplikacji na urządzeniu, użyj kwantyzacji zakresu dynamicznego w modelu, aby skompresować go prawie czterokrotnie przy minimalnym spadku wydajności.

config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config._experimental_new_quantizer = True
model.export(export_dir='mobilebert/', quantization_config=config)

Możesz także pobrać model za pomocą lewego paska bocznego w Colab.

Po wykonaniu powyższych 5 kroków można dalej używać pliku modelu TensorFlow Lite w aplikacjach na urządzeniu przy użyciu interfejsu API BertNLClassifier w bibliotece zadań TensorFlow Lite .

W poniższych sekcjach krok po kroku omówiono przykład, aby pokazać więcej szczegółów.

Wybierz model_spec który reprezentuje model klasyfikatora tekstu

Każdy obiekt model_spec reprezentuje określony model klasyfikatora tekstu. TensorFlow Lite Model Maker obsługuje obecnie MobileBERT , osadzanie słów uśredniających i modele BERT-Base .

Obsługiwany model Nazwa modelu_spec Opis modelu
MobileBERT „mobilebert_classifier” 4,3x mniejszy i 5,5x szybszy niż BERT-Base przy jednoczesnym osiągnięciu konkurencyjnych wyników, odpowiednich do zastosowań na urządzeniu.
Baza BERT „bert_classifier” Standardowy model BERT, który jest szeroko stosowany w zadaniach NLP.
uśrednianie osadzania słów „Average_word_vec” Uśrednianie osadzania słów w tekście z aktywacją RELU.

W tym samouczku zastosowano mniejszy model, average_word_vec , który można wielokrotnie szkolić, aby zademonstrować proces.

spec = model_spec.get('average_word_vec')

Załaduj dane wejściowe specyficzne dla aplikacji ML na urządzeniu

SST-2 (Stanford Sentiment Treebank) jest jednym z zadań w benchmarku GLUE . Zawiera 67 349 recenzji filmów do celów szkoleniowych i 872 recenzji filmów do weryfikacji. Zbiór danych ma dwie klasy: pozytywne i negatywne recenzje filmów.

Pobierz zarchiwizowaną wersję zbioru danych i rozpakuj ją.

data_dir = tf.keras.utils.get_file(
      fname='SST-2.zip',
      origin='https://firebasestorage.googleapis.com/v0/b/mtl-sentence-representations.appspot.com/o/data%2FSST-2.zip?alt=media&token=aabc5f6b-e466-44a2-b9b4-cf6337f84ac8',
      extract=True)
data_dir = os.path.join(os.path.dirname(data_dir), 'SST-2')

train.tsv danych SST-2 zawiera train.tsv do szkolenia i dev.tsv do walidacji. Pliki mają następujący format:

zdanie etykieta
to urocza i często poruszająca podróż. 1
niezachwianie ponury i zdesperowany 0

Pozytywna recenzja jest oznaczona jako 1, a negatywna jako 0.

Użyj metody TestClassifierDataLoader.from_csv , aby załadować dane.

train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=True)
test_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'dev.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=spec,
      delimiter='\t',
      is_training=False)

Biblioteka Model Maker obsługuje również metodę from_folder() do ładowania danych. Zakłada, że ​​dane tekstowe tej samej klasy znajdują się w tym samym podkatalogu, a nazwa podfolderu jest nazwą klasy. Każdy plik tekstowy zawiera jedną próbkę recenzji filmu. Parametr class_labels służy do określenia, które podfoldery.

Dostosuj model TensorFlow

Utwórz niestandardowy model klasyfikatora tekstu na podstawie załadowanych danych.

model = text_classifier.create(train_data, model_spec=spec, epochs=10)

Sprawdź szczegółową strukturę modelu.

model.summary()

Oceń dostosowany model

Oceń model za pomocą danych testowych i uzyskaj jego utratę i dokładność.

loss, acc = model.evaluate(test_data)

Eksportuj jako model TensorFlow Lite

Przekonwertuj istniejący model na format modelu TensorFlow Lite z metadanymi , których możesz później użyć w aplikacji ML na urządzeniu. Plik etykiety i plik vocab są osadzone w metadanych. Domyślna nazwa pliku model.tflite to model.tflite .

model.export(export_dir='average_word_vec/')

Plik modelu TensorFlow Lite może być używany w aplikacji referencyjnej do klasyfikacji tekstu przy użyciu interfejsu API NLClassifier w bibliotece zadań TensorFlow Lite .

Dozwolonymi formatami eksportu może być jeden z następujących formatów lub ich lista:

  • ExportFormat.TFLITE
  • ExportFormat.LABEL
  • ExportFormat.VOCAB
  • ExportFormat.SAVED_MODEL

Domyślnie po prostu eksportuje model TensorFlow Lite z metadanymi. Możesz także selektywnie eksportować różne pliki. Na przykład eksportowanie tylko pliku etykiety i pliku słownika w następujący sposób:

model.export(export_dir='average_word_vec/', export_format=[ExportFormat.LABEL, ExportFormat.VOCAB])

Można oszacowania możliwych model tflite z evaluate_tflite metoda, aby uzyskać jego dokładność.

accuracy = model.evaluate_tflite('average_word_vec/model.tflite', test_data)

Zaawansowane użycie

Funkcja create jest funkcją sterownika używaną przez bibliotekę Model Maker do tworzenia modeli. Parametr model_spec definiuje specyfikację modelu. Obecnie obsługiwane są klasy AverageWordVecModelSpec i BertClassifierModelSpec . Funkcja create składa się z następujących kroków:

  1. Tworzy model dla klasyfikatora tekstu zgodnie z model_spec .
  2. Trenuje model klasyfikatora. Domyślne epoki i domyślny rozmiar wsadu są ustawiane przez zmienne default_training_epochs i default_batch_size w obiekcie model_spec .

W tej sekcji omówiono zaawansowane zagadnienia dotyczące użytkowania, takie jak dostosowywanie modelu i hiperparametry szkoleniowe.

Dostosuj model

Infrastrukturę modelu można dostosować, na przykład wordvec_dim i seq_len w klasie AverageWordVecModelSpec .

Na przykład możesz uczyć model z większą wartością wordvec_dim . Zauważ, że musisz utworzyć nową model_spec jeśli modyfikujesz model.

new_model_spec = model_spec.AverageWordVecModelSpec(wordvec_dim=32)

Uzyskaj wstępnie przetworzone dane.

new_train_data = TextClassifierDataLoader.from_csv(
      filename=os.path.join(os.path.join(data_dir, 'train.tsv')),
      text_column='sentence',
      label_column='label',
      model_spec=new_model_spec,
      delimiter='\t',
      is_training=True)

Wytrenuj nowy model.

model = text_classifier.create(new_train_data, model_spec=new_model_spec)

Możesz także dostosować model MobileBERT.

Parametry modelu, które możesz dostosować, to:

  • seq_len : Długość sekwencji podawanej do modelu.
  • initializer_range : odchylenie standardowe truncated_normal_initializer do inicjalizacji wszystkich macierzy wagi.
  • trainable : Boolean, który określa, czy wstępnie wytrenowana warstwa jest możliwa do trenowania.

Parametry potoku szkoleniowego, które możesz dostosować, to:

  • model_dir : lokalizacja plików punktów kontrolnych modelu. Jeśli nie jest ustawiona, zostanie użyty katalog tymczasowy.
  • dropout_rate : współczynnik porzucania.
  • learning_rate : początkowy współczynnik uczenia się optymalizatora Adama.
  • tpu : adres TPU do połączenia.

Na przykład, możesz ustawić seq_len=256 (domyślnie 128). Dzięki temu model może klasyfikować dłuższy tekst.

new_model_spec = model_spec.get('mobilebert_classifier')
new_model_spec.seq_len = 256

Dostrój hiperparametry treningu

Możesz również dostroić hiperparametry batch_size , takie jak epochs i batch_size które mają wpływ na dokładność modelu. Na przykład,

  • epochs : więcej epok mogłoby zapewnić lepszą dokładność, ale może prowadzić do nadmiernego dopasowania.
  • batch_size : liczba próbek do wykorzystania w jednym kroku szkoleniowym.

Na przykład możesz trenować z większą liczbą epok.

model = text_classifier.create(train_data, model_spec=spec, epochs=20)

Oceń nowo wyszkolony model z 20 okresami treningowymi.

loss, accuracy = model.evaluate(test_data)

Zmień architekturę modelu

Możesz zmienić model, zmieniając model_spec . Poniżej pokazano, jak przejść do modelu BERT-Base.

Zmień model_spec na model BERT-Base dla klasyfikatora tekstu.

spec = model_spec.get('bert_classifier')

Pozostałe kroki są takie same.