This transformation is a stateful relative of tf.data.Dataset.map.
In addition to mapping scan_func across the elements of the input dataset,
scan() accumulates one or more state tensors, whose initial values are
initial_state.
Args
initial_state
A nested structure of tensors, representing the initial state
of the accumulator.
scan_func
A function that maps (old_state, input_element) to
(new_state, output_element). It must take two arguments and return a
pair of nested structures of tensors. The new_state must match the
structure of initial_state.