Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge


Base class for polymorphic graph functions.

Inherits From: Callable

Graph functions are Python callable objects that dispatch calls to a TensorFlow graph. Polymorphic graph functions can be backed by multiple TF graphs, and automatically select the appropriate specialization based on the type of input they were called with. They may also create specializations on the fly if necessary, for example by tracing.

Also see tf.function.



View source

Returns compiler IR for the compiled function.

This API is intended only for debugging as there are no guarantees on backwards compatibility of returned IR or the allowed values of stage.

*args Arguments used for compilation; same arguments as used for calling the function. Need to be eager tensors.
**kwargs Keyword arguments used for compilation.

Function callable with the following kwargs:

  • stage at which the compiler IR should be serialized. Allowed values are:
    • hlo: HLO output after conversion from TF (
    • hlo_serialized: Like stage=hlo, but the output is a serialized HLO module proto (a bytes object).
    • optimized_hlo: HLO after compiler optimizations.
    • optimized_hlo_serialized: Like stage=optimized_hlo, but the output is a serialized HLO module proto (a bytes object).
    • optimized_hlo_dot: optimized HLO in DOT format suitable for Graphviz.
  • device_name can be either None, in which case the preferred device is used for compilation, or a device name. It can be a full device name, or a partial one, e.g., /device:CPU:0.

For example, for

def f(x):
  return x + 1

f.experimental_get_compiler_ir(tf.random.normal([10, 10])(stage='hlo')

the output is:

HloModule a_inference_f_13__.9

ENTRY %a_inference_f_13__.9 (arg0.1: f32[10,10]) -> f32[10,10] {
  %arg0.1 = f32[10,10]{1,0} parameter(0), parameter_replication={false}
  %reshape.2 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %arg0.1)
  %constant.3 = f32[] constant(1)
  %broadcast.4 = f32[10,10]{1,0} broadcast(f32[] %constant.3)
  %add.5 = f32[10,10]{1,0} add(f32[10,10]{1,0} %reshape.2,
                               f32[10,10]{1,0} %broadcast.4)
  %reshape.6 = f32[10,10]{1,0} reshape(f32[10,10]{1,0} %add.5)
  %tuple.7 = (f32[10,10]{1,0}) tuple(f32[10,10]{1,0} %reshape.6)
  ROOT %get-tuple-element.8 = f32[10,10]{1,0}
    get-tuple-element((f32[10,10]{1,0}) %tuple.7), index=0

ValueError If an invalid stage is selected or if applied to a function which is not compiled (jit_compile=True is not set).
TypeError When called with input in graph mode.


View source

Returns a ConcreteFunction specialized to input types.

The arguments specified by args and kwargs follow normal function call rules. The returned ConcreteFunction has the same set of positional and keyword arguments as self, but their types are refined to the types specified by args and kwargs.

def f(x):
  return x
f_concrete = f.get_concrete_function(tf.constant(1.0))
f_concrete = f.get_concrete_function(x=tf.constant(1.0))

Unlike normal calls, get_concrete_function allow type specifiers instead of TensorFlow objects, so for example tf.Tensors may be replaced with tf.TensorSpecs.

def f(x):
  return x
f_concrete = f.get_concrete_function(tf.TensorSpec([], tf.float64))

If the function definition allows only one specialization, args and kwargs may be omitted altogether.

@tf.function(input_signature=[tf.TensorSpec(None, tf.float32)])
def f(x):
  return x
f_concrete = f.get_concrete_function()

The returned ConcreteFunction can be called normally:

<tf.Tensor: shape=(), dtype=float32, numpy=1.0>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0>

*args inputs to specialize on.
**kwargs inputs to specialize on.

A ConcreteFunction.