# tfp.math.scan_associative

Perform a scan with an associative binary operation, in parallel.

The associative scan operation computes the cumulative sum, or all-prefix sum, of a set of elements under an associative binary operation . For example, using the ordinary addition operator `fn = lambda a, b: a + b`, this is equivalent to the ordinary cumulative sum `tf.math.cumsum` along axis 0. This method supports the general case of arbitrary associative binary operations operating on `Tensor`s or structures of `Tensor`s:

``````scan_associative(fn, elems) = tf.stack([
elems,
fn(elems, elems),
fn(elems, fn(elems, elems)),
...
fn(elems, fn(elems, fn(..., fn(elems[-2], elems[-1]))),
], axis=0)
``````

The associative structure allows the computation to be decomposed and executed by parallel reduction. Where a naive sequential implementation would loop over all `N` elements, this method requires only a logarithmic number (`2 * ceil(log_2 N)`) of sequential steps, and can thus yield substantial performance speedups from hardware-accelerated vectorization. The total number of invocations of the binary operation (including those performed in parallel) is `2 * (N / 2 + N / 4 + ... + 1) = 2N - 2` --- i.e., approximately twice as many as a naive approach.

 Blelloch, Guy E. Prefix sums and their applications Technical Report CMU-CS-90-190, School of Computer Science, Carnegie Mellon University, 1990.

`fn` Python callable implementing an associative binary operation with signature `r = fn(a, b)`. This must satisfy associativity: `fn(a, fn(b, c)) == fn(fn(a, b), c)`. The inputs and result are (possibly nested structures of) `Tensor`(s), matching `elems`. Each `Tensor` has a batch dimension in place of `elem_length`; the `fn` is expected to map over this dimension. The result `r` has the same shape (and structure) as the two inputs `a` and `b`.
`elems` A (possibly nested structure of) `Tensor`(s), each with dimension `elem_length` along `axis`. Note that `elem_length` determines the number of recursive steps required to perform the scan: if, in graph mode, this is not statically available, then ops will be created to handle any `elem_length` up to the maximum dimension of a `Tensor`.
`max_num_levels` Python `int`. The `axis` of the tensors in `elems` must have size less than `2**(max_num_levels + 1)`. The default value is sufficiently large for most needs. Lowering this value can reduce graph-building time when `scan_associative` is used with inputs of unknown shape. Default value: `48`.
`axis` Tensor `int` axis along which to perform the scan.
`validate_args` Python `bool`. When `True`, runtime checks for invalid inputs are performed. This may carry a performance cost. Default value: `False`.
`name` Python `str` name prefixed to ops created by this function.

`result` A (possibly nested structure of) `Tensor`(s) of the same shape and structure as `elems`, in which the `k`th element is the result of recursively applying `fn` to combine the first `k` elements of `elems`. For example, given `elems = [a, b, c, ...]`, the result would be `[a, fn(a, b), fn(fn(a, b), c), ...]`.

#### Examples

``````import tensorflow as tf
import tensorflow_probability as tfp
import operator

# Example 1: Partials sums of numbers.