tfp.substrates.jax.math.batch_interp_regular_nd_grid

Multi-linear interpolation on a regular (constant spacing) grid.

Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new x values. This is a multi-dimensional generalization of Bilinear Interpolation.

The interpolant is built from reference values indexed by nd dimensions of y_ref, starting at axis.

The x grid span is defined by x_ref_min, x_ref_max. The number of grid points is inferred from the shape of y_ref.

For example, take the case of a 2-D scalar valued function and no leading batch dimensions. In this case, y_ref.shape = [C1, C2] and y_ref[i, j] is the reference value corresponding to grid point

[x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
 x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]

In the general case, dimensions to the left of axis in y_ref are broadcast with leading dimensions in x, x_ref_min, x_ref_max.

x Numeric Tensor The x-coordinates of the interpolated output values for each batch. Shape [..., D, nd], designating [a batch of] D coordinates in nd space. D must be >= 1 and is not a batch dim.
x_ref_min Tensor of same dtype as x. The minimum values of the (implicitly defined) reference x_ref. Shape [..., nd].
x_ref_max Tensor of same dtype as x. The maximum values of the (implicitly defined) reference x_ref. Shape [..., nd].
y_ref Tensor of same dtype as x. The reference output values. Shape [..., C1, ..., Cnd, B1,...,BM], designating [a batch of] reference values indexed by nd dimensions, of a shape [B1,...,BM] valued function (for M >= 0).
axis Scalar integer Tensor. Dimensions [axis, axis + nd) of y_ref index the interpolation table. E.g. 3-D interpolation of a scalar valued function requires axis=-3 and a 3-D matrix valued function requires axis=-5.
fill_value Determines what values output should take for x values that are below x_ref_min or above x_ref_max. Scalar Tensor or 'constant_extension' ==> Extend as constant function. Default value: 'constant_extension'
name A name to prepend to created ops. Default value: 'batch_interp_regular_nd_grid'.

y_interp Interpolation between members of y_ref, at points x. Tensor of same dtype as x, and shape [..., D, B1, ..., BM].

Exceptions will be raised if shapes are statically determined to be wrong.

ValueError If rank(x) < 2.
ValueError If axis is not a scalar.
ValueError If axis + nd > rank(y_ref).

Examples

Interpolate a function of one variable.

y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))

tfp.math.batch_interp_regular_nd_grid(
    # x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
    x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
    axis=0)
==> approx [exp(6.0), exp(0.5), exp(3.3)]

Interpolate a scalar function of two variables.

x_ref_min = [0., 0.]
x_ref_max = [2 * np.pi, 2 * np.pi]

# Build y_ref.
x0s, x1s = tf.meshgrid(
    tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
    tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
    indexing='ij')

def func(x0, x1):
  return tf.sin(x0) * tf.cos(x1)

y_ref = func(x0s, x1s)

x = 2 * np.pi * tf.random.stateless_uniform(shape=(10, 2))

tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])