Transforms elems by applying fn to each element unstacked on axis 0. (deprecated arguments)

Used in the notebooks

Used in the guide Used in the tutorials

See also tf.scan.

map_fn unstacks elems on axis 0 to obtain a sequence of elements; calls fn to transform each element; and then stacks the transformed values back together.

Mapping functions with single-Tensor inputs and outputs

If elems is a single tensor and fn's signature is tf.Tensor->tf.Tensor, then map_fn(fn, elems) is equivalent to tf.stack([fn(elem) for elem in tf.unstack(elems)]). E.g.:

tf.map_fn(fn=lambda t: tf.range(t, t + 3), elems=tf.constant([3, 5, 2]))
<tf.Tensor: shape=(3, 3), dtype=int32, numpy=
  array([[3, 4, 5],
         [5, 6, 7],
         [2, 3, 4]], dtype=int32)>

map_fn(fn, elems).shape = [elems.shape[0]] + fn(elems[0]).shape.

Mapping functions with multi-arity inputs and outputs

map_fn also supports functions with multi-arity inputs and outputs:

  • If elems is a tuple (or nested structure) of tensors, then those tensors must all have the same outer-dimension size (num_elems); and fn is used to transform each tuple (or structure) of corresponding slices from elems. E.g., if elems is a tuple (t1, t2, t3), then fn is used to transform each tuple of slices (t1[i], t2[i], t3[i]) (where 0 <= i < num_elems).

  • If fn returns a tuple (or nested structure) of tensors, then the result is formed by stacking corresponding elements from those structures.

Specifying fn's output signature

If fn's input and output signatures are different, then the output signature must be specified using fn_output_signature. (The input and output signatures are differ if their structures, dtypes, or tensor types do not match). E.g.:

tf.map_fn(fn=tf.strings.length,  # input & output have different dtypes
          elems=tf.constant(["hello", "moon"]),
<tf.Tensor: shape=(2,), dtype=int32, numpy=array([5, 4], dtype=int32)>
tf.map_fn(fn=tf.strings.join,  # input & output have different structures
          elems=[tf.constant(['The', 'A']), tf.constant(['Dog', 'Cat'])],
<tf.Tensor: shape=(2,), dtype=string,
 numpy=array([b'TheDog', b'ACat'], dtype=object)>

fn_output_signature can be specified using any of the following:


map_fn supports tf.RaggedTensor inputs and outputs. In particular:

  • If elems is a RaggedTensor, then fn will be called with each row of that ragged tensor.

    • If elems has only one ragged dimension, then the values passed to fn will be tf.Tensors.
    • If elems has multiple ragged dimensions, then the values passed to fn will be tf.RaggedTensors with one fewer ragged dimension.
  • If the result of map_fn should be a RaggedTensor, then use a tf.RaggedTensorSpec to specify fn_output_signature.

# Example: RaggedTensor input
rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
tf.map_fn(tf.reduce_sum, rt, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([6, 0, 9, 6], dtype=int32)>
# Example: RaggedTensor output
elems = tf.constant([3, 5, 0, 2])
tf.map_fn(tf.range, elems,
<tf.RaggedTensor [[0, 1, 2], [0, 1, 2, 3, 4], [], [0, 1]]>


rt = tf.ragged.constant([[1, 2, 3], [], [4, 5], [6]])
tf.ragged.map_flat_values(lambda x: x + 2, rt)
<tf.RaggedTensor [[3, 4, 5], [], [6, 7], [8]]>


map_fn supports tf.sparse.SparseTensor inputs and outputs. In particular:

  • If elems is a SparseTensor, then fn will be called with each row of that sparse tensor. In particular, the value passed to fn will be a tf.sparse.SparseTensor with one fewer dimension than elems.

  • If the result of map_fn should be a SparseTensor, then use a tf.SparseTensorSpec to specify fn_output_signature. The individual SparseTensors returned by fn will be stacked into a single SparseTensor with one more dimension.

# Example: SparseTensor input
st = tf.sparse.SparseTensor([[0, 0], [2, 0], [2, 1]], [2, 3, 4], [4, 4])
tf.map_fn(tf.sparse.reduce_sum, st, fn_output_signature=tf.int32)
<tf.Tensor: shape=(4,), dtype=int32, numpy=array([2, 0, 7, 0], dtype=int32)>
# Example: SparseTensor output
    tf.map_fn(tf.sparse.eye, tf.constant([2, 3]),
              fn_output_signature=tf.SparseTensorSpec(None, tf.float32)))
<tf.Tensor: shape=(2, 3, 3), dtype=float32, numpy=
  array([[[1., 0., 0.],
          [0., 1., 0.],
          [0., 0., 0.]],
         [[1., 0., 0.],
          [0., 1., 0.],
          [0., 0., 1.]]], dtype=float32)>
  • If the function is expressible as TensorFlow ops, use:
tf.sparse.SparseTensor(st.indices, fn(st.values), st.dense_shape)
  • Otherwise, use:
tf.sparse.SparseTensor(st.indices, tf.map_fn(fn, st.values),

map_fn vs. vectorized operations

map_fn will apply the operations used by fn to each element of elems, resulting in O(elems.shape[0]) total operations. This is somewhat mitigated by the fact that map_fn can process elements in parallel. However, a transform expressed using map_fn is still typically less efficient than an equivalent transform expressed using vectorized operations.

map_fn should typically only be used if one of the following is true:

  • It is difficult or expensive to express the desired transform with vectorized operations.
  • fn creates large intermediate values, so an equivalent vectorized transform would take too much memory.
  • Processing elements in parallel is more efficient than an equivalent vectorized transform.
  • Efficiency of the transform is not critical, and using map_fn is more readable.

E.g., the example given above that maps fn=lambda t: tf.range(t, t + 3) across elems could be rewritten more efficiently using vectorized ops: