tf.keras.layers.Normalization

Feature-wise normalization of the data.

Inherits From: PreprocessingLayer, Layer, Module

This layer will coerce its inputs into a distribution centered around 0 with standard deviation 1. It accomplishes this by precomputing the mean and variance of the data, and calling (input - mean) / sqrt(var) at runtime.

What happens in adapt(): Compute mean and variance of the data and store them as the layer's weights. adapt() should be called before fit(), evaluate(), or predict().

axis Integer, tuple of integers, or None. The axis or axes that should have a separate mean and variance for each index in the shape. For example, if shape is (None, 5) and axis=1, the layer will track 5 separate mean and variance values for the last axis. If axis is set to None, the layer will normalize all elements in the input by a scalar mean and variance. Defaults to -1, where the last axis of the input is assumed to be a feature dimension and is normalized per index. Note that in the specific case of batched scalar inputs where the only axis is the batch axis, the default will normalize each index in the batch separately. In this case, consider passing axis=None.
mean The mean value(s) to use during normalization. The passed value(s) will be broadcast to the shape of the kept axes above; if the value(s) cannot be broadcast, an error will be raised when this layer's build() method is called.
variance The variance value(s) to use during normalization. The passed value(s) will be broadcast to the shape of the kept axes above; if the value(s) cannot be broadcast, an error will be raised when this layer's build() method is called.

Examples:

Calculate a global mean and variance by analyzing the dataset in adapt().

adapt_data = np.array([1., 2., 3., 4., 5.], dtype='float32')
input_data = np.array([1., 2., 3.], dtype='float32')
layer = tf.keras.layers.Normalization(axis=None)
layer.adapt(adapt_data)
layer(input_data)
<tf.Tensor: shape=(3,), dtype=float32, numpy=
array([-1.4142135, -0.70710677, 0.], dtype=float32)>

Calculate a mean and variance for each index on the last axis.

adapt_data = np.array([[0., 7., 4.],
                       [2., 9., 6.],
                       [0., 7., 4.],
                       [2., 9., 6.]], dtype='float32')
input_data = np.array([[0., 7., 4.]], dtype='float32')
layer = tf.keras.layers.Normalization(axis=-1)
layer.adapt(adapt_data)
layer(input_data)
<tf.Tensor: shape=(1, 3), dtype=float32, numpy=
array([0., 0., 0.], dtype=float32)>

Pass the mean and variance directly.

input_data = np.array([[1.], [2.], [3.]], dtype='float32')
layer = tf.keras.layers.Normalization(mean=3., variance=2.)
layer(input_data)
<tf.Tensor: shape=(3, 1), dtype=float32, numpy=
array([[-1.4142135 ],
       [-0.70710677],
       [ 0.        ]], dtype=float32)>

is_adapted Whether the layer has been fit to data already.

Methods

adapt

View source

Fits the state of the preprocessing layer to the data being passed.

After calling adapt on a layer, a preprocessing layer's state will not update during training. In order to make preprocessing layers efficient in any distribution context, they are kept constant with respect to any compiled tf.Graphs that call the layer. This does not affect the layer use when adapting each layer only once, but if you adapt a layer multiple times you will need to take care to re-compile any compiled functions as follo