Tune in to the first Women in ML Symposium this Tuesday, October 19 at 9am PST Register now


View source on GitHub

Computes and returns the sampled sparse softmax training loss.

This is a faster way to train a softmax classifier over a huge number of classes.

This operation is for training only. It is generally an underestimate of the full softmax loss.

A common use case is to use this method for training, and calculate the full softmax loss for evaluation or inference. In this case, you must set partition_strategy="div" for the two losses to be consistent, as in the following example:

if mode == "train":
  loss = tf.nn.sampled_sparse_softmax_loss(
elif mode == "eval":
  logits = tf.matmul(inputs, tf.transpose(weights))
  logits = tf.nn.bias_add(logits, biases)
  loss = tf.nn.sparse_softmax_cross_entropy_with_logits(

See our Candidate Sampling Algorithms Reference

Also see Section 3 of Jean et al., 2014 (pdf) for the math.

weights A Tensor of shape [num_classes, dim], or a list of Tensor objects whose concatenation along dimension 0 has shape [num_classes, dim]. The (possibly-sharded) class embeddings.
biases A Tensor of shape [num_classes]. The class biases.
labels A Tensor of type int64 and shape [batch_size, 1]. The index of the single target class for each row of logits. Note that this format differs from the labels argument of nn.sparse_softmax_cross_entropy_with_logits.
inputs A Tensor of shape [batch_size, dim]. The forward activations of the input network.
num_sampled An int. The number of classes to randomly sample per batch.
num_classes An int. The number of possible classes.
sampled_values a tuple of (sampled_candidates, true_expected_count, sampled_expected_count) returned by a *_candidate_sampler function. (if None, we default to log_uniform_candidate_sampler)
remove_accidental_hits A bool. whether to remove "accidental hits" where a sampled class equals one of the target classes. Default is True.
partition_strategy A string specifying the partitioning strategy, relevant if len(weights) > 1. Currently "div" and "mod" are supported. Default is "mod". See tf.nn.embedding_lookup for more details.
name A name for the operation (optional).

A batch_size 1-D tensor of per-example sampled softmax losses.