Import a JAX model using JAX2TF

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

This notebook provides a complete, runnable example of creating a model using JAX and bringing it into TensorFlow to continue training. This is made possible by JAX2TF, a lightweight API that provides a pathway from the JAX ecosystem to the TensorFlow ecosystem.

JAX is a high-performance array computing library. To create the model, this notebook uses Flax, a neural network library for JAX. To train it, it uses Optax, an optimization library for JAX.

If you're a researcher using JAX, JAX2TF gives you a path to production using TensorFlow's proven tools.

There are many ways this can be useful, here are just a few:

  • Inference: Taking a model written for JAX and deploying it either on a server using TF Serving, on-device using TFLite, or on the web using TensorFlow.js.

  • Fine-tuning: Taking a model that was trained using JAX, you can bring its components to TF using JAX2TF, and continue training it in TensorFlow with your existing training data and setup.

  • Fusion: Combining parts of models that were trained using JAX with those trained using TensorFlow, for maximum flexibility.

The key to enabling this kind of interoperation between JAX and TensorFlow is jax2tf.convert, which takes in model components created on top of JAX (your loss function, prediction function, etc) and creates equivalent representations of them as TensorFlow functions, which can then be exported as a TensorFlow SavedModel.

Setup

import tensorflow as tf
import numpy as np
import jax
import jax.numpy as jnp
import flax
import optax
import os
from matplotlib import pyplot as plt
from jax.experimental import jax2tf
from threading import Lock # Only used in the visualization utility.
from functools import partial
# Needed for TensorFlow and JAX to coexist in GPU memory.
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = "false"
gpus = tf.config.list_physical_devices('GPU')
if gpus:
  try:
    for gpu in gpus:
      tf.config.experimental.set_memory_growth(gpu, True)
  except RuntimeError as e:
    # Memory growth must be set before GPUs have been initialized.
    print(e)

Visualization utilities

Download and prepare the MNIST dataset

(x_train, train_labels), (x_test, test_labels) = tf.keras.datasets.mnist.load_data()

train_data = tf.data.Dataset.from_tensor_slices((x_train, train_labels))
train_data = train_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),
                                         tf.one_hot(y, depth=10)))

BATCH_SIZE = 256
train_data = train_data.batch(BATCH_SIZE, drop_remainder=True)
train_data = train_data.cache()
train_data = train_data.shuffle(5000, reshuffle_each_iteration=True)

test_data = tf.data.Dataset.from_tensor_slices((x_test, test_labels))
test_data = test_data.map(lambda x,y: (tf.expand_dims(tf.cast(x, tf.float32)/255.0, axis=-1),
                                         tf.one_hot(y, depth=10)))
test_data = test_data.batch(10000)
test_data = test_data.cache()

(one_batch, one_batch_labels) = next(iter(train_data)) # just one batch
(all_test_data, all_test_labels) = next(iter(test_data)) # all in one batch since batch size is 10000

Configure training

This notebook will create and train a simple model for demonstration purposes.

# Training hyperparameters.
JAX_EPOCHS = 3
TF_EPOCHS = 7
STEPS_PER_EPOCH = len(train_labels)//BATCH_SIZE
LEARNING_RATE = 0.01
LEARNING_RATE_EXP_DECAY = 0.6

