View source on GitHub |
Clone a Functional or Sequential Model
instance.
tf.keras.models.clone_model(
model,
input_tensors=None,
clone_function=None,
call_function=None,
recursive=False,
**kwargs
)
Used in the notebooks
Used in the tutorials |
---|
Model cloning is similar to calling a model on new inputs, except that it creates new layers (and thus new weights) instead of sharing the weights of the existing layers.
Note that
clone_model
will not preserve the uniqueness of shared objects within the
model (e.g. a single variable attached to two distinct layers will be
restored as two separate variables).
Returns | |
---|---|
An instance of Model reproducing the behavior
of the original model, on top of new inputs tensors,
using newly instantiated weights. The cloned model may behave
differently from the original model if a custom clone_function
or call_function modifies a layer or layer call.
|
Example:
# Create a test Sequential model.
model = keras.Sequential([
keras.layers.Input(shape=(728,)),
keras.layers.Dense(32, activation='relu'),
keras.layers.Dense(1, activation='sigmoid'),
])
# Create a copy of the test model (with freshly initialized weights).
new_model = clone_model(model)
Using a clone_function
to make a model deterministic by setting the
random seed everywhere:
def clone_function(layer):
config = layer.get_config()
if "seed" in config:
config["seed"] = 1337
return layer.__class__.from_config(config)
new_model = clone_model(model)
Using a call_function
to add a Dropout
layer after each Dense
layer
(without recreating new layers):
def call_function(layer, *args, **kwargs):
out = layer(*args, **kwargs)
if isinstance(layer, keras.layers.Dense):
out = keras.layers.Dropout(0.5)(out)
return out
new_model = clone_model(
model,
clone_function=lambda x: x, # Reuse the same layers.
call_function=call_function,
)
Note that subclassed models cannot be cloned by default,
since their internal layer structure is not known.
To achieve equivalent functionality
as clone_model
in the case of a subclassed model, simply make sure
that the model class implements get_config()
(and optionally from_config()
), and call:
new_model = model.__class__.from_config(model.get_config())
In the case of a subclassed model, you cannot using a custom
clone_function
.