ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more


TensorFlow 1 version View source on GitHub

This is the class from which all layers inherit.

Inherits From: Module

A layer is a callable object that takes as input one or more tensors and that outputs one or more tensors. It involves computation, defined in the call() method, and a state (weight variables), defined either in the constructor __init__() or in the build() method.

Users will just instantiate a layer and then treat it as a callable.

trainable Boolean, whether the layer's variables should be trainable.
name String name of the layer.
dtype The dtype of the layer's computations and weights (default of None means use tf.keras.backend.floatx in TensorFlow 2, or the type of the first input in TensorFlow 1).
dynamic Set this to True if your layer should only be run eagerly, and should not be used to generate a static computation graph. This would be the case for a Tree-RNN or a recursive network, for example, or generally for any layer that manipulates tensors using Python control flow. If False, we assume that the layer can safely be used to generate a static computation graph.

We recommend that descendants of Layer implement the following methods:

  • __init__(): Defines custom layer attributes, and creates layer state variables that do not depend on input shapes, using add_weight().
  • build(self, input_shape): This method can be used to create weights that depend on the shape(s) of the input(s), using add_weight(). __call__() will automatically build the layer (if it has not been built yet) by calling build().
  • call(self, *args, **kwargs): Called in __call__ after making sure build() has been called. call() performs the logic of applying the layer to the input tensors (which should be passed in as argument). Two reserved keyword arguments you can optionally use in call() are:
    • training (boolean, whether the call is in inference mode or training mode)
    • mask (boolean tensor encoding masked timesteps in the input, used in RNN layers)
  • get_config(self): Returns a dictionary containing the configuration used to initialize this layer. If the keys differ from the arguments in __init__, then override from_config(self) as well. This method is used when saving the layer or a model that contains this layer.


Here's a basic example: a layer with two variables, w and b, that returns y = w . x + b. It shows how to implement build() and call(). Variables set as attributes of a layer are tracked as weights of the layers (in layer.weights).

class SimpleDense(Layer):

  def __init__(self, units=32):
      super(SimpleDense, self).__init__()
      self.units = units

  def build(self, input_shape):  # Create the state of the layer (weights)
    w_init = tf.random_normal_initializer()
    self.w = tf.Variable(
        initial_value=w_init(shape=(input_shape[-1], self.units),
    b_init = tf.zeros_initializer()
    self.b = tf.Variable(
        initial_value=b_init(shape=(self.units,), dtype='float32'),

  def call(self, inputs):  # Defines the computation from inputs to outputs
      return tf.matmul(inputs, self.w) + self.b

# Instantiates the layer.
linear_layer = SimpleDense(4)

# This will also call `build(input_shape)` and create the weights.
y = linear_layer(tf.ones((2, 2)))
assert len(linear_layer.weights) == 2

# These weights are trainable, so they're listed in `trainable_weights`:
assert len(linear_layer.trainable_weights) == 2

Note that the method add_weight() offers a shortcut to create weights:

class SimpleDense(Layer):

  def __init__(self, units=32):
      super(SimpleDense, self).__init__()
      self.units = units

  def build(self, input_shape):
      self.w = self.add_weight(shape=(input_shape[-1], self.units),
      self.b = self.add_weight(shape=(self.units,),

  def call(self, inputs):
      return tf.matmul(inputs, self.w) + self.b

Besides trainable weights, updated via backpropagation during training, layers can also have non-trainable weights. These weights are meant to be updated manually during call(). Here's a example layer that computes the running sum of its inputs:

class ComputeSum(Layer):

  def __init__(self, input_dim):
      super(ComputeSum, self).__init__()
      # Create a non-trainable weight. = tf.Variable(initial_value=tf.zeros((input_dim,)),

  def call(self, inputs):, axis=0))

my_sum = ComputeSum(2)
x = tf.ones((2, 2))

y = my_sum(x)
print(y.numpy())  # [2. 2.]

y = my_sum(x)
print(y.numpy())  # [4. 4.]

assert my_sum.weights == []
assert my_sum.non_trainable_weights == []
assert my_sum.trainable_weights == []

For more information about creating layers, see the guide Writing custom layers and models with Keras

About the layer's dtype attribute:

Each layer has a dtype, which is typically the dtype of the layer's computations and variables. A layer's dtype can be queried via the Layer.dtype property. The dtype is specified with the dtype constructor argument. In TensorFlow 2, the dtype defaults to tf.keras.backend.floatx() if no dtype is passed. floatx() itself defaults to "float32". Additionally, layers will cast their inputs to the layer's dtype in TensorFlow 2. When mixed precision is used, layers may have different computation and variable dtypes. See tf.keras.mixed_precision.experimental.Policy for details on layer dtypes.

name The name of the layer (string).
dtype The dtype of the layer's computations and weights. If mixed precision is used with a tf.keras.mixed_precision.experimental.Policy, this is instead just the dtype of the layer's weights, as the computations are done in a different dtype.
trainable_weights List of variables to be included in backprop.
non_trainable_weights List of variables that should not be included in backprop.
weights The concatenation of the lists trainable_weights and non_trainable_weights (in this order).
trainable Whether the layer should be trained (boolean), i.e. whether its potentially-trainable weights should be returned as part of layer.trainable_weights.
input_spec Optional (list of) InputSpec object(s) specifying the constraints on inputs that can be accepted by the layer.
activity_regularizer Optional regularizer function for the output of this layer.
dynamic Whether the layer is dynamic (eager-only); set in the constructor.
input Retrieves the input tensor(s) of a layer.

Only applicable if the layer has exactly one input, i.e. if it is connected to one incoming layer.

losses List of losses added using the add_loss()