tff.jax_computation

Decorates/wraps Python functions containing JAX code as TFF computations.

This wrapper can be used in a similar manner to tff.tf_computation, with exception of the following:

  • The code in the wrapped Python function must be JAX code that can be compiled to XLA (e.g., code that one would expect to be able to annotate with @jax.jit).

  • The inputs and outputs must be tensors, or (possibly recursively) nested structures of tensors. Sequences are currently not supported.

Example:

@tff.jax_computation(np.int32)
def comp(x):
  return jax.numpy.add(x, np.int32(10))