Google I/O is a wrap! Catch up on TensorFlow sessions

# oryx.core.custom_inverse

Decorates a function to enable defining a custom inverse.

A `custom_inverse`-decorated function is semantically identical to the original except when it is inverted with `core.inverse`. By default, `core.inverse(custom_inverse(f))` will programmatically invert the body of `f`, but `f` has two additional methods that can override that behavior: `def_inverse_unary` and `def_inverse_ildj`.

## `def_inverse_unary`

`def_inverse_unary` is applicable if `f` is a unary function. `def_inverse_unary` takes in an optional `f_inv` function, which is a unary function from the output of `f` to the input of `f`.

#### Example:

``````@custom_inverse
return x + 1.
add_one.def_inverse_unary(lambda x: x * 2)  # Define silly custom inverse.
``````

With a unary `f_inv` function, Oryx will automatically compute an inverse log-det Jacobian using `core.ildj(core.inverse(f_inv))`, but a user can also override the Jacobian term by providing the optional `f_ildj` keyword argument to `def_inverse_unary`.

#### Example:

``````@custom_inverse
return x + 1.
add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x))
``````

## `def_inverse_and_ildj`

A more general way of defining a custom inverse or ILDJ is to use `def_inverse_and_ildj`, which will enable the user to invert functions with partially known inputs and outputs. Take an example like `add = lambda x, y: x + y`, which cannot be inverted with just the output, but can be inverted when just one input is known. `def_inverse_and_ildj` takes a single function `f_ildj` as an argument. `f_ildj` is a function from `invals` (a set of values corresponding to `f`'s inputs), `outvals` (a set of values corresponding to `f`'s outputs) and `out_ildjs` (a set of inverse diagonal log-Jacobian values for each of the `outvals`). If any are unknown, they will be `None`. `f_ildj` should return a tuple `(new_invals, new_inildjs)` which corresponds to known values of the inputs and any corresponding diagonal Jacobian values (which should be the same shape as `invals`). If these values cannot be computed (e.g. too many values are `None`) the user can raise a `NonInvertibleError` which will signal to Oryx to give up trying to invert the function for this set of values.

#### Example:

``````@custom_inverse
return x + y

x, y = invals
z = outvals
z_ildj = outildjs
if x is None and y is None:
raise NonInvertibleError()
if x is None:
return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z))
if y is None:
return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z))

`f` a function for which we'd like to define a custom inverse.
A `CustomInverse` object whose inverse can be overridden with `def_inverse_unary` or `def_inverse`.