tfp.substrates.jax.math.bracket_root

Finds bounds that bracket a root of the objective function.

This method attempts to return an interval bracketing a root of the objective function. It evaluates the objective in parallel at num_points locations, at exponentially increasing distance from the origin, and returns the first pair of adjacent points [low, high] such that the objective is finite and has a different sign at the two points. If no such pair was observed, it returns the trivial interval [np.finfo(dtype).min, np.finfo(dtype).max] containing all float values of the specified dtype. If the objective has multiple roots, the returned interval will contain at least one (but perhaps not all) of the roots.

objective_fn Python callable for which roots are searched. It must be a continuous function that accepts a scalar Tensor of type dtype and returns a Tensor of shape batch_shape.
dtype Optional float dtype of inputs to objective_fn. Default value: tf.float32.
num_points Optional Python int number of points at which to evaluate the objective. Default value: 512.
name Python str name given to ops created by this method.

low Float Tensor of shape batch_shape and dtype dtype. Lower bound on a root of objective_fn.
high Float Tensor of shape batch_shape and dtype dtype. Upper bound on a root of objective_fn.