概述
注:此为新 API ,只有通过 pip 安装 tf-nighly 才能使用。它将在 TensorFlow 2.7 版中提供。另外,此 API 仍处于实验阶段,可能会发生变化。
此 CodeLab 演示了如何使用 Jax 构建 MNIST 识别模型,以及如何将其转换为 TensorFlow Lite。此 CodeLab 还将演示如何使用训练后量化来优化 Jax 转换的 TFLite 模型。
在 TensorFlow.org 上查看 | 在 Google Colab 运行 | 在 Github 上查看源代码 | 下载笔记本 |
先决条件
建议在最新的 TensorFlow nightly pip 构建中尝试此功能。
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade
数据准备
使用 Keras 数据集下载 MNIST 数据并进行预处理。
import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
2023-11-07 21:50:53.094555: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:10778] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2023-11-07 21:50:53.094603: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2023-11-07 21:50:53.094963: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1533] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz 11490434/11490434 ━━━━━━━━━━━━━━━━━━━━ 0s 0us/step
使用 Jax 构建 MNIST 模型
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
stax.Flatten,
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
rng = random.PRNGKey(0)
An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.
训练并评估模型
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
@jit
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Starting training... Epoch 0 in 4.47 sec Training set accuracy 0.8728833198547363 Test set accuracy 0.880299985408783 Epoch 1 in 2.45 sec Training set accuracy 0.8983833193778992 Test set accuracy 0.9047999978065491 Epoch 2 in 2.45 sec Training set accuracy 0.9102333188056946 Test set accuracy 0.9138000011444092 Epoch 3 in 2.47 sec Training set accuracy 0.9172333478927612 Test set accuracy 0.9218999743461609 Epoch 4 in 2.42 sec Training set accuracy 0.9224833250045776 Test set accuracy 0.9253999590873718 Epoch 5 in 2.45 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.9309999942779541 Epoch 6 in 2.44 sec Training set accuracy 0.9328166842460632 Test set accuracy 0.9334999918937683 Epoch 7 in 2.41 sec Training set accuracy 0.9360166788101196 Test set accuracy 0.9370999932289124 Epoch 8 in 2.44 sec Training set accuracy 0.939050018787384 Test set accuracy 0.939300000667572 Epoch 9 in 2.45 sec Training set accuracy 0.9425666928291321 Test set accuracy 0.9429000020027161
转换为 TFLite 模型
请注意,我们需要执行以下操作:
- 使用
functools.partial
将参数内联到 Jaxpredict
函数。 - 构建一个
jnp.zeros
,这是一个用于 Jax 跟踪模型的“占位符”张量。 - 调用
experimental_from_jax
:
serving_func
被封装在一个列表中。- 输入与给定的名称相关联,并作为封装在列表中的数组传入。
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
f.write(tflite_model)
WARNING:tensorflow:From /tmpfs/tmp/ipykernel_38895/3025848042.py:3: TFLiteConverterV2.experimental_from_jax (from tensorflow.lite.python.lite) is deprecated and will be removed in a future version. Instructions for updating: Use `jax2tf.convert` and (`lite.TFLiteConverter.from_saved_model` or `lite.TFLiteConverter.from_concrete_functions`) instead. 2023-11-07 21:51:31.217318: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-11-07 21:51:31.217364: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-11-07 21:51:31.217371: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 % * 7 ARITH ops - arith.constant: 7 occurrences (f32: 6, i32: 1) (f32: 2) (f32: 3) (f32: 1) (f32: 1)
检查转换后的 TFLite 模型
将转换后的模型的结果与 Jax 模型进行比较。
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
INFO: Created TensorFlow Lite XNNPACK delegate for CPU.
优化模型
我们将提供一个 representative_dataset
来进行训练后量化,以优化模型。
def representative_dataset():
for i in range(1000):
x = train_images[i:i+1]
yield [x]
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
f.write(tflite_quant_model)
2023-11-07 21:51:31.978225: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-11-07 21:51:31.978276: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-11-07 21:51:31.978283: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 % * 7 ARITH ops - arith.constant: 7 occurrences (f32: 6, i32: 1) (f32: 2) (f32: 3) (f32: 1) (f32: 1) 2023-11-07 21:51:32.321877: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:378] Ignored output_format. 2023-11-07 21:51:32.321927: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:381] Ignored drop_control_dependency. 2023-11-07 21:51:32.321934: W tensorflow/compiler/mlir/lite/python/tf_tfl_flatbuffer_helpers.cc:387] Ignored change_concat_input_ranges. Summary on the non-converted ops: --------------------------------- * Accepted dialects: tfl, builtin, func * Non-Converted Ops: 7, Total Ops 18, % non-converted = 38.89 % * 7 ARITH ops - arith.constant: 7 occurrences (f32: 6, i32: 1) (f32: 2) (f32: 3) (f32: 1) (f32: 1) fully_quantize: 0, inference_type: 6, input_inference_type: FLOAT32, output_inference_type: FLOAT32
评估优化后的模型
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)
比较量化模型大小
我们应该能够看到,量化模型的大小缩减为了原始模型的四分之一。
du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M jax_mnist.tflite 1.8M jax_mnist_quant.tflite