Finds bounds that bracket a root of the objective function.
tfp.substrates.jax.math.bracket_root(
objective_fn,
dtype=tf.float32,
num_points=512,
name='bracket_root'
)
This method attempts to return an interval bracketing a root of the objective
function. It evaluates the objective in parallel at num_points
locations, at exponentially increasing distance from the origin, and returns
the first pair of adjacent points [low, high]
such that the objective is
finite and has a different sign at the two points. If no such pair was
observed, it returns the trivial interval
[np.finfo(dtype).min, np.finfo(dtype).max]
containing all float values of
the specified dtype
. If the objective has multiple
roots, the returned interval will contain at least one (but perhaps not all)
of the roots.
Args |
objective_fn
|
Python callable for which roots are searched. It must be a
continuous function that accepts a scalar Tensor of type dtype and
returns a Tensor of shape batch_shape .
|
dtype
|
Optional float dtype of inputs to objective_fn .
Default value: tf.float32 .
|
num_points
|
Optional Python int number of points at which to evaluate
the objective.
Default value: 512 .
|
name
|
Python str name given to ops created by this method.
|
Returns |
low
|
Float Tensor of shape batch_shape and dtype dtype . Lower bound
on a root of objective_fn .
|
high
|
Float Tensor of shape batch_shape and dtype dtype . Upper bound
on a root of objective_fn .
|