# The learning rate schedule for JAX (with Optax).
jlr_decay = optax.exponential_decay(LEARNING_RATE, transition_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

# THe learning rate schedule for TensorFlow.
tflr_decay = tf.keras.optimizers.schedules.ExponentialDecay(initial_learning_rate=LEARNING_RATE, decay_steps=STEPS_PER_EPOCH, decay_rate=LEARNING_RATE_EXP_DECAY, staircase=True)

Create the model using Flax

class ConvModel(flax.linen.Module):

  @flax.linen.compact
  def __call__(self, x, train):
    x = flax.linen.Conv(features=12, kernel_size=(3,3), padding="SAME", use_bias=False)(x)
    x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)
    x = x.reshape((x.shape[0], -1))  # flatten
    x = flax.linen.Dense(features=200, use_bias=True)(x)
    x = flax.linen.BatchNorm(use_running_average=not train, use_scale=False, use_bias=True)(x)
    x = flax.linen.Dropout(rate=0.3, deterministic=not train)(x)
    x = flax.linen.relu(x)
    x = flax.linen.Dense(features=10)(x)
    #x = flax.linen.log_softmax(x)
    return x

  # JAX differentiation requires a function `f(params, other_state, data, labels)` -> `loss` (as a single number).
  # `jax.grad` will differentiate it against the fist argument.
  # The user must split trainable and non-trainable variables into `params` and `other_state`.
  # Must pass a different RNG key each time for the dropout mask to be different.
  def loss(self, params, other_state, rng, data, labels, train):
    logits, batch_stats = self.apply({'params': params, **other_state},
                                     data,
                                     mutable=['batch_stats'],
                                     rngs={'dropout': rng},
                                     train=train)
    # The loss averaged across the batch dimension.
    loss = optax.softmax_cross_entropy(logits, labels).mean()
    return loss, batch_stats

  def predict(self, state, data):
    logits = self.apply(state, data, train=False) # predict and accuracy disable dropout and use accumulated batch norm stats (train=False)
    probabilities = flax.linen.log_softmax(logits)
    return probabilities

  def accuracy(self, state, data, labels):
    probabilities = self.predict(state, data)
    predictions = jnp.argmax(probabilities, axis=-1)
    dense_labels = jnp.argmax(labels, axis=-1)
    accuracy = jnp.equal(predictions, dense_labels).mean()
    return accuracy

Write the training step function

# The training step.
@partial(jax.jit, static_argnums=[0]) # this forces jax.jit to recompile for every new model
def train_step(model, state, optimizer_state, rng, data, labels):

  other_state, params = state.pop('params') # differentiate only against 'params' which represents trainable variables
  (loss, batch_stats), grads = jax.value_and_grad(model.loss, has_aux=True)(params, other_state, rng, data, labels, train=True)

  updates, optimizer_state = optimizer.update(grads, optimizer_state)
  params = optax.apply_updates(params, updates)
  new_state = state.copy(add_or_replace={**batch_stats, 'params': params})

  rng, _ = jax.random.split(rng)

  return new_state, optimizer_state, rng, loss

Write the training loop

def train(model, state, optimizer_state, train_data, epochs, losses, avg_losses, eval_losses, eval_accuracies):
  p = Progress(STEPS_PER_EPOCH)
  rng = jax.random.PRNGKey(0)
  for epoch in range(epochs):

    # This is where the learning rate schedule state is stored in the optimizer state.
    optimizer_step = optimizer_state[1].count

    # Run an epoch of training.
    for step, (data, labels) in enumerate(train_data):
      p.step(reset=(step==0))
      state, optimizer_state, rng, loss = train_step(model, state, optimizer_state, rng, data.numpy(), labels.numpy())
      losses.append(loss)
    avg_loss = np.mean(losses[-step:])
    avg_losses.append(avg_loss)

    # Run one epoch of evals (10,000 test images in a single batch).
    other_state, params = state.pop('params')
    # Gotcha: must discard modified batch_stats here
    eval_loss, _ = model.loss(params, other_state, rng, all_test_data.numpy(), all_test_labels.numpy(), train=False)
    eval_losses.append(eval_loss)
    eval_accuracy = model.accuracy(state, all_test_data.numpy(), all_test_labels.numpy())
    eval_accuracies.append(eval_accuracy)

    print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", jlr_decay(optimizer_step))

  return state, optimizer_state

Create the model and the optimizer (with Optax)

# The model.
model = ConvModel()
state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout"

# The optimizer.
optimizer = optax.adam(learning_rate=jlr_decay) # Gotcha: it does not seem to be possible to pass just a callable as LR, must be an Optax Schedule
optimizer_state = optimizer.init(state['params'])

losses=[]
avg_losses=[]
eval_losses=[]
eval_accuracies=[]

Train the model

new_state, new_optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS+TF_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

Partially train the model

You will continue training the model in TensorFlow shortly.

