Module: tf.contrib.model_pruning

View source on GitHub

Model pruning implementation in tensorflow.

Classes

class MaskedBasicLSTMCell: Basic LSTM recurrent network cell with pruning.

class MaskedLSTMCell: LSTMCell with pruning.

class Pruning

Functions

apply_mask(...): Apply mask to a given weight tensor.

get_masked_weights(...)

get_masks(...)

get_pruning_hparams(...): Get a tf.HParams object with the default values for the hyperparameters.

get_thresholds(...)

get_weight_sparsity(...): Get sparsity of the weights.

get_weights(...)

graph_def_from_checkpoint(...): Converts checkpoint data to GraphDef.

masked_conv2d(...): Adds an 2D convolution followed by an optional batch_norm layer.

masked_convolution(...): Adds an 2D convolution followed by an optional batch_norm layer.

masked_fully_connected(...): Adds a sparse fully connected layer. The weight matrix is masked.

strip_pruning_vars_fn(...): Removes mask variable from the graph.

train(...): Wrapper around tf-slim's train function.