Migrating your TFLite code to TF2

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

TensorFlow Lite (TFLite) is a set of tools that helps developers run ML inference on-device (mobile, embedded, and IoT devices). The TFLite converter is one such tool that converts existing TF models into an optimized TFLite model format that can be efficiently run on-device.

In this doc, you'll learn what changes you need to make to your TF to TFLite conversion code, followed by a few examples that do the same.

Changes to your TF to TFLite conversion code

  • If you're using a legacy TF1 model format (such as Keras file, frozen GraphDef, checkpoints, tf.Session), update it to TF1/TF2 SavedModel and use the TF2 converter API tf.lite.TFLiteConverter.from_saved_model(...) to convert it to a TFLite model (refer to Table 1).

  • Update the converter API flags (refer to Table 2).

  • Remove legacy APIs such as tf.lite.constants. (eg: Replace tf.lite.constants.INT8 with tf.int8)

// Table 1 // TFLite Python Converter API Update

TF1 API TF2 API
tf.lite.TFLiteConverter.from_saved_model('saved_model/',..) supported
tf.lite.TFLiteConverter.from_keras_model_file('model.h5',..) removed (update to SavedModel format)
tf.lite.TFLiteConverter.from_frozen_graph('model.pb',..) removed (update to SavedModel format)
tf.lite.TFLiteConverter.from_session(sess,...) removed (update to SavedModel format)

// Table 2 // TFLite Python Converter API Flags Update

TF1 API TF2 API
allow_custom_ops
optimizations
representative_dataset
target_spec
inference_input_type
inference_output_type
experimental_new_converter
experimental_new_quantizer
supported







input_tensors
output_tensors
input_arrays_with_shape
output_arrays
experimental_debug_info_func
removed (unsupported converter API arguments)




change_concat_input_ranges
default_ranges_stats
get_input_arrays()
inference_type
quantized_input_stats
reorder_across_fake_quant
removed (unsupported quantization workflows)





conversion_summary_dir
dump_graphviz_dir
dump_graphviz_video
removed (instead, visualize models using Netron or visualize.py)


output_format
drop_control_dependency
removed (unsupported features in TF2)

Examples

You'll now walk through some examples to convert legacy TF1 models to TF1/TF2 SavedModels and then convert them to TF2 TFLite models.

Setup

Start with the necessary TensorFlow imports.

import tensorflow as tf
import tensorflow.compat.v1 as tf1
import numpy as np

import logging
logger = tf.get_logger()
logger.setLevel(logging.ERROR)

import shutil
def remove_dir(path):
  try:
    shutil.rmtree(path)
  except:
    pass

Create all the necessary TF1 model formats.

# Create a TF1 SavedModel
SAVED_MODEL_DIR = "tf_saved_model/"
remove_dir(SAVED_MODEL_DIR)
with tf1.Graph().as_default() as g:
  with tf1.Session() as sess:
    input = tf1.placeholder(tf.float32, shape=(3,), name='input')
    output = input + 2
    # print("result: ", sess.run(output, {input: [0., 2., 4.]}))
    tf1.saved_model.simple_save(
        sess, SAVED_MODEL_DIR,
        inputs={'input': input}, 
        outputs={'output': output})
print("TF1 SavedModel path: ", SAVED_MODEL_DIR)

# Create a TF1 Keras model
KERAS_MODEL_PATH = 'tf_keras_model.h5'
model = tf1.keras.models.Sequential([
    tf1.keras.layers.InputLayer(input_shape=(128, 128, 3,), name='input'),
    tf1.keras.layers.Dense(units=16, input_shape=(128, 128, 3,), activation='relu'),
    tf1.keras.layers.Dense(units=1, name='output')
])
model.save(KERAS_MODEL_PATH, save_format='h5')
print("TF1 Keras Model path: ", KERAS_MODEL_PATH)