model = ConvModel()
state = model.init({'params':jax.random.PRNGKey(0), 'dropout':jax.random.PRNGKey(0)}, one_batch, train=True) # Flax allows a separate RNG for "dropout"

# The optimizer.
optimizer = optax.adam(learning_rate=jlr_decay) # LR must be an Optax LR Schedule
optimizer_state = optimizer.init(state['params'])

losses, avg_losses, eval_losses, eval_accuracies = [], [], [], []
state, optimizer_state = train(model, state, optimizer_state, train_data, JAX_EPOCHS, losses, avg_losses, eval_losses, eval_accuracies)
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=1*STEPS_PER_EPOCH)

Save just enough for inference

If your goal is to deploy your JAX model (so you can run inference using model.predict()), simply exporting it to SavedModel is sufficient. This section demonstrates how to accomplish that.

# Test data with a different batch size to test polymorphic shapes.
x, y = next(iter(train_data.unbatch().batch(13)))

m = tf.Module()
# Wrap the JAX state in `tf.Variable` (needed when calling the converted JAX function.
state_vars = tf.nest.map_structure(tf.Variable, state)
# Keep the wrapped state as flat list (needed in TensorFlow fine-tuning).
m.vars = tf.nest.flatten(state_vars)
# Convert the desired JAX function (`model.predict`).
predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"])
# Wrap the converted function in `tf.function` with the correct `tf.TensorSpec` (necessary for dynamic shapes to work).
@tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])
def predict(data):
    return predict_fn(state_vars, data)
m.predict = predict
tf.saved_model.save(m, "./")
# Test the converted function.
print("Converted function predictions:", np.argmax(m.predict(x).numpy(), axis=-1))
# Reload the model.
reloaded_model = tf.saved_model.load("./")
# Test the reloaded converted function (the result should be the same).
print("Reloaded  function predictions:", np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

Save everything

If your goal is a comprehensive export (useful if you're planning on brining the model into TensorFlow for fine-tuning, fusion, etc), this section demonstrates how to save the model so you can access methods including:

  • model.predict
  • model.accuracy
  • model.loss (including train=True/False bool, RNG for dropout and BatchNorm state updates)
from collections import abc

def _fix_frozen(d):
  """Changes any mappings (e.g. frozendict) back to dict."""
  if isinstance(d, list):
    return [_fix_frozen(v) for v in d]
  elif isinstance(d, tuple):
    return tuple(_fix_frozen(v) for v in d)
  elif not isinstance(d, abc.Mapping):
    return d
  d = dict(d)
  for k, v in d.items():
    d[k] = _fix_frozen(v)
  return d
class TFModel(tf.Module):
  def __init__(self, state, model):
    super().__init__()

    # Special care needed for the train=True/False parameter in the loss
    @jax.jit
    def loss_with_train_bool(state, rng, data, labels, train):
      other_state, params = state.pop('params')
      loss, batch_stats = jax.lax.cond(train,
                                       lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=True),
                                       lambda state, data, labels: model.loss(params, other_state, rng, data, labels, train=False),
                                       state, data, labels)
      # must use JAX to split the RNG, therefore, must do it in a @jax.jit function
      new_rng, _ = jax.random.split(rng)
      return loss, batch_stats, new_rng

    self.state_vars = tf.nest.map_structure(tf.Variable, state)
    self.vars = tf.nest.flatten(self.state_vars)
    self.jax_rng = tf.Variable(jax.random.PRNGKey(0))

    self.loss_fn = jax2tf.convert(loss_with_train_bool, polymorphic_shapes=["...", "...", "(b, 28, 28, 1)", "(b, 10)", "..."])
    self.accuracy_fn = jax2tf.convert(model.accuracy, polymorphic_shapes=["...", "(b, 28, 28, 1)", "(b, 10)"])
    self.predict_fn = jax2tf.convert(model.predict, polymorphic_shapes=["...", "(b, 28, 28, 1)"])

  # Must specify TensorSpec manually for variable batch size to work
  @tf.function(autograph=False, input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32)])
  def predict(self, data):
    # Make sure the TfModel.predict function implicitly use self.state_vars and not the JAX state directly
    # otherwise, all model weights would be embedded in the TF graph as constants.
    return self.predict_fn(self.state_vars, data)

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def train_loss(self, data, labels):
      loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, True)
      # update batch norm stats
      flat_vars = tf.nest.flatten(self.state_vars['batch_stats'])
      flat_values = tf.nest.flatten(batch_stats['batch_stats'])
      for var, val in zip(flat_vars, flat_values):
        var.assign(val)
      # update RNG
      self.jax_rng.assign(new_rng)
      return loss

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def eval_loss(self, data, labels):
      loss, batch_stats, new_rng = self.loss_fn(self.state_vars, self.jax_rng, data, labels, False)
      return loss

  @tf.function(input_signature=[tf.TensorSpec(shape=(None, 28, 28, 1), dtype=tf.float32),
                                tf.TensorSpec(shape=(None, 10), dtype=tf.float32)],
               autograph=False)
  def accuracy(self, data, labels):
    return self.accuracy_fn(self.state_vars, data, labels)
