View source on GitHub
|
Specifies what layers to prune in the model.
PruningPolicy controls application of PruneLowMagnitude wrapper on per-layer
basis and checks that the model contains only supported layers.
PruningPolicy works together with prune_low_magnitude through which it
provides fine-grained control over pruning in the model.
pruning_params = {
'pruning_schedule': ConstantSparsity(0.5, 0),
'block_size': (1, 1),
'block_pooling_type': 'AVG'
}
model = prune_low_magnitude(
keras.Sequential([
layers.Dense(10, activation='relu', input_shape=(100,)),
layers.Dense(2, activation='sigmoid')
]),
pruning_policy=PruneForLatencyOnXNNPack(),
**pruning_params)
You can inherit this class to write your own custom pruning policy.
The API is experimental and is subject to change.
Methods
allow_pruning
@abc.abstractmethodallow_pruning( layer )
Checks if pruning wrapper should be applied for the current layer.
| Args | |
|---|---|
layer
|
Current layer in the model. |
| Returns | |
|---|---|
| True/False, whether the pruning wrapper should be applied for the layer. |
ensure_model_supports_pruning
@abc.abstractmethodensure_model_supports_pruning( model )
Checks that the model contains only supported layers.
| Args | |
|---|---|
model
|
A tf.keras.Model instance which is going to be pruned.
|
| Raises | |
|---|---|
ValueError
|
if the keras model doesn't support pruning policy, i.e. keras model contains an unsupported layer. |
View source on GitHub