Attend the Women in ML Symposium on December 7

# tfp.math.batch_interp_regular_nd_grid

Multi-linear interpolation on a regular (constant spacing) grid.

Given [a batch of] reference values, this function computes a multi-linear interpolant and evaluates it on [a batch of] of new `x` values. This is a multi-dimensional generalization of Bilinear Interpolation.

The interpolant is built from reference values indexed by `nd` dimensions of `y_ref`, starting at `axis`.

The x grid span is defined by `x_ref_min`, `x_ref_max`. The number of grid points is inferred from the shape of `y_ref`.

For example, take the case of a `2-D` scalar valued function and no leading batch dimensions. In this case, `y_ref.shape = [C1, C2]` and `y_ref[i, j]` is the reference value corresponding to grid point

``````[x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
``````

In the general case, dimensions to the left of `axis` in `y_ref` are broadcast with leading dimensions in `x`, `x_ref_min`, `x_ref_max`.

`x` Numeric `Tensor` The x-coordinates of the interpolated output values for each batch. Shape `[..., D, nd]`, designating [a batch of] `D` coordinates in `nd` space. `D` must be `>= 1` and is not a batch dim.
`x_ref_min` `Tensor` of same `dtype` as `x`. The minimum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`.
`x_ref_max` `Tensor` of same `dtype` as `x`. The maximum values of the (implicitly defined) reference `x_ref`. Shape `[..., nd]`.
`y_ref` `Tensor` of same `dtype` as `x`. The reference output values. Shape `[..., C1, ..., Cnd, B1,...,BM]`, designating [a batch of] reference values indexed by `nd` dimensions, of a shape `[B1,...,BM]` valued function (for `M >= 0`).
`axis` Scalar integer `Tensor`. Dimensions `[axis, axis + nd)` of `y_ref` index the interpolation table. E.g. `3-D` interpolation of a scalar valued function requires `axis=-3` and a `3-D` matrix valued function requires `axis=-5`.
`fill_value` Determines what values output should take for `x` values that are below `x_ref_min` or above `x_ref_max`. Scalar `Tensor` or 'constant_extension' ==> Extend as constant function. Default value: `'constant_extension'`
`name` A name to prepend to created ops. Default value: `'batch_interp_regular_nd_grid'`.

`y_interp` Interpolation between members of `y_ref`, at points `x`. `Tensor` of same `dtype` as `x`, and shape `[..., D, B1, ..., BM].`

Exceptions will be raised if shapes are statically determined to be wrong.

`ValueError` If `rank(x) < 2`.
`ValueError` If `axis` is not a scalar.
`ValueError` If `axis + nd > rank(y_ref)`.

#### Examples

Interpolate a function of one variable.

``````y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))

tfp.math.batch_interp_regular_nd_grid(
# x.shape = [3, 1], x_ref_min/max.shape = [1].  Trailing `1` for `1-D`.
x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
axis=0)
==> approx [exp(6.0), exp(0.5), exp(3.3)]
``````

Interpolate a scalar function of two variables.

``````x_ref_min = [0., 0.]
x_ref_max = [2 * np.pi, 2 * np.pi]

# Build y_ref.
x0s, x1s = tf.meshgrid(
tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
indexing='ij')

def func(x0, x1):
return tf.sin(x0) * tf.cos(x1)

y_ref = func(x0s, x1s)

x = 2 * np.pi * tf.random.uniform(shape=(10, 2))

tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])
``````
[]
[]