View source on GitHub |
Decorator that overrides the default implementation for a TensorFlow API.
tf.experimental.dispatch_for_api(
api, *signatures
)
Used in the notebooks
Used in the guide |
---|
The decorated function (known as the "dispatch target") will override the
default implementation for the API when the API is called with parameters that
match a specified type signature. Signatures are specified using dictionaries
that map parameter names to type annotations. E.g., in the following example,
masked_add
will be called for tf.add
if both x
and y
are
MaskedTensor
s:
class MaskedTensor(tf.experimental.ExtensionType):
values: tf.Tensor
mask: tf.Tensor
@dispatch_for_api(tf.math.add, {'x': MaskedTensor, 'y': MaskedTensor})
def masked_add(x, y, name=None):
return MaskedTensor(x.values + y.values, x.mask & y.mask)
mt = tf.add(MaskedTensor([1, 2], [True, False]), MaskedTensor(10, True))
print(f"values={mt.values.numpy()}, mask={mt.mask.numpy()}")
values=[11 12], mask=[ True False]
If multiple type signatures are specified, then the dispatch target will be
called if any of the signatures match. For example, the following code
registers masked_add
to be called if x
is a MaskedTensor
or y
is
a MaskedTensor
.
@dispatch_for_api(tf.math.add, {'x': MaskedTensor}, {'y':MaskedTensor})
def masked_add(x, y):
x_values = x.values if isinstance(x, MaskedTensor) else x
x_mask = x.mask if isinstance(x, MaskedTensor) else True
y_values = y.values if isinstance(y, MaskedTensor) else y
y_mask = y.mask if isinstance(y, MaskedTensor) else True
return MaskedTensor(x_values + y_values, x_mask & y_mask)
The type annotations in type signatures may be type objects (e.g.,
MaskedTensor
), typing.List
values, or typing.Union
values. For
example, the following will register masked_concat
to be called if values
is a list of MaskedTensor
values:
@dispatch_for_api(tf.concat, {'values': typing.List[MaskedTensor]})
def masked_concat(values, axis):
return MaskedTensor(tf.concat([v.values for v in values], axis),
tf.concat([v.mask for v in values], axis))
Each type signature must contain at least one subclass of tf.CompositeTensor
(which includes subclasses of tf.ExtensionType
), and dispatch will only be
triggered if at least one type-annotated parameter contains a
CompositeTensor
value. This rule avoids invoking dispatch in degenerate
cases, such as the following examples:
@dispatch_for_api(tf.concat, {'values': List[MaskedTensor]})
: Will not dispatch to the decorated dispatch target when the user callstf.concat([])
.@dispatch_for_api(tf.add, {'x': Union[MaskedTensor, Tensor], 'y': Union[MaskedTensor, Tensor]})
: Will not dispatch to the decorated dispatch target when the user callstf.add(tf.constant(1), tf.constant(2))
.
The dispatch target's signature must match the signature of the API that is
being overridden. In particular, parameters must have the same names, and
must occur in the same order. The dispatch target may optionally elide the
"name" parameter, in which case it will be wrapped with a call to
tf.name_scope
when appropraite.
Returns | |
---|---|
A decorator that overrides the default implementation for api .
|
Registered APIs
The TensorFlow APIs that may be overridden by @dispatch_for_api
are:
<