View source on GitHub |
Unit normalization layer.
tf.keras.layers.UnitNormalization(
axis=-1, **kwargs
)
Normalize a batch of inputs so that each input in the batch has a L2 norm
equal to 1 (across the axes specified in axis
).
Example:
data = tf.constant(np.arange(6).reshape(2, 3), dtype=tf.float32)
normalized_data = tf.keras.layers.UnitNormalization()(data)
print(tf.reduce_sum(normalized_data[0, :] ** 2).numpy())
1.0