View source on GitHub |
Wrapper that provides versions of a function with and without summaries.
orbit.utils.OptionalSummariesFunction(
function, **tf_function_kwargs
)
This is a utility class for implementing optimized summary recording via a
two-function approach, specifically important for TPUs. Two tf.function
versions of a given function
are created: one with soft device placement
enabled (for use on steps that require summary writing), and one with summary
writing and soft device placement entirely disabled (for use on all other
steps). This removes any performance impact of summaries on steps where they
aren't recorded (b/148418718).
This class can be used as a base class to implement summary optimizations for
a function with a specific signature. For example, to implement efficient TPU
summaries for a standard train()
method (as in orbit.AbstractTrainer
):
class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction):
'''Implements a two-program approach for summaries on TPU.'''
def __call__(self, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(num_steps)
return output
This can be used directly or to implement a decorator:
def train_function_with_summaries(function=None, **kwargs):
if function is not None:
return TrainFunctionWithSummaries(function, **kwargs)
return functools.partial(TrainFunctionWithSummaries, **kwargs)
The decorator can be applied directly to train()
methods:
@train_function_with_summaries
def train(self, num_steps):
...
A similar approach approach can be implemented for functions with different signatures.
This wrapper properly handles instance methods (see __get__
).
Args | |
---|---|
function
|
The underlying function to wrap. |
**tf_function_kwargs
|
Additional arguments to pass to tf.function .
|
Attributes | |
---|---|
with_summaries
|
A wrapped version of the underlying function with summaries
enabled (using whatever the active predicate is for
tf.summary.record_if ), and placed inside a "soft device placement"
context to enable summary recording on TPU.
|
without_summaries
|
A wrapped version of the underlying function with all summary recording disabled. |