ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tfp.substrates.jax.random.split_seed

Splits a seed into n derived seeds.

See https://github.com/tensorflow/probability/blob/main/PRNGS.md for details. Args: seed: The seed to split; may be an int, an (int, int) tuple, or a Tensor. int seeds are converted to Tensor seeds using tf.random.stateless_uniform stateful sampling. Tuples are converted to Tensor. n: The number of splits to return. In TensorFlow, if n is an integer, this function returns a list of seeds and otherwise returns a Tensor of seeds. In JAX, this function always returns an array of seeds. salt: Optional str salt to mix with the seed. name: Optional name to scope related ops.

seeds If n is a Python int, a tuple of seed values is returned. If n is an int Tensor, a single Tensor of shape [n, 2] is returned. A single such seed is suitable to pass as the seed argument of the tf.random.stateless_* ops.