View source on GitHub
|
Specifies when to prune layer and the sparsity(%) at each training step.
PruningSchedule controls pruning during training by notifying at each step whether the layer's weights should be pruned or not, and the sparsity(%) at which they should be pruned.
It can be invoked as a callable by providing the training step Tensor. It
returns a tuple of bool and float tensors.
should_prune, sparsity = pruning_schedule(step)
You can inherit this class to write your own custom pruning schedule.
Methods
from_config
@classmethodfrom_config( config )
Instantiates a PruningSchedule from its config.
| Args | |
|---|---|
config
|
Output of get_config().
|
| Returns | |
|---|---|
A PruningSchedule instance.
|
get_config
@abc.abstractmethodget_config()
__call__
@abc.abstractmethod__call__( step )
Returns the sparsity(%) to be applied.
If the returned sparsity(%) is 0, pruning is ignored for the step.
| Args | |
|---|---|
step
|
Current step in graph execution. |
| Returns | |
|---|---|
| Sparsity (%) that should be applied to the weights for the step. |
View source on GitHub