Decorates/wraps Python functions containing JAX code as TFF computations.
tff.jax_computation( *args, tff_internal_types=None, **kwargs )
This wrapper can be used in a similar manner to
exception of the following:
The code in the wrapped Python function must be JAX code that can be compiled to XLA (e.g., code that one would expect to be able to annotate with
The inputs and outputs must be tensors, or (possibly recursively) nested structures of tensors. Sequences are currently not supported.
@tff.jax_computation(tf.int32) def comp(x): return jax.numpy.add(x, np.int32(10))