Wrapper to compile an object's public methods using XLA.
tfp.experimental.util.JitPublicMethods(
object_to_wrap,
trace_only=False,
methods_to_exclude=tfp.experimental.util.DEFAULT_METHODS_EXCLUDED_FROM_JIT
)
Args |
object_to_wrap
|
Any Python object; for example, a
tfd.Distribution instance.
|
trace_only
|
Python bool ; if True , the object's methods are
not compiled, but only traced with tf.function(jit_compile=False) .
This is only valid in the TensorFlow backend; in JAX, passing
trace_only=True will raise an exception.
Default value: False .
|
methods_to_exclude
|
List of Python str method names not to wrap.
For example, these may include methods that do not take or return
Tensor values. By default, a number of tfd.Distribution and
tfb.Bijector methods and properties are excluded (e.g.,
event_shape , batch_shape , dtype , etc.).
Default value:
tfp.experimental.util.DEFAULT_METHODS_EXCLUDED_FROM_JIT`
|
Attributes |
methods_to_exclude
|
|
object_to_wrap
|
|
trace_only
|
|
Methods
copy
View source
copy(
**kwargs
)
__getitem__
View source
__getitem__(
slices
)