用于 TFLite 的 Jax 模型转换

概述

注:此为新 API ,只有通过 pip 安装 tf-nighly 才能使用。它将在 TensorFlow 2.7 版中提供。另外,此 API 仍处于实验阶段,可能会发生变化。

此 CodeLab 演示了如何使用 Jax 构建 MNIST 识别模型,以及如何将其转换为 TensorFlow Lite。此 CodeLab 还将演示如何使用训练后量化来优化 Jax 转换的 TFLite 模型。

在 TensorFlow.org 上查看 在 Google Colab 运行 在 Github 上查看源代码 下载笔记本

先决条件

建议在最新的 TensorFlow nightly pip 构建中尝试此功能。

pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade

数据准备

使用 Keras 数据集下载 MNIST 数据并进行预处理。

import numpy as np
import tensorflow as tf
import functools

import time
import itertools

import numpy.random as npr

import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
2023-11-07 21:50:53.094555: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10778] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 21:50:53.094603: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 21:50:53.094963: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1533] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
def _one_hot(x, k, dtype=np.float32):
  """Create a one-hot encoding of x of size k."""
  return np.array(x[:, None] == np.arange(k), dtype)

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)

train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz
11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step

使用 Jax 构建 MNIST 模型

def loss(params, batch):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

init_random_params, predict = stax.serial(
    stax.Flatten,
    stax.Dense(1024), stax.Relu,
    stax.Dense(1024), stax.Relu,
    stax.Dense(10), stax.LogSoftmax)

rng = random.PRNGKey(0)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.

训练并评估模型

step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9


num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)

def data_stream():
  rng = npr.RandomState(0)
  while True:
    perm = rng.permutation(num_train)
    for i in range(num_batches):
      batch_idx = perm[i * batch_size:(i + 1) * batch_size]
      yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()

opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)

@jit
def update(i, opt_state, batch):
  params = get_params(opt_state)
  return opt_update(i, grad(loss)(params, batch), opt_state)

_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()

print("\nStarting training...")
for epoch in range(num_epochs):
  start_time = time.time()
  for _ in range(num_batches):
    opt_state = update(next(itercount), opt_state, next(batches))
  epoch_time = time.time() - start_time

  params = get_params(opt_state)
  train_acc = accuracy(params, (train_images, train_labels))
  test_acc = accuracy(params, (test_images, test_labels))
  print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
  print("Training set accuracy {}".format(train_acc))
  print("Test set accuracy {}".format(test_acc))
Starting training...
Epoch 0 in 4.47 sec
Training set accuracy 0.8728833198547363
Test set accuracy 0.880299985408783
Epoch 1 in 2.45 sec
Training set accuracy 0.8983833193778992
Test set accuracy 0.9047999978065491
Epoch 2 in 2.45 sec
Training set accuracy 0.9102333188056946
Test set accuracy 0.9138000011444092
Epoch 3 in 2.47 sec
Training set accuracy 0.9172333478927612
Test set accuracy 0.9218999743461609
Epoch 4 in 2.42 sec
Training set accuracy 0.9224833250045776
Test set accuracy 0.9253999590873718
Epoch 5 in 2.45 sec
Training set accuracy 0.9272000193595886
Test set accuracy 0.9309999942779541
Epoch 6 in 2.44 sec
Training set accuracy 0.9328166842460632
Test set accuracy 0.9334999918937683
Epoch 7 in 2.41 sec
Training set accuracy 0.9360166788101196
Test set accuracy 0.9370999932289124
Epoch 8 in 2.44 sec
Training set accuracy 0.939050018787384
Test set accuracy 0.939300000667572
Epoch 9 in 2.45 sec
Training set accuracy 0.9425666928291321
Test set accuracy 0.9429000020027161

转换为 TFLite 模型

请注意,我们需要执行以下操作:

  1. 使用 functools.partial 将参数内联到 Jax predict 函数。
  2. 构建一个 jnp.zeros,这是一个用于 Jax 跟踪模型的“占位符”张量。
  3. 调用 experimental_from_jax
  • serving_func 被封装在一个列表中。
  • 输入与给定的名称相关联,并作为封装在列表中的数组传入。
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
  f.write(tflite_model)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_38895/3025848042.py:3: TFLiteConverterV2.experimental_from_jax (from tensorflow.lite.python.lite) is deprecated and will be removed in a future version.
Instructions for updating:
Use `jax2tf.convert` and (`lite.TFLiteConverter.from_saved_model` or `lite.TFLiteConverter.from_concrete_functions`) instead.
2023-11-07 21:51:31.217318: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 21:51:31.217364: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 21:51:31.217371: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 %
 * 7 ARITH ops

- arith.constant:    7 occurrences  (f32: 6, i32: 1)



  (f32: 2)
  (f32: 3)
  (f32: 1)

  (f32: 1)

检查转换后的 TFLite 模型

将转换后的模型的结果与 Jax 模型进行比较。

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.

优化模型

我们将提供一个 representative_dataset 来进行训练后量化,以优化模型。

def representative_dataset():
  for i in range(1000):
    x = train_images[i:i+1]
    yield [x]

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
  f.write(tflite_quant_model)
2023-11-07 21:51:31.978225: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 21:51:31.978276: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 21:51:31.978283: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 %
 * 7 ARITH ops

- arith.constant:    7 occurrences  (f32: 6, i32: 1)



  (f32: 2)
  (f32: 3)
  (f32: 1)

  (f32: 1)
2023-11-07 21:51:32.321877: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format.
2023-11-07 21:51:32.321927: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency.
2023-11-07 21:51:32.321934: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges.
Summary on the non-converted ops:
---------------------------------

 * Accepted dialects: tfl, builtin, func
 * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 %
 * 7 ARITH ops

- arith.constant:    7 occurrences  (f32: 6, i32: 1)



  (f32: 2)
  (f32: 3)
  (f32: 1)

  (f32: 1)
fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32

评估优化后的模型

expected = serving_func(train_images[0:1])

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])

# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)

比较量化模型大小

我们应该能够看到,量化模型的大小缩减为了原始模型的四分之一。

du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M    jax_mnist.tflite
1.8M    jax_mnist_quant.tflite