Các mô hình JAX với TensorFlow Lite

Trang này cung cấp đường dẫn cho người dùng muốn đào tạo các mô hình trong JAX và triển khai trên thiết bị di động để suy luận ( ví dụ colab ).

Các phương pháp trong hướng dẫn này tạo ra một tflite_model có thể được sử dụng trực tiếp với ví dụ về mã trình thông dịch TFLite hoặc được lưu vào tệp TFLite FlatBuffer.

Điều kiện tiên quyết

Bạn nên thử tính năng này với gói Python hàng đêm mới nhất của TensorFlow.

pip install tf-nightly --upgrade

Chúng tôi sẽ sử dụng thư viện Xuất Orbax để xuất các mô hình JAX. Đảm bảo phiên bản JAX của bạn ít nhất là 0.4.20 trở lên.

pip install jax --upgrade
pip install orbax-export --upgrade

Chuyển đổi mô hình JAX sang TensorFlow Lite

Chúng tôi sử dụng TensorFlow SavingModel làm định dạng trung gian giữa JAX và TensorFlow Lite. Sau khi bạn có SavingModel thì bạn có thể sử dụng API TensorFlow Lite hiện có để hoàn tất quá trình chuyển đổi.

# This code snippet converts a JAX model to TFLite through TF SavedModel.
from orbax.export import ExportManager
from orbax.export import JaxModule
from orbax.export import ServingConfig
import tensorflow as tf
import jax.numpy as jnp

def model_fn(_, x):
  return jnp.sin(jnp.cos(x))

jax_module = JaxModule({}, model_fn, input_polymorphic_shape='b, ...')

# Option 1: Simply save the model via `tf.saved_model.save` if no need for pre/post
# processing.
tf.saved_model.save(
    jax_module,
    '/some/directory',
    signatures=jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
        tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
    ),
    options=tf.saved_model.SaveOptions(experimental_custom_gradients=True),
)
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 2: Define pre/post processing TF functions (e.g. (de)?tokenize).
serving_config = ServingConfig(
    'Serving_default',
    # Corresponds to the input signature of `tf_preprocessor`
    input_signature=[tf.TensorSpec(shape=(None,), dtype=tf.float32, name='input')],
    tf_preprocessor=lambda x: x,
    tf_postprocessor=lambda out: {'output': out}
)
export_mgr = ExportManager(jax_module, [serving_config])
export_mgr.save('/some/directory')
converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory')
tflite_model = converter.convert()

# Option 3: Convert from TF concrete function directly
converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [
        jax_module.methods[JaxModule.DEFAULT_METHOD_KEY].get_concrete_function(
            tf.TensorSpec(shape=(None,), dtype=tf.float32, name="input")
        )
    ]
)
tflite_model = converter.convert()

Kiểm tra mô hình TFLite đã chuyển đổi

Sau khi mô hình được chuyển đổi sang TFLite, bạn có thể chạy API trình thông dịch TFLite để kiểm tra kết quả đầu ra của mô hình.

# Run the model with TensorFlow Lite
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors() input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
interpreter.set_tensor(input_details[0]["index"], input_data)
interpreter.invoke()
result = interpreter.get_tensor(output_details[0]["index"])