View source on GitHub |
Returns (nested) explict dtype from args
if there is one.
tfp.experimental.distributions.marginal_fns.ps.dtype_util.common_dtype(
args, dtype_hint=None
)
Returns | |
---|---|
dtype
|
The (nested) dtype common across all elements of args , or None .
|
Examples
Usage with non-nested dtype:
x = tf.ones([3, 4], dtype=tf.float64)
y = 4.
z = None
common_dtype([x, y, z], dtype_hint=tf.float32) # ==> tf.float64
common_dtype([y, z], dtype_hint=tf.float32) # ==> tf.float32
# The arg to `common_dtype` can be an arbitrary nested structure; it is
# flattened, and the common dtype of its contents is returned.
common_dtype({'x': x, 'yz': (y, z)})
# ==> tf.float64
Usage with nested dtype:
# Define `x` and `y` as JointDistributions with the same nested dtype.
x = tfd.JointDistributionNamed(
{'a': tfd.Uniform(np.float64(0.), 1.),
'b': tfd.JointDistributionSequential(
[tfd.Normal(0., 2.), tfd.Bernoulli(0.4)])})
x.dtype # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
y = tfd.JointDistributionNamed(
{'a': tfd.LogitNormal(np.float64(0.), 1.),
'b': tfd.JointDistributionSequential(
[tfd.Normal(-1., 1.), tfd.Bernoulli(0.6)])})
y.dtype # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
# Pack x and y into an arbitrary nested structure and pass it to
# `common_dtype`.
args0 = [x, y]
common_dtype(args0) # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
# If `dtype_hint` is not structured, the nested structure of the argument
# to `common_dtype` is flattened and ignored, and only the nested structures
# of the dtypes are relevant.
args1 = {'x': x, 'yz': {'y': y, 'z': None} }
common_dtype(args1) # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}
# Use structured `dtype_hint` to indicate the structure of the expected dtype.
# In this example, `x` is an object with structured dtype, and `t` is a
# a structure of objects whose dtypes are compatible with the corresponding
# components of `x.dtype`. Without structured `dtype_hint`, this example
# would fail, since the args `[x, t]` would be flattened entirely, and the
# structured `x.dtype` is incompatible with the non-structured `float32`
# contained in `t`.
t = {'a': [1., 2., 3.], 'b': [np.float32(1.), [[4, 5]]]}
common_dtype([x, t], dtype_hint={'a': None, 'b': [None, None]})
# ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}