Decorates a function to enable defining a custom inverse.

A custom_inverse-decorated function is semantically identical to the original except when it is inverted with core.inverse. By default, core.inverse(custom_inverse(f)) will programmatically invert the body of f, but f has two additional methods that can override that behavior: def_inverse_unary and def_inverse_ildj.


def_inverse_unary is applicable if f is a unary function. def_inverse_unary takes in an optional f_inv function, which is a unary function from the output of f to the input of f.


def add_one(x):
  return x + 1.
add_one.def_inverse_unary(lambda x: x * 2)  # Define silly custom inverse.
inverse(add_one)(2.)  # ==> 4.

With a unary f_inv function, Oryx will automatically compute an inverse log-det Jacobian using core.ildj(core.inverse(f_inv)), but a user can also override the Jacobian term by providing the optional f_ildj keyword argument to def_inverse_unary.


def add_one(x):
  return x + 1.
add_one.def_inverse_unary(lambda x: x * 2, f_ildj=lambda x: jnp.ones_like(x))
ildj(add_one)(2.)  # ==> 1.


A more general way of defining a custom inverse or ILDJ is to use def_inverse_and_ildj, which will enable the user to invert functions with partially known inputs and outputs. Take an example like add = lambda x, y: x + y, which cannot be inverted with just the output, but can be inverted when just one input is known. def_inverse_and_ildj takes a single function f_ildj as an argument. f_ildj is a function from invals (a set of values corresponding to f's inputs), outvals (a set of values corresponding to f's outputs) and out_ildjs (a set of inverse diagonal log-Jacobian values for each of the outvals). If any are unknown, they will be None. f_ildj should return a tuple (new_invals, new_inildjs) which corresponds to known values of the inputs and any corresponding diagonal Jacobian values (which should be the same shape as invals). If these values cannot be computed (e.g. too many values are None) the user can raise a NonInvertibleError which will signal to Oryx to give up trying to invert the function for this set of values.


def add(x, y):
  return x + y

def add_ildj(invals, outvals, out_ildjs):
  x, y = invals
  z = outvals
  z_ildj = outildjs
  if x is None and y is None:
    raise NonInvertibleError()
  if x is None:
    return (z - y, y), (z_ildj + jnp.zeros_like(z), jnp.zeros_like(z))
  if y is None:
    return (x, z - x), (jnp.zeros_like(z), z_ildj + jnp.zeros_like(z))

inverse(partial(add, 1.))(2.)  # ==> 1.
inverse(partial(add, 1.))(2.)  # ==> 0.

f a function for which we'd like to define a custom inverse.

A CustomInverse object whose inverse can be overridden with def_inverse_unary or def_inverse.