|View source on GitHub|
Class to implement highly efficient vectorised look-ups.
tfmot.clustering.keras.ClusteringAlgorithm( clusters_centroids, cluster_gradient_aggregation=GradientAggregation.SUM )
We do not utilise looping for that purpose, instead we
smartly reshape and
tile arrays. The trade-off is that we are potentially using way more memory
than we would have if looping is used.
Each class that inherits from this class is supposed to implement a particular lookup function for a certain shape.
For example, look-ups for 2D table will be different in the case of 3D.
||An array of shape (N,) that contains initial values of clusters centroids.|
||An enum that specify the aggregation method of the cluster gradient.|
get_clustered_weight( pulling_indices, original_weight )
Returns clustered weights with custom gradients.
Take indices (pulling_indices) as input and then form a new array by gathering cluster centroids based on the given pulling indices. The original gradients will also be modified in two ways:
- By averaging the gradient of cluster_centroids based on the size of each cluster.
- By adding an estimated gradient onto the non-differentiable original weight. Args: pulling_indices: a tensor of indices used for lookup of the same size as original_weight. original_weight: the original weights of the wrapped layer.
array with the same shape as
get_pulling_indices( weight )
Returns indices of closest cluster centroids.
Takes a weight(can be 1D, 2D or ND) and creates tf.int32 array of the same shape that will hold indices of cluster centroids clustered arrays elements will be pulled from.
In the current setup pulling indices are meant to be created once and used everywhere.
||ND array of weights. For each weight in this array the closest cluster centroids is found.|
ND array of the same shape as