Question Answer with TensorFlow Lite Model Maker

View on TensorFlow.org View source on GitHub Download notebook

The TensorFlow Lite Model Maker library simplifies the process of adapting and converting a TensorFlow model to particular input data when deploying this model for on-device ML applications.

This notebook shows an end-to-end example that utilizes the Model Maker library to illustrate the adaptation and conversion of a commonly-used question answer model for question answer task.

Introduction to Question Answer Task

The supported task in this library is extractive question answer task, which means given a passage and a question, the answer is the span in the passage. The image below shows an example for question answer.

Answers are spans in the passage (image credit: SQuAD blog)

As for the model of question answer task, the inputs should be the passage and question pair that are already preprocessed, the outputs should be the start logits and end logits for each token in the passage. The size of input could be set and adjusted according to the length of passage and question.

End-to-End Overview

The following code snippet demonstrates how to get the model within a few lines of code. The overall process includes 5 steps: (1) choose a model, (2) load data, (3) retrain the model, (4) evaluate, and (5) export it to TensorFlow Lite format.

# Chooses a model specification that represents the model.
spec = model_spec.get('mobilebert_qa')

# Gets the training data and validation data.
train_data = QuestionAnswerDataLoader.from_squad(train_data_path, spec, is_training=True)
validation_data = QuestionAnswerDataLoader.from_squad(validation_data_path, spec, is_training=False)

# Fine-tunes the model.
model = question_answer.create(train_data, model_spec=spec)

# Gets the evaluation result.
metric = model.evaluate(validation_data)

# Exports the model to the TensorFlow Lite format with metadata in the export directory.
model.export(export_dir)

The following sections explain the code in more detail.

Prerequisites

To run this example, install the required packages, including the Model Maker package from the GitHub repo.

