Represents a tangent space to some manifold M at a point x.
TFP allows one to transform manifolds via Bijector
s. Keeping track
of the tangent space to a manifold allows us to calculate the
correct push-forward density under such transformations.
In general, the density correction involves computing the basis of the tangent space as well as the image of that basis under the transformation. But we can avoid this work in special cases that arise from the properties of the transformation f (e.g., dimension-preserving, coordinate-wise) and the density p of the manifold (e.g., discrete, supported on all of R^n).
Each subclass of TangentSpace
represents a specific property of
densities seen in the uses of TFP. The methods of TangentSpace
represent the special Bijector
structures provided by TFP. Each
subclass thus defines how to compute the density correction under
each kind of transformation.
Methods
transform_coordinatewise
transform_coordinatewise(
x, f, **kwargs
)
Same as transform_dimension_preserving
, for a coordinatewise f.
Default falls back to transform_dimension_preserving
, which may
be overridden in subclasses.
Args | |
---|---|
x
|
same as in transform_dimension_preserving .
|
f
|
same as in transform_dimension_preserving .
|
**kwargs
|
same as in transform_dimension_preserving .
|
Returns | |
---|---|
log_density
|
A Tensor representing the log density correction of f at x
|
space
|
A TangentSpace representing the tangent to fM at f(x)
|
Raises | |
---|---|
NotImplementedError
|
if the TangentSpace subclass does not implement
transform_dimension_preserving .
|
transform_dimension_preserving
transform_dimension_preserving(
x, f, **kwargs
)
Same as transform_general
, assuming f goes from R^n to R^n.
Default falls back to transform_general
, which may be overridden
in subclasses.
Args | |
---|---|
x
|
same as in transform_general .
|
f
|
same as in transform_general .
|
**kwargs
|
same as in transform_general .
|
Returns | |
---|---|
log_density
|
A Tensor representing the log density correction of f at x
|
space
|
A TangentSpace representing the tangent to fM at f(x)
|
Raises | |
---|---|
NotImplementedError
|
if the TangentSpace subclass does not implement
transform_general .
|
transform_general
transform_general(
x, f, **kwargs
)
Returns the density correction, in log space, corresponding to f at x.
Also returns a new TangentSpace
representing the tangent to fM at f(x).
Args | |
---|---|
x
|
Tensor (structure). The point at which to calculate the density.
|
f
|
Bijector or one of its subclasses. The transformation that requires a
density correction based on this tangent space.
|
**kwargs
|
Optional keyword arguments as part of the Bijector. |
Returns | |
---|---|
log_density
|
A Tensor representing the log density correction of f at x
|
space
|
A TangentSpace representing the tangent to fM at f(x)
|
Raises | |
---|---|
NotImplementedError
|
if the TangentSpace subclass does not implement
this method.
|
transform_projection
transform_projection(
x, f, **kwargs
)
Same as transform_general
, with f a projection (or its inverse).
Default falls back to transform_general
, which may be overridden
in subclasses.
Args | |
---|---|
x
|
same as in transform_general .
|
f
|
same as in transform_general .
|
**kwargs
|
same as in transform_general .
|
Returns | |
---|---|
log_density
|
A Tensor representing the log density correction of f at x
|
space
|
A TangentSpace representing the tangent to fM at f(x)
|
Raises | |
---|---|
NotImplementedError
|
if the TangentSpace subclass does not implement
transform_general .
|