View source on GitHub |
Abstract Base Class for making your own keras layer clusterable.
Used in the notebooks
Used in the guide |
---|
Your layer could be derived from a keras built-in layer or it could be a keras custom layer.
The function get_clusterable_weights should be provided in both cases.
The function get_clusterable_algorithm is provided, when weights for clustering is added in the keras layer.
Methods
get_clusterable_algorithm
get_clusterable_algorithm(
weight_name
)
Returns class with the clustering algorithm for the given weight_name.
This function needs to be implemented for the customerable layers. If the layer is derived from the built-in keras layer, the clustering algorithm for the base built-in keras layer is used.
The returned class should be derived from ClusteringAlgorithm and implements the function get_pulling_indices. This function is used to provide a special lookup function for the custom weights. It reshapes and tile centroids the same way as the weights. This allows us to find pulling indices efficiently.
Args | |
---|---|
weight_name
|
[string]
The name of the weight variable. |
get_clusterable_weights
@abc.abstractmethod
get_clusterable_weights()
Returns list of clusterable weight tensors.
All the weight tensors which the layer wants to be clustered during training must be returned by this method.
Returns: List of weight tensors/kernels in the keras layer which must be clustered during training. Each element in the list is a (name, kernel) 2-tuple that consists of the name of the clusterable kernel and the kernel object itself.