# Instantiate the model.
tf_model = TFModel(state, model)

# Save the model.
tf.saved_model.save(tf_model, "./")

Reload the model

reloaded_model = tf.saved_model.load("./")

# Test if it works and that the batch size is indeed variable.
x,y = next(iter(train_data.unbatch().batch(13)))
print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))
x,y = next(iter(train_data.unbatch().batch(20)))
print(np.argmax(reloaded_model.predict(x).numpy(), axis=-1))

print(reloaded_model.accuracy(one_batch, one_batch_labels))
print(reloaded_model.accuracy(all_test_data, all_test_labels))

Continue training the converted JAX model in TensorFlow

optimizer = tf.keras.optimizers.Adam(learning_rate=tflr_decay)

# Set the iteration step for the learning rate to resume from where it left off in JAX.
optimizer.iterations.assign(len(eval_losses)*STEPS_PER_EPOCH)

p = Progress(STEPS_PER_EPOCH)

for epoch in range(JAX_EPOCHS, JAX_EPOCHS+TF_EPOCHS):

  # This is where the learning rate schedule state is stored in the optimizer state.
  optimizer_step = optimizer.iterations

  for step, (data, labels) in enumerate(train_data):
    p.step(reset=(step==0))
    with tf.GradientTape() as tape:
      #loss = reloaded_model.loss(data, labels, True)
      loss = reloaded_model.train_loss(data, labels)
      grads = tape.gradient(loss, reloaded_model.vars)
      optimizer.apply_gradients(zip(grads, reloaded_model.vars))
      losses.append(loss)
  avg_loss = np.mean(losses[-step:])
  avg_losses.append(avg_loss)

  eval_loss = reloaded_model.eval_loss(all_test_data.numpy(), all_test_labels.numpy()).numpy()
  eval_losses.append(eval_loss)
  eval_accuracy = reloaded_model.accuracy(all_test_data.numpy(), all_test_labels.numpy()).numpy()
  eval_accuracies.append(eval_accuracy)

  print("\nEpoch", epoch, "train loss:", avg_loss, "eval loss:", eval_loss, "eval accuracy", eval_accuracy, "lr:", tflr_decay(optimizer.iterations).numpy())
display_train_curves(losses, avg_losses, eval_losses, eval_accuracies, len(eval_losses), STEPS_PER_EPOCH, ignore_first_n=2*STEPS_PER_EPOCH)

# The loss takes a hit when the training restarts, but does not go back to random levels.
# This is likely caused by the optimizer momentum being reinitialized.

Next steps

You can learn more about JAX and Flax on their documentation websites which contain detailed guides and examples. If you're new to JAX, be sure to explore the JAX 101 tutorials, and check out the Flax quickstart. To learn more about converting JAX models to TensorFlow format, check out the jax2tf utility on GitHub. If you're interested in converting JAX models to run in the browser with TensorFlow.js, visit JAX on the Web with TensorFlow.js. If you'd like to prepare JAX models to run in TensorFLow Lite, visit the JAX Model Conversion For TFLite guide.