tfmot.clustering.keras.ClusteringAlgorithm

Class to implement highly efficient vectorised look-ups.

We do not utilise looping for that purpose, instead we smartly reshape and tile arrays. The trade-off is that we are potentially using way more memory than we would have if looping is used.

Each class that inherits from this class is supposed to implement a particular lookup function for a certain shape.

For example, look-ups for 2D table will be different in the case of 3D.

clusters_centroids An array of shape (N,) that contains initial values of clusters centroids.
cluster_gradient_aggregation An enum that specify the aggregation method of the cluster gradient.
data_format To be used in cluster_per_channel to ensure the weight kernel is permuted properly when updating the weights and calculating gradients

Methods

add_gradient_to_original_weight

View source

Overrides gradients in the backprop stage.

This function overrides gradients in the backprop stage: the Jacobian matrix of multiplication is replaced with the identity matrix, which effectively changes multiplication into add in the backprop. Since the gradient of tf.sign is 0, overwriting it with identity follows the design of straight-through-estimator, which accepts all upstream gradients and uses them to update original non-clustered weights of the layer. Here, we assume the gradient updates on individual elements inside a cluster will be different so that there is no point in mapping the gradient updates back to original non-clustered weights using the LUT.

Args
clustered_weight clustered weights
original_weight original weights

Returns
result and custom gradient, as expected by @tf.custom_gradient

average_centroids_gradient_by_cluster_size

View source

Average the gradient based on the number of weights.

get_clustered_weight

View source

Returns clustered weights with custom gradients.

Take indices (pulling_indices) as input and then form a new array by gathering cluster centroids based on the given pulling indices. The original gradients will also be modified in two ways:

  • By averaging the gradient of cluster_centroids based on the size of each cluster.
  • By adding an estimated gradient onto the non-differentiable original weight. Args: pulling_indices: a tensor of indices used for lookup of the same size as original_weight. original_weight: the original weights of the wrapped layer.

Returns
array with the same shape as pulling_indices. Each array element is a member of self.cluster_centroids. The backward pass is modified by adding custom gradients.

get_pulling_indices

View source

Returns indices of closest cluster centroids.

Takes a weight(can be 1D, 2D or ND) and creates tf.int32 array of the same shape that will hold indices of cluster centroids clustered arrays elements will be pulled from.

In the current setup pulling indices are meant to be created once and used everywhere.

Args
weight ND array of weights. For each weight in this array the closest cluster centroids is found.
centroids Optional list of cluster centroids.

Returns
ND array of the same shape as weight parameter of the type tf.int32. The returned array contain weight lookup indices