注:此为新 API ,只有通过 pip 安装 tf-nighly 才能使用。它将在 TensorFlow 2.7 版中提供。另外,此 API 仍处于实验阶段,可能会发生变化。
此 CodeLab 演示了如何使用 Jax 构建 MNIST 识别模型,以及如何将其转换为 TensorFlow Lite。此 CodeLab 还将演示如何使用训练后量化来优化 Jax 转换的 TFLite 模型。
建议在最新的 TensorFlow nightly pip 构建中尝试此功能。
pip install tf-nightly --upgrade
pip install jax --upgrade
pip install jaxlib --upgrade
使用 Keras 数据集下载 MNIST 数据并进行预处理。
import numpy as np
import tensorflow as tf
import functools
import time
import itertools
import numpy.random as npr
import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers
from jax.example_libraries import stax
def _one_hot(x, k, dtype=np.float32):
"""Create a one-hot encoding of x of size k."""
return np.array(x[:, None] == np.arange(k), dtype)
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_images, test_images = train_images / 255.0, test_images / 255.0
train_images = train_images.astype(np.float32)
test_images = test_images.astype(np.float32)
train_labels = _one_hot(train_labels, 10)
test_labels = _one_hot(test_labels, 10)
使用 Jax 构建 MNIST 模型
def loss(params, batch):
inputs, targets = batch
preds = predict(params, inputs)
return -jnp.mean(jnp.sum(preds * targets, axis=1))
def accuracy(params, batch):
inputs, targets = batch
target_class = jnp.argmax(targets, axis=1)
predicted_class = jnp.argmax(predict(params, inputs), axis=1)
return jnp.mean(predicted_class == target_class)
init_random_params, predict = stax.serial(
stax.Dense(1024), stax.Relu,
stax.Dense(1024), stax.Relu,
stax.Dense(10), stax.LogSoftmax)
rng = random.PRNGKey(0)
step_size = 0.001
num_epochs = 10
batch_size = 128
momentum_mass = 0.9
num_train = train_images.shape[0]
num_complete_batches, leftover = divmod(num_train, batch_size)
num_batches = num_complete_batches + bool(leftover)
def data_stream():
rng = npr.RandomState(0)
while True:
perm = rng.permutation(num_train)
for i in range(num_batches):
batch_idx = perm[i * batch_size:(i + 1) * batch_size]
yield train_images[batch_idx], train_labels[batch_idx]
batches = data_stream()
opt_init, opt_update, get_params = optimizers.momentum(step_size, mass=momentum_mass)
def update(i, opt_state, batch):
params = get_params(opt_state)
return opt_update(i, grad(loss)(params, batch), opt_state)
_, init_params = init_random_params(rng, (-1, 28 * 28))
opt_state = opt_init(init_params)
itercount = itertools.count()
print("\nStarting training...")
for epoch in range(num_epochs):
start_time = time.time()
for _ in range(num_batches):
opt_state = update(next(itercount), opt_state, next(batches))
epoch_time = time.time() - start_time
params = get_params(opt_state)
train_acc = accuracy(params, (train_images, train_labels))
test_acc = accuracy(params, (test_images, test_labels))
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
Starting training... Epoch 0 in 4.47 sec Training set accuracy 0.8728833198547363 Test set accuracy 0.880299985408783 Epoch 1 in 2.45 sec Training set accuracy 0.8983833193778992 Test set accuracy 0.9047999978065491 Epoch 2 in 2.45 sec Training set accuracy 0.9102333188056946 Test set accuracy 0.9138000011444092 Epoch 3 in 2.47 sec Training set accuracy 0.9172333478927612 Test set accuracy 0.9218999743461609 Epoch 4 in 2.42 sec Training set accuracy 0.9224833250045776 Test set accuracy 0.9253999590873718 Epoch 5 in 2.45 sec Training set accuracy 0.9272000193595886 Test set accuracy 0.9309999942779541 Epoch 6 in 2.44 sec Training set accuracy 0.9328166842460632 Test set accuracy 0.9334999918937683 Epoch 7 in 2.41 sec Training set accuracy 0.9360166788101196 Test set accuracy 0.9370999932289124 Epoch 8 in 2.44 sec Training set accuracy 0.939050018787384 Test set accuracy 0.939300000667572 Epoch 9 in 2.45 sec Training set accuracy 0.9425666928291321 Test set accuracy 0.9429000020027161
转换为 TFLite 模型
- 使用
将参数内联到 Jaxpredict
函数。 - 构建一个
,这是一个用于 Jax 跟踪模型的“占位符”张量。 - 调用
被封装在一个列表中。- 输入与给定的名称相关联,并作为封装在列表中的数组传入。
serving_func = functools.partial(predict, params)
x_input = jnp.zeros((1, 28, 28))
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('input1', x_input)]])
tflite_model = converter.convert()
with open('jax_mnist.tflite', 'wb') as f:
检查转换后的 TFLite 模型
将转换后的模型的结果与 Jax 模型进行比较。
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the JAX model.
np.testing.assert_almost_equal(expected, result, 1e-5)
我们将提供一个 representative_dataset
def representative_dataset():
for i in range(1000):
x = train_images[i:i+1]
yield [x]
converter = tf.lite.TFLiteConverter.experimental_from_jax(
[serving_func], [[('x', x_input)]])
tflite_model = converter.convert()
converter.optimizations = [tf.lite.Optimize.DEFAULT]
converter.representative_dataset = representative_dataset
converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
tflite_quant_model = converter.convert()
with open('jax_mnist_quant.tflite', 'wb') as f:
expected = serving_func(train_images[0:1])
# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_quant_model)
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], train_images[0:1, :, :])
result = interpreter.get_tensor(output_details[0]["index"])
# Assert if the result of TFLite model is consistent with the Jax model.
np.testing.assert_almost_equal(expected, result, 1e-5)
du -h jax_mnist.tflite
du -h jax_mnist_quant.tflite
7.2M jax_mnist.tflite 1.8M jax_mnist_quant.tflite