pip install tflite-model-maker
Collecting tflite-model-maker
[?25l  Downloading https://files.pythonhosted.org/packages/13/bc/4c23b9cb9ef612a1f48bac5543bd531665de5eab8f8231111aac067f8c30/tflite_model_maker-0.1.2-py3-none-any.whl (104kB)
[K     |████████████████████████████████| 112kB 3.0MB/s 
[?25hRequirement already satisfied: absl-py in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.10.0)
Collecting tf-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/33/d4/61c47ae889b490b9c5f07f4f61bdc057c158a1a1979c375fa019d647a19e/tf_nightly-2.4.0.dev20200914-cp36-cp36m-manylinux2010_x86_64.whl (390.1MB)
[K     |████████████████████████████████| 390.2MB 43kB/s 
[?25hRequirement already satisfied: numpy>=1.17.3 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (1.18.5)
Requirement already satisfied: pillow in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (7.0.0)
Collecting tf-models-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/d3/e9/c4e5a451c268a5a75a27949562364f6086f6bb33b226a065a8beceefa9ba/tf_models_nightly-2.3.0.dev20200914-py2.py3-none-any.whl (993kB)
[K     |████████████████████████████████| 1.0MB 57.6MB/s 
[?25hCollecting flatbuffers==1.12
  Downloading https://files.pythonhosted.org/packages/eb/26/712e578c5f14e26ae3314c39a1bdc4eb2ec2f4ddc89b708cf8e0a0d20423/flatbuffers-1.12-py2.py3-none-any.whl
Requirement already satisfied: tensorflow-hub>=0.8.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (0.9.0)
Collecting fire
[?25l  Downloading https://files.pythonhosted.org/packages/34/a7/0e22e70778aca01a52b9c899d9c145c6396d7b613719cd63db97ffa13f2f/fire-0.3.1.tar.gz (81kB)
[K     |████████████████████████████████| 81kB 11.5MB/s 
[?25hCollecting sentencepiece
[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)
[K     |████████████████████████████████| 1.1MB 50.9MB/s 
[?25hCollecting tflite-support==0.1.0rc3.dev2
[?25l  Downloading https://files.pythonhosted.org/packages/fa/c5/5e9ee3abd5b4ef8294432cd714407f49a66befa864905b66ee8bdc612795/tflite_support-0.1.0rc3.dev2-cp36-cp36m-manylinux2010_x86_64.whl (1.0MB)
[K     |████████████████████████████████| 1.0MB 50.9MB/s 
[?25hRequirement already satisfied: tensorflow-datasets>=2.1.0 in /usr/local/lib/python3.6/dist-packages (from tflite-model-maker) (2.1.0)
Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from absl-py->tflite-model-maker) (1.15.0)
Requirement already satisfied: termcolor>=1.1.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.0)
Requirement already satisfied: opt-einsum>=2.3.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.3.0)
Collecting tb-nightly<3.0.0a0,>=2.4.0a0
[?25l  Downloading https://files.pythonhosted.org/packages/fc/cb/4dfe0d65bffb5e9663261ff664e6f5a2d37672b31dae27a0f14721ac00d3/tb_nightly-2.4.0a20200914-py3-none-any.whl (10.1MB)
[K     |████████████████████████████████| 10.1MB 51.4MB/s 
[?25hRequirement already satisfied: typing-extensions>=3.7.4.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.7.4.3)
Requirement already satisfied: wheel>=0.26 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.35.1)
Collecting tf-estimator-nightly
[?25l  Downloading https://files.pythonhosted.org/packages/bd/9a/3bfb9994eda11e426c809ebdf434e2ac5824a0784d980018bb53fd1620ec/tf_estimator_nightly-2.4.0.dev2020091401-py2.py3-none-any.whl (460kB)
[K     |████████████████████████████████| 460kB 36.0MB/s 
[?25hRequirement already satisfied: google-pasta>=0.1.8 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.2.0)
Requirement already satisfied: h5py<2.11.0,>=2.10.0 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (2.10.0)
Requirement already satisfied: keras-preprocessing<1.2,>=1.1.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.1.2)
Requirement already satisfied: wrapt>=1.11.1 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.12.1)
Requirement already satisfied: grpcio>=1.8.6 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.32.0)
Requirement already satisfied: protobuf>=3.9.2 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (3.12.4)
Requirement already satisfied: gast==0.3.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (0.3.3)
Requirement already satisfied: astunparse==1.6.3 in /usr/local/lib/python3.6/dist-packages (from tf-nightly->tflite-model-maker) (1.6.3)
Requirement already satisfied: scipy>=0.19.1 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.4.1)
Collecting pyyaml>=5.1
[?25l  Downloading https://files.pythonhosted.org/packages/64/c2/b80047c7ac2478f9501676c988a5411ed5572f35d1beff9cae07d321512c/PyYAML-5.3.1.tar.gz (269kB)
[K     |████████████████████████████████| 276kB 59.8MB/s 
[?25hCollecting tensorflow-model-optimization>=0.4.1
[?25l  Downloading https://files.pythonhosted.org/packages/55/38/4fd48ea1bfcb0b6e36d949025200426fe9c3a8bfae029f0973d85518fa5a/tensorflow_model_optimization-0.5.0-py2.py3-none-any.whl (172kB)
[K     |████████████████████████████████| 174kB 51.0MB/s 
[?25hRequirement already satisfied: pandas>=0.22.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.0.5)
Requirement already satisfied: dataclasses in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.7)
Requirement already satisfied: Cython in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.29.21)
Collecting opencv-python-headless
[?25l  Downloading https://files.pythonhosted.org/packages/b6/2a/496e06fd289c01dc21b11970be1261c87ce1cc22d5340c14b516160822a7/opencv_python_headless-4.4.0.42-cp36-cp36m-manylinux2014_x86_64.whl (36.6MB)
[K     |████████████████████████████████| 36.6MB 83kB/s 
[?25hRequirement already satisfied: kaggle>=1.3.9 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.5.8)
Requirement already satisfied: pycocotools in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (2.0.2)
Requirement already satisfied: oauth2client in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (4.1.3)
Requirement already satisfied: matplotlib in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (3.2.2)
Collecting tf-slim>=1.1.0
[?25l  Downloading https://files.pythonhosted.org/packages/02/97/b0f4a64df018ca018cc035d44f2ef08f91e2e8aa67271f6f19633a015ff7/tf_slim-1.1.0-py2.py3-none-any.whl (352kB)
[K     |████████████████████████████████| 358kB 55.9MB/s 
[?25hCollecting seqeval
  Downloading https://files.pythonhosted.org/packages/34/91/068aca8d60ce56dd9ba4506850e876aba5e66a6f2f29aa223224b50df0de/seqeval-0.0.12.tar.gz
Requirement already satisfied: psutil>=5.4.3 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (5.4.8)
Collecting py-cpuinfo>=3.3.0
[?25l  Downloading https://files.pythonhosted.org/packages/f6/f5/8e6e85ce2e9f6e05040cf0d4e26f43a4718bcc4bce988b433276d4b1a5c1/py-cpuinfo-7.0.0.tar.gz (95kB)
[K     |████████████████████████████████| 102kB 13.5MB/s 
[?25hRequirement already satisfied: google-api-python-client>=1.6.7 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.7.12)
Requirement already satisfied: gin-config in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.3.0)
Requirement already satisfied: tensorflow-addons in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (0.8.3)
Requirement already satisfied: google-cloud-bigquery>=0.31.0 in /usr/local/lib/python3.6/dist-packages (from tf-models-nightly->tflite-model-maker) (1.21.0)
Collecting pybind11>=2.4
[?25l  Downloading https://files.pythonhosted.org/packages/89/e3/d576f6f02bc75bacbc3d42494e8f1d063c95617d86648dba243c2cb3963e/pybind11-2.5.0-py2.py3-none-any.whl (296kB)
[K     |████████████████████████████████| 296kB 47.9MB/s 
[?25hRequirement already satisfied: promise in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.3)
Requirement already satisfied: tensorflow-metadata in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.24.0)
Requirement already satisfied: requests>=2.19.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (2.23.0)
Requirement already satisfied: dill in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.3.2)
Requirement already satisfied: attrs>=18.1.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (20.2.0)
Requirement already satisfied: tqdm in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (4.41.1)
Requirement already satisfied: future in /usr/local/lib/python3.6/dist-packages (from tensorflow-datasets>=2.1.0->tflite-model-maker) (0.16.0)
Requirement already satisfied: werkzeug>=0.11.15 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.0.1)
Requirement already satisfied: setuptools>=41.0.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (50.3.0)
Requirement already satisfied: tensorboard-plugin-wit>=1.6.0 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: google-auth-oauthlib<0.5,>=0.4.1 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: google-auth<2,>=1.6.3 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.17.2)
Requirement already satisfied: markdown>=2.6.8 in /usr/local/lib/python3.6/dist-packages (from tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.2.2)
Requirement already satisfied: dm-tree~=0.1.1 in /usr/local/lib/python3.6/dist-packages (from tensorflow-model-optimization>=0.4.1->tf-models-nightly->tflite-model-maker) (0.1.5)
Requirement already satisfied: pytz>=2017.2 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2018.9)
Requirement already satisfied: python-dateutil>=2.6.1 in /usr/local/lib/python3.6/dist-packages (from pandas>=0.22.0->tf-models-nightly->tflite-model-maker) (2.8.1)
Requirement already satisfied: certifi in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (2020.6.20)
Requirement already satisfied: python-slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (4.0.1)
Requirement already satisfied: slugify in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (0.0.1)
Requirement already satisfied: urllib3<1.25,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.24.3)
Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.4.8)
Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (4.6)
Requirement already satisfied: httplib2>=0.9.1 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.17.4)
Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.6/dist-packages (from oauth2client->tf-models-nightly->tflite-model-maker) (0.2.8)
Requirement already satisfied: kiwisolver>=1.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (1.2.0)
Requirement already satisfied: pyparsing!=2.0.4,!=2.1.2,!=2.1.6,>=2.0.1 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (2.4.7)
Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.6/dist-packages (from matplotlib->tf-models-nightly->tflite-model-maker) (0.10.0)
Requirement already satisfied: Keras>=2.2.4 in /usr/local/lib/python3.6/dist-packages (from seqeval->tf-models-nightly->tflite-model-maker) (2.4.3)
Requirement already satisfied: google-auth-httplib2>=0.0.3 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (0.0.4)
Requirement already satisfied: uritemplate<4dev,>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from google-api-python-client>=1.6.7->tf-models-nightly->tflite-model-maker) (3.0.1)
Requirement already satisfied: typeguard in /usr/local/lib/python3.6/dist-packages (from tensorflow-addons->tf-models-nightly->tflite-model-maker) (2.7.1)
Requirement already satisfied: google-cloud-core<2.0dev,>=1.0.3 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.0.3)
Requirement already satisfied: google-resumable-media!=0.4.0,<0.5.0dev,>=0.3.1 in /usr/local/lib/python3.6/dist-packages (from google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (0.4.1)
Requirement already satisfied: googleapis-common-protos<2,>=1.52.0 in /usr/local/lib/python3.6/dist-packages (from tensorflow-metadata->tensorflow-datasets>=2.1.0->tflite-model-maker) (1.52.0)
Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (3.0.4)
Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests>=2.19.0->tensorflow-datasets>=2.1.0->tflite-model-maker) (2.10)
Requirement already satisfied: requests-oauthlib>=0.7.0 in /usr/local/lib/python3.6/dist-packages (from google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.3.0)
Requirement already satisfied: cachetools<5.0,>=2.0.0 in /usr/local/lib/python3.6/dist-packages (from google-auth<2,>=1.6.3->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (4.1.1)
Requirement already satisfied: importlib-metadata; python_version < "3.8" in /usr/local/lib/python3.6/dist-packages (from markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (1.7.0)
Requirement already satisfied: text-unidecode>=1.3 in /usr/local/lib/python3.6/dist-packages (from python-slugify->kaggle>=1.3.9->tf-models-nightly->tflite-model-maker) (1.3)
Requirement already satisfied: google-api-core<2.0.0dev,>=1.14.0 in /usr/local/lib/python3.6/dist-packages (from google-cloud-core<2.0dev,>=1.0.3->google-cloud-bigquery>=0.31.0->tf-models-nightly->tflite-model-maker) (1.16.0)
Requirement already satisfied: oauthlib>=3.0.0 in /usr/local/lib/python3.6/dist-packages (from requests-oauthlib>=0.7.0->google-auth-oauthlib<0.5,>=0.4.1->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Requirement already satisfied: zipp>=0.5 in /usr/local/lib/python3.6/dist-packages (from importlib-metadata; python_version < "3.8"->markdown>=2.6.8->tb-nightly<3.0.0a0,>=2.4.0a0->tf-nightly->tflite-model-maker) (3.1.0)
Building wheels for collected packages: fire, pyyaml, seqeval, py-cpuinfo
  Building wheel for fire (setup.py) ... [?25l[?25hdone
  Created wheel for fire: filename=fire-0.3.1-py2.py3-none-any.whl size=111005 sha256=f0b82e6b31e21d6db3591478a37188c727533acefe415b16b456c85ef9bef47c
  Stored in directory: /root/.cache/pip/wheels/c1/61/df/768b03527bf006b546dce284eb4249b185669e65afc5fbb2ac
  Building wheel for pyyaml (setup.py) ... [?25l[?25hdone
  Created wheel for pyyaml: filename=PyYAML-5.3.1-cp36-cp36m-linux_x86_64.whl size=44619 sha256=cdbc63ead8369d7403f47b1adff163ebde2636c9f0c2a5ebd6413d156b2b7a9f
  Stored in directory: /root/.cache/pip/wheels/a7/c1/ea/cf5bd31012e735dc1dfea3131a2d5eae7978b251083d6247bd
  Building wheel for seqeval (setup.py) ... [?25l[?25hdone
  Created wheel for seqeval: filename=seqeval-0.0.12-cp36-none-any.whl size=7423 sha256=3ac4a1cc3b88a9b1a1ed8217f2b8d3abb7f936e853383025888b94019d98a856
  Stored in directory: /root/.cache/pip/wheels/4f/32/0a/df3b340a82583566975377d65e724895b3fad101a3fb729f68
  Building wheel for py-cpuinfo (setup.py) ... [?25l[?25hdone
  Created wheel for py-cpuinfo: filename=py_cpuinfo-7.0.0-cp36-none-any.whl size=20071 sha256=b5491e6fcabbf9ae464c0def53ec6ec27bbf01230ff96f4e34c6a7c44d55d5c9
  Stored in directory: /root/.cache/pip/wheels/f1/93/7b/127daf0c3a5a49feb2fecd468d508067c733fba5192f726ad1
Successfully built fire pyyaml seqeval py-cpuinfo
Installing collected packages: tb-nightly, flatbuffers, tf-estimator-nightly, tf-nightly, pyyaml, tensorflow-model-optimization, opencv-python-headless, sentencepiece, tf-slim, seqeval, py-cpuinfo, tf-models-nightly, fire, pybind11, tflite-support, tflite-model-maker
  Found existing installation: PyYAML 3.13
    Uninstalling PyYAML-3.13:
      Successfully uninstalled PyYAML-3.13
Successfully installed fire-0.3.1 flatbuffers-1.12 opencv-python-headless-4.4.0.42 py-cpuinfo-7.0.0 pybind11-2.5.0 pyyaml-5.3.1 sentencepiece-0.1.91 seqeval-0.0.12 tb-nightly-2.4.0a20200914 tensorflow-model-optimization-0.5.0 tf-estimator-nightly-2.4.0.dev2020091401 tf-models-nightly-2.3.0.dev20200914 tf-nightly-2.4.0.dev20200914 tf-slim-1.1.0 tflite-model-maker-0.1.2 tflite-support-0.1.0rc3.dev2

Import the required packages.

import numpy as np
import os

import tensorflow as tf
assert tf.__version__.startswith('2')

from tflite_model_maker import configs
from tflite_model_maker import ExportFormat
from tflite_model_maker import model_spec
from tflite_model_maker import question_answer
from tflite_model_maker import QuestionAnswerDataLoader

The "End-to-End Overview" demonstrates a simple end-to-end example. The following sections walk through the example step by step to show more detail.

Choose a model_spec that represents a model for question answer

Each model_spec object represents a specific model for question answer. The Model Maker currently supports MobileBERT and BERT-Base models.

Supported Model Name of model_spec Model Description
MobileBERT 'mobilebert_qa' 4.3x smaller and 5.5x faster than BERT-Base while achieving competitive results, suitable for on-device scenario.
MobileBERT-SQuAD 'mobilebert_qa_squad' Same model architecture as MobileBERT model and the initial model is already retrained on SQuAD1.1.
BERT-Base 'bert_qa' Standard BERT model that widely used in NLP tasks.

In this tutorial, MobileBERT-SQuAD is used as an example. Since the model is already retrained on SQuAD1.1, it could coverage faster for question answer task.

spec = model_spec.get('mobilebert_qa_squad')

Load Input Data Specific to an On-device ML App and Preprocess the Data

The TriviaQA is a reading comprehension dataset containing over 650K question-answer-evidence triples. In this tutorial, you will use a subset of this dataset to learn how to use the Model Maker library.

To load the data, convert the TriviaQA dataset to the SQuAD1.1 format by running the converter Python script with --sample_size=8000 and a set of web data. Modify the conversion code a little bit by:

  • Skipping the samples that couldn't find any answer in the context document;
  • Getting the original answer in the context without uppercase or lowercase.

Download the archived version of the already converted dataset.

train_data_path = tf.keras.utils.get_file(
    fname='triviaqa-web-train-8000.json',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json')
validation_data_path = tf.keras.utils.get_file(
    fname='triviaqa-verified-web-dev.json',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-web-train-8000.json
32571392/32570663 [==============================] - 1s 0us/step
Downloading data from https://storage.googleapis.com/download.tensorflow.org/models/tflite/dataset/triviaqa-verified-web-dev.json
1171456/1167744 [==============================] - 0s 0us/step

You can also train the MobileBERT model with your own dataset. If you are running this notebook on Colab, upload your data by using the left sidebar.

Upload File

If you prefer not to upload your data to the cloud, you can also run the library offline by following the guide.

Use the QuestionAnswerDataLoader.from_squad method to load and preprocess the SQuAD format data according to a specific model_spec. You can use either SQuAD2.0 or SQuAD1.1 formats. Setting parameter version_2_with_negative as True means the formats is SQuAD2.0. Otherwise, the format is SQuAD1.1. By default, version_2_with_negative is False.

train_data = QuestionAnswerDataLoader.from_squad(train_data_path, spec, is_training=True)
validation_data = QuestionAnswerDataLoader.from_squad(validation_data_path, spec, is_training=False)

Customize the TensorFlow Model

Create a custom question answer model based on the loaded data. The create function comprises the following steps:

  1. Creates the model for question answer according to model_spec.
  2. Train the question answer model. The default epochs and the default batch size are set according to two variables default_training_epochs and default_batch_size in the model_spec object.
model = question_answer.create(train_data, model_spec=spec)
INFO:tensorflow:Retraining the models...

INFO:tensorflow:Retraining the models...

Epoch 1/2

Have a look at the detailed model structure.

model.summary()

Evaluate the Customized Model

Evaluate the model on the validation data and get a dict of metrics including f1 score and exact match etc. Note that metrics are different for SQuAD1.1 and SQuAD2.0.

model.evaluate(validation_data)

Export to TensorFlow Lite Model

Convert the existing model to TensorFlow Lite model format that you can later use in an on-device ML application.

Since MobileBERT is too big for on-device applications, use dynamic range quantization on the model to compress MobileBERT by 4x with the minimal loss of performance. First, define the quantization configuration:

config = configs.QuantizationConfig.create_dynamic_range_quantization(optimizations=[tf.lite.Optimize.OPTIMIZE_FOR_LATENCY])
config._experimental_new_quantizer = True

Export the quantized TFLite model according to the quantization config with metadata. The default TFLite model filename is model.tflite.

model.export(export_dir='.', quantization_config=config)

You can use the TensorFlow Lite model file in the bert_qa reference app using BertQuestionAnswerer API in TensorFlow Lite Task Library by downloading it from the left sidebar on Colab.

The allowed export formats can be one or a list of the following:

  • ExportFormat.TFLITE
  • ExportFormat.VOCAB
  • ExportFormat.SAVED_MODEL

By default, it just exports TensorFlow Lite model with metadata. You can also selectively export different files. For instance, exporting only the vocab file as follows:

model.export(export_dir='.', export_format=ExportFormat.VOCAB)

You can also evalute the tflite model with the evaluate_tflite method. This step is expected to take a long time.

model.evaluate_tflite('model.tflite', validation_data)

Advanced Usage

The create function is the critical part of this library in which the model_spec parameter defines the model specification. The BertQAModelSpec class is currently supported. There are 2 models: MobileBERT model, BERT-Base model. The create function comprises the following steps:

  1. Creates the model for question answer according to model_spec.
  2. Train the question answer model.

This section describes several advanced topics, including adjusting the model, tuning the training hyperparameters etc.

Adjust the model

You can adjust the model infrastructure like parameters seq_len and query_len in the BertQAModelSpec class.

Adjustable parameters for model:

  • seq_len: Length of the passage to feed into the model.
  • query_len: Length of the question to feed into the model.
  • doc_stride: The stride when doing a sliding window approach to take chunks of the documents.
  • initializer_range: The stdev of the truncated_normal_initializer for initializing all weight matrices.
  • trainable: Boolean, whether pre-trained layer is trainable.

Adjustable parameters for training pipeline:

  • model_dir: The location of the model checkpoint files. If not set, temporary directory will be used.
  • dropout_rate: The rate for dropout.
  • learning_rate: The initial learning rate for Adam.
  • predict_batch_size: Batch size for prediction.
  • tpu: TPU address to connect to. Only used if using tpu.

For example, you can train the model with a longer sequence length. If you change the model, you must first construct a new model_spec.

new_spec = model_spec.get('mobilebert_qa')
new_spec.seq_len = 512

The remaining steps are the same. Note that you must rerun both the dataloader and create parts as different model specs may have different preprocessing steps.

Tune training hyperparameters

You can also tune the training hyperparameters like epochs and batch_size to impact the model performance. For instance,

  • epochs: more epochs could achieve better performance, but may lead to overfitting.
  • batch_size: number of samples to use in one training step.

For example, you can train with more epochs and with a bigger batch size like:

model = question_answer.create(train_data, model_spec=spec, epochs=5, batch_size=64)

Change the Model Architecture

You can change the base model your data trains on by changing the model_spec. For example, to change to the BERT-Base model, run:

spec = model_spec.get('bert_qa')

The remaining steps are the same.