ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


Flattens an input tensor while preserving the batch axis (axis 0).

Inherits From: Flatten, Layer, Layer, Module

Migrate to TF2

This API is not compatible with eager execution or tf.function.

Please refer to tf.layers section of the migration guide to migrate a TensorFlow v1 model to Keras. The corresponding TensorFlow v2 layer is tf.keras.layers.Flatten.

Structural Mapping to Native TF2

None of the supported arguments have changed name.


 flatten = tf.compat.v1.layers.Flatten()


 flatten = tf.keras.layers.Flatten()


Used in the notebooks

Used in the guide

data_format A string, one of channels_last (default) or channels_first. The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch, ..., channels) while channels_first corresponds to inputs with shape (batch, channels, ...).


  x = tf.compat.v1.placeholder(shape=(None, 4, 4), dtype='float32')
  y = Flatten()(x)
  # now `y` has shape `(None, 16)`

  x = tf.compat.v1.placeholder(shape=(None, 3, None), dtype='float32')
  y = Flatten()(x)
  # now `y` has shape `(None, None)`