tfp.experimental.distributions.marginal_fns.ps.dtype_util.common_dtype

Returns (nested) explict dtype from args if there is one.

args A nested structure of objects that may have dtype. If dtype_hint is not nested, then the structure of args is flattened and ignored. If dtype_hint is nested, then args is interpreted as a depth-1 iterable or mapping, each element of which is an object with dtype of this structure (or dtype None), a nested structure with this shallow structure, or None. This enables unification of dtypes between objects of nested dtype and nested structures of arrays.
dtype_hint Optional (nested) dtype containing defaults to use in place of None. If dtype_hint is not nested and the common dtype of args is nested, dtype_hint serves as the default for each element of the common nested dtype structure.

dtype The (nested) dtype common across all elements of args, or None.

Examples

Usage with non-nested dtype:

x = tf.ones([3, 4], dtype=tf.float64)
y = 4.
z = None
common_dtype([x, y, z], dtype_hint=tf.float32)  # ==> tf.float64
common_dtype([y, z], dtype_hint=tf.float32)     # ==> tf.float32

# The arg to `common_dtype` can be an arbitrary nested structure; it is
# flattened, and the common dtype of its contents is returned.
common_dtype({'x': x, 'yz': (y, z)})
# ==> tf.float64

Usage with nested dtype:

# Define `x` and `y` as JointDistributions with the same nested dtype.
x = tfd.JointDistributionNamed(
    {'a': tfd.Uniform(np.float64(0.), 1.),
     'b': tfd.JointDistributionSequential(
        [tfd.Normal(0., 2.), tfd.Bernoulli(0.4)])})
x.dtype  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}

y = tfd.JointDistributionNamed(
    {'a': tfd.LogitNormal(np.float64(0.), 1.),
     'b': tfd.JointDistributionSequential(
        [tfd.Normal(-1., 1.), tfd.Bernoulli(0.6)])})
y.dtype  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}

# Pack x and y into an arbitrary nested structure and pass it to
# `common_dtype`.
args0 = [x, y]
common_dtype(args0)  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}

# If `dtype_hint` is not structured, the nested structure of the argument
# to `common_dtype` is flattened and ignored, and only the nested structures
# of the dtypes are relevant.
args1 = {'x': x, 'yz': {'y': y, 'z': None} }
common_dtype(args1)  # ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}

# Use structured `dtype_hint` to indicate the structure of the expected dtype.
# In this example, `x` is an object with structured dtype, and `t` is a
# a structure of objects whose dtypes are compatible with the corresponding
# components of `x.dtype`. Without structured `dtype_hint`, this example
# would fail, since the args `[x, t]` would be flattened entirely, and the
# structured `x.dtype` is incompatible with the non-structured `float32`
# contained in `t`.
t = {'a': [1., 2., 3.], 'b': [np.float32(1.), [[4, 5]]]}
common_dtype([x, t], dtype_hint={'a': None, 'b': [None, None]})
#   ==> {'a': tf.float64, 'b': [tf.float32, tf.int32]}