# Create a TF1 frozen GraphDef model
GRAPH_DEF_MODEL_PATH = tf.keras.utils.get_file(
    'mobilenet_v1_0.25_128',
    origin='https://storage.googleapis.com/download.tensorflow.org/models/mobilenet_v1_0.25_128_frozen.tgz',
    untar=True,
) + '/frozen_graph.pb'

print("TF1 frozen GraphDef path: ", GRAPH_DEF_MODEL_PATH)

1. Convert a TF1 SavedModel to a TFLite model

Before: Converting with TF1

This is typical code for TF1-style TFlite conversion.

converter = tf1.lite.TFLiteConverter.from_saved_model(
    saved_model_dir=SAVED_MODEL_DIR,
    input_arrays=['input'],
    input_shapes={'input' : [3]}
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()
# Ignore warning: "Use '@tf.function' or '@defun' to decorate the function."

After: Converting with TF2

Directly convert the TF1 SavedModel to a TFLite model, with a smaller v2 converter flags set.

# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir=SAVED_MODEL_DIR)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()

2. Convert a TF1 Keras model file to a TFLite model

Before: Converting with TF1

This is typical code for TF1-style TFlite conversion.

converter = tf1.lite.TFLiteConverter.from_keras_model_file(model_file=KERAS_MODEL_PATH)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()

After: Converting with TF2

First, convert the TF1 Keras model file to a TF2 SavedModel and then convert it to a TFLite model, with a smaller v2 converter flags set.

# Convert TF1 Keras model file to TF2 SavedModel.
model = tf.keras.models.load_model(KERAS_MODEL_PATH)
model.save(filepath='saved_model_2/')

# Convert TF2 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_2/')
tflite_model = converter.convert()

3. Convert a TF1 frozen GraphDef to a TFLite model

Before: Converting with TF1

This is typical code for TF1-style TFlite conversion.

converter = tf1.lite.TFLiteConverter.from_frozen_graph(
    graph_def_file=GRAPH_DEF_MODEL_PATH,
    input_arrays=['input'],
    input_shapes={'input' : [1, 128, 128, 3]},
    output_arrays=['MobilenetV1/Predictions/Softmax'],
)
converter.optimizations = {tf.lite.Optimize.DEFAULT}
converter.change_concat_input_ranges = True
tflite_model = converter.convert()

After: Converting with TF2

First, convert the TF1 frozen GraphDef to a TF1 SavedModel and then convert it to a TFLite model, with a smaller v2 converter flags set.

## Convert TF1 frozen Graph to TF1 SavedModel.

# Load the graph as a v1.GraphDef
import pathlib
gdef = tf.compat.v1.GraphDef()
gdef.ParseFromString(pathlib.Path(GRAPH_DEF_MODEL_PATH).read_bytes())

# Convert the GraphDef to a tf.Graph
with tf.Graph().as_default() as g:
  tf.graph_util.import_graph_def(gdef, name="")

# Look up the input and output tensors.
input_tensor = g.get_tensor_by_name('input:0') 
output_tensor = g.get_tensor_by_name('MobilenetV1/Predictions/Softmax:0')

# Save the graph as a TF1 Savedmodel
remove_dir('saved_model_3/')
with tf.compat.v1.Session(graph=g) as s:
  tf.compat.v1.saved_model.simple_save(
      session=s,
      export_dir='saved_model_3/',
      inputs={'input':input_tensor},
      outputs={'output':output_tensor})

# Convert TF1 SavedModel to a TFLite model.
converter = tf.lite.TFLiteConverter.from_saved_model(saved_model_dir='saved_model_3/')
converter.optimizations = {tf.lite.Optimize.DEFAULT}
tflite_model = converter.convert()

Further reading

  • Refer to the TFLite Guide to learn more about the workflows and latest features.
  • If you're using TF1 code or legacy TF1 model formats (Keras .h5 files, frozen GraphDef .pb, etc), please update your code and migrate your models to the TF2 SavedModel format.