Google I/O is a wrap! Catch up on TensorFlow sessions View sessions

Wiki Talk Comments Toxicity Prediction

View on Run in Google Colab View on GitHub Download notebook

In this example, we consider the task of predicting whether a discussion comment posted on a Wiki talk page contains toxic content (i.e. contains content that is “rude, disrespectful or unreasonable”). We use a public dataset released by the Conversation AI project, which contains over 100k comments from the English Wikipedia that are annotated by crowd workers (see paper for labeling methodology).

One of the challenges with this dataset is that a very small proportion of the comments cover sensitive topics such as sexuality or religion. As such, training a neural network model on this dataset leads to disparate performance on the smaller sensitive topics. This can mean that innocuous statements about those topics might get incorrectly flagged as ‘toxic’ at higher rates, causing speech to be unfairly censored

By imposing constraints during training, we can train a fairer model that performs more equitably across the different topic groups.

We will use the TFCO library to optimize for our fairness goal during training.


Let's first install and import the relevant libraries. Note that you may have to restart your colab once after running the first cell because of outdated packages in the runtime. After doing so, there should be no further issues with imports.

pip installs

Note that depending on when you run the cell below, you may receive a warning about the default version of TensorFlow in Colab switching to TensorFlow 2.X soon. You can safely ignore that warning as this notebook was designed to be compatible with TensorFlow 1.X and 2.X.

Import Modules

Though TFCO is compatible with eager and graph execution, this notebook assumes that eager execution is enabled by default. To ensure that nothing breaks, eager execution will be enabled in the cell below.

Enable Eager Execution and Print Versions

Eager execution enabled by default.
TensorFlow 2.9.0-rc2
TFMA 0.38.0


First, we set some hyper-parameters needed for the data preprocessing and model training.

hparams = {
    "batch_size": 128,
    "cnn_filter_sizes": [128, 128, 128],
    "cnn_kernel_sizes": [5, 5, 5],
    "cnn_pooling_sizes": [5, 5, 40],
    "constraint_learning_rate": 0.01,
    "embedding_dim": 100,
    "embedding_trainable": False,
    "learning_rate": 0.005,
    "max_num_words": 10000,
    "max_sequence_length": 250

Load and pre-process dataset

Next, we download the dataset and preprocess it. The train, test and validation sets are provided as separate CSV files.

toxicity_data_url = (""

data_train = pd.read_csv(toxicity_data_url + "wiki_train.csv")
data_test = pd.read_csv(toxicity_data_url + "wiki_test.csv")
data_vali = pd.read_csv(toxicity_data_url + "wiki_dev.csv")


The comment column contains the discussion comments and is_toxic column indicates whether or not a comment is annotated as toxic.

In the following, we:

  1. Separate out the labels
  2. Tokenize the text comments
  3. Identify comments that contain sensitive topic terms

First, we separate the labels from the train, test and validation sets. The labels are all binary (0 or 1).

labels_train = data_train["is_toxic"].values.reshape(-1, 1) * 1.0
labels_test = data_test["is_toxic"].values.reshape(-1, 1) * 1.0
labels_vali = data_vali["is_toxic"].values.reshape(-1, 1) * 1.0

Next, we tokenize the textual comments using the Tokenizer provided by Keras. We use the training set comments alone to build a vocabulary of tokens, and use them to convert all the comments into a (padded) sequence of tokens of the same length.

tokenizer = text.Tokenizer(num_words=hparams["max_num_words"])

def prep_text(texts, tokenizer, max_sequence_length):
    # Turns text into into padded sequences.
    text_sequences = tokenizer.texts_to_sequences(texts)
    return sequence.pad_sequences(text_sequences, maxlen=max_sequence_length)

text_train = prep_text(data_train["comment"], tokenizer, hparams["max_sequence_length"])
text_test = prep_text(data_test["comment"], tokenizer, hparams["max_sequence_length"])
text_vali = prep_text(data_vali["comment"], tokenizer, hparams["max_sequence_length"])

Finally, we identify comments related to certain sensitive topic groups. We consider a subset of the identity terms provided with the dataset and group them into four broad topic groups: sexuality, gender identity, religion, and race.

terms = {
    'sexuality': ['gay', 'lesbian', 'bisexual', 'homosexual', 'straight', 'heterosexual'], 
    'gender identity': ['trans', 'transgender', 'cis', 'nonbinary'],
    'religion': ['christian', 'muslim', 'jewish', 'buddhist', 'catholic', 'protestant', 'sikh', 'taoist'],
    'race': ['african', 'african american', 'black', 'white', 'european', 'hispanic', 'latino', 'latina', 
             'latinx', 'mexican', 'canadian', 'american', 'asian', 'indian', 'middle eastern', 'chinese', 

group_names = list(terms.keys())
num_groups = len(group_names)

We then create separate group membership matrices for the train, test and validation sets, where the rows correspond to comments, the columns correspond to the four sensitive groups, and each entry is a boolean indicating whether the comment contains a term from the topic group.

def get_groups(text):
    # Returns a boolean NumPy array of shape (n, k), where n is the number of comments, 
    # and k is the number of groups. Each entry (i, j) indicates if the i-th comment 
    # contains a term from the j-th group.
    groups = np.zeros((text.shape[0], num_groups))
    for ii in range(num_groups):
        groups[:, ii] = text.str.contains('|'.join(terms[group_names[ii]]), case=False)
    return groups

groups_train = get_groups(data_train["comment"])
groups_test = get_groups(data_test["comment"])
groups_vali = get_groups(data_vali["comment"])

As shown below, all four topic groups constitute only a small fraction of the overall dataset, and have varying proportions of toxic comments.

print("Overall label proportion = %.1f%%" % (labels_train.mean() * 100))

group_stats = []
for ii in range(num_groups):
    group_proportion = groups_train[:, ii].mean()
    group_pos_proportion = labels_train[groups_train[:, ii] == 1].mean()
                        "%.2f%%" % (group_proportion * 100), 
                        "%.1f%%" % (group_pos_proportion * 100)])
group_stats = pd.DataFrame(group_stats, 
                           columns=["Topic group", "Group proportion", "Label proportion"])
Overall label proportion = 9.7%

We see that only 1.3% of the dataset contains comments related to sexuality. Among them, 37% of the comments have been annotated as being toxic. Note that this is significantly larger than the overall proportion of comments annotated as toxic. This could be because the few comments that used those identity terms did so in pejorative contexts. As mentioned above, this could cause our model to disporportionately misclassify comments as toxic when they include those terms. Since this is the concern, we'll make sure to look at the False Positive Rate when we evaluate the model's performance.

Build CNN toxicity prediction model

Having prepared the dataset, we now build a Keras model for prediction toxicity. The model we use is a convolutional neural network (CNN) with the same architecture used by the Conversation AI project for their debiasing analysis. We adapt code provided by them to construct the model layers.

The model uses an embedding layer to convert the text tokens to fixed-length vectors. This layer converts the input text sequence into a sequence of vectors, and passes them through several layers of convolution and pooling operations, followed by a final fully-connected layer.

We make use of pre-trained GloVe word vector embeddings, which we download below. This may take a few minutes to complete.

zip_file_url = ""
zip_file = urllib.request.urlopen(zip_file_url)
archive = zipfile.ZipFile(io.BytesIO(

We use the downloaded GloVe embeddings to create an embedding matrix, where the rows contain the word embeddings for the tokens in the Tokenizer's vocabulary.

embeddings_index = {}
glove_file = "glove.6B.100d.txt"

with as f:
    for line in f:
        values = line.split()
        word = values[0].decode("utf-8") 
        coefs = np.asarray(values[1:], dtype="float32")
        embeddings_index[word] = coefs

embedding_matrix = np.zeros((len(tokenizer.word_index) + 1, hparams["embedding_dim"]))
num_words_in_embedding = 0
for word, i in tokenizer.word_index.items():
    embedding_vector = embeddings_index.get(word)
    if embedding_vector is not None:
        num_words_in_embedding += 1
        embedding_matrix[i] = embedding_vector

We are now ready to specify the Keras layers. We write a function to create a new model, which we will invoke whenever we wish to train a new model.

def create_model():
    model = keras.Sequential()

    # Embedding layer.
    embedding_layer = layers.Embedding(

    # Convolution layers.
    for filter_size, kernel_size, pool_size in zip(
        hparams['cnn_filter_sizes'], hparams['cnn_kernel_sizes'],

        conv_layer = layers.Conv1D(
            filter_size, kernel_size, activation='relu', padding='same')

        pooled_layer = layers.MaxPooling1D(pool_size, padding='same')

    # Add a flatten layer, a fully-connected layer and an output layer.
    model.add(layers.Dense(128, activation='relu'))

    return model

We also define a method to set random seeds. This is done to ensure reproducible results.

def set_seeds():

Fairness indicators

We also write functions to plot fairness indicators.

def create_examples(labels, predictions, groups, group_names):
  # Returns tf.examples with given labels, predictions, and group information.  
  examples = []
  sigmoid = lambda x: 1/(1 + np.exp(-x)) 
  for ii in range(labels.shape[0]):
    example = tf.train.Example()
        sigmoid(predictions[ii][0]))  # predictions need to be in [0, 1].
    for jj in range(groups.shape[1]):
          b'Yes' if groups[ii, jj] else b'No')
  return examples
def evaluate_results(labels, predictions, groups, group_names):
  # Evaluates fairness indicators for given labels, predictions and group
  # membership info.
  examples = create_examples(labels, predictions, groups, group_names)

  # Create feature map for labels, predictions and each group.
  feature_map = {
      'prediction':[], tf.float32),
      'toxicity':[], tf.float32),
  for group in group_names:
    feature_map[group] =[], tf.string)

  # Serialize the examples.
  serialized_examples = [e.SerializeToString() for e in examples]

  BASE_DIR = tempfile.gettempdir()
  OUTPUT_DIR = os.path.join(BASE_DIR, 'output')

  with beam.Pipeline() as pipeline:
    model_agnostic_config = agnostic_predict.ModelAgnosticConfig(

    slices = [tfma.slicer.SingleSliceSpec()]
    for group in group_names:

    extractors = [

    metrics_callbacks = [

    # Create a model agnostic aggregator.
    eval_shared_model = tfma.types.EvalSharedModel(

    # Run Model Agnostic Eval.
    _ = (
        | beam.Create(serialized_examples)
        | 'ExtractEvaluateAndWriteResults' >>

  fairness_ind_result = tfma.load_eval_result(output_path=OUTPUT_DIR)

  # Also evaluate accuracy of the model.
  accuracy = np.mean(labels == (predictions > 0.0))

  return fairness_ind_result, accuracy
def plot_fairness_indicators(eval_result, title):
  fairness_ind_result, accuracy = eval_result
  display(HTML("<center><h2>" + title + 
               " (Accuracy = %.2f%%)" % (accuracy * 100) + "</h2></center>"))
def plot_multi_fairness_indicators(multi_eval_results):

  multi_results = {}
  multi_accuracy = {}
  for title, (fairness_ind_result, accuracy) in multi_eval_results.items():
    multi_results[title] = fairness_ind_result
    multi_accuracy[title] = accuracy

  title_str = "<center><h2>"
  for title in multi_eval_results.keys():
      title_str+=title + " (Accuracy = %.2f%%)" % (multi_accuracy[title] * 100) + "; "
  # fairness_ind_result, accuracy = eval_result

Train unconstrained model

For the first model we train, we optimize a simple cross-entropy loss without any constraints..

# Set random seed for reproducible results.
# Optimizer and loss.
optimizer = tf.keras.optimizers.Adam(learning_rate=hparams["learning_rate"])
loss = lambda y_true, y_pred: tf.keras.losses.binary_crossentropy(
    y_true, y_pred, from_logits=True)

# Create, compile and fit model.
model_unconstrained = create_model()
model_unconstrained.compile(optimizer=optimizer, loss=loss)
    x=text_train, y=labels_train, batch_size=hparams["batch_size"], epochs=2)
Epoch 1/2
748/748 [==============================] - 14s 5ms/step - loss: 0.1575
Epoch 2/2
748/748 [==============================] - 4s 5ms/step - loss: 0.1202
<keras.callbacks.History at 0x7f98bc419d10>

Having trained the unconstrained model, we plot various evaluation metrics for the model on the test set.

scores_unconstrained_test = model_unconstrained.predict(text_test)
eval_result_unconstrained = evaluate_results(
    labels_test, scores_unconstrained_test, groups_test, group_names)
996/996 [==============================] - 2s 2ms/step
WARNING:apache_beam.runners.interactive.interactive_environment:Dependencies required for Interactive Beam PCollection visualization are not available, please use: `pip install apache-beam[interactive]` to install necessary dependencies to enable all data visualization features.
WARNING:root:Make sure that locally built Python SDK docker image has Python 3.7 interpreter.
INFO:tensorflow:ExampleCount post export metric: could not find any of the standard keys in predictions_dict (keys were: dict_keys(['prediction']))
INFO:tensorflow:ExampleCount post export metric: could not find any of the standard keys in predictions_dict (keys were: dict_keys(['prediction']))
INFO:tensorflow:Using the first key from predictions_dict: prediction
INFO:tensorflow:Using the first key from predictions_dict: prediction't find python-snappy so the implementation of _TFRecordUtil._masked_crc32c is not as fast as it could be.
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_model_analysis/writers/ tf_record_iterator (from is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 
WARNING:tensorflow:From /tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow_model_analysis/writers/ tf_record_iterator (from is deprecated and will be removed in a future version.
Instructions for updating:
Use eager execution and: 

As explained above, we are concentrating on the false positive rate. In their current version (0.1.2), Fairness Indicators select false negative rate by default. After running the line below, go ahead and deselect false_negative_rate and select false_positive_rate to look at the metric we are interested in.

plot_fairness_indicators(eval_result_unconstrained, "Unconstrained")

While the overall false positive rate is less than 2%, the false positive rate on the sexuality-related comments is significantly higher. This is because the sexuality group is very small in size, and has a disproportionately higher fraction of comments annotated as toxic. Hence, training a model without constraints results in the model believing that sexuality-related terms are a strong indicator of toxicity.

Train with constraints on false positive rates

To avoid large differences in false positive rates across different groups, we next train a model by constraining the false positive rates for each group to be within a desired limit. In this case, we will optimize the error rate of the model subject to the per-group false positive rates being lesser or equal to 2%.

Training on minibatches with per-group constraints can be challenging for this dataset, however, as the groups we wish to constraint are all small in size, and it's likely that the individual minibatches contain very few examples from each group. Hence the gradients we compute during training will be noisy, and result in the model converging very slowly.

To mitigate this problem, we recommend using two streams of minibatches, with the first stream formed as before from the entire training set, and the second stream formed solely from the sensitive group examples. We will compute the objective using minibatches from the first stream and the per-group constraints using minibatches from the second stream. Because the batches from the second stream are likely to contain a larger number of examples from each group, we expect our updates to be less noisy.

We create separate features, labels and groups tensors to hold the minibatches from the two streams.

# Set random seed.

# Features tensors.
batch_shape = (hparams["batch_size"], hparams['max_sequence_length'])
features_tensor = tf.Variable(np.zeros(batch_shape, dtype='int32'), name='x')
features_tensor_sen = tf.Variable(np.zeros(batch_shape, dtype='int32'), name='x_sen')

# Labels tensors.
batch_shape = (hparams["batch_size"], 1)
labels_tensor = tf.Variable(np.zeros(batch_shape, dtype='float32'), name='labels')
labels_tensor_sen = tf.Variable(np.zeros(batch_shape, dtype='float32'), name='labels_sen')

# Groups tensors.
batch_shape = (hparams["batch_size"], num_groups)
groups_tensor_sen = tf.Variable(np.zeros(batch_shape, dtype='float32'), name='groups_sen')

We instantiate a new model, and compute predictions for minibatches from the two streams.

# Create model, and separate prediction functions for the two streams. 
# For the predictions, we use a nullary function returning a Tensor to support eager mode.
model_constrained = create_model()

def predictions():
  return model_constrained(features_tensor)

def predictions_sen():
  return model_constrained(features_tensor_sen)

We then set up a constrained optimization problem with the error rate as the objective and with constraints on the per-group false positive rate.

epsilon = 0.02  # Desired false-positive rate threshold.

# Set up separate contexts for the two minibatch streams.
context = tfco.rate_context(predictions, lambda:labels_tensor)
context_sen = tfco.rate_context(predictions_sen, lambda:labels_tensor_sen)

# Compute the objective using the first stream.
objective = tfco.error_rate(context)

# Compute the constraint using the second stream.
# Subset the examples belonging to the "sexuality" group from the second stream 
# and add a constraint on the group's false positive rate.
context_sen_subset = context_sen.subset(lambda: groups_tensor_sen[:, 0] > 0)
constraint = [tfco.false_positive_rate(context_sen_subset) <= epsilon]

# Create a rate minimization problem.
problem = tfco.RateMinimizationProblem(objective, constraint)

# Set up a constrained optimizer.
optimizer = tfco.ProxyLagrangianOptimizerV2(

# List of variables to optimize include the model weights, 
# and the trainable variables from the rate minimization problem and 
# the constrained optimizer.
var_list = (model_constrained.trainable_weights + list(problem.trainable_variables) +

We are ready to train the model. We maintain a separate counter for the two minibatch streams. Every time we perform a gradient update, we will have to copy the minibatch contents from the first stream to the tensors features_tensor and labels_tensor, and the minibatch contents from the second stream to the tensors features_tensor_sen, labels_tensor_sen and groups_tensor_sen.

# Indices of sensitive group members.
protected_group_indices = np.nonzero(groups_train.sum(axis=1))[0]

num_examples = text_train.shape[0]
num_examples_sen = protected_group_indices.shape[0]
batch_size = hparams["batch_size"]

# Number of steps needed for one epoch over the training sample.
num_steps = int(num_examples / batch_size)

start_time = time.time()

# Loop over minibatches.
for batch_index in range(num_steps):
    # Indices for current minibatch in the first stream.
    batch_indices = np.arange(
        batch_index * batch_size, (batch_index + 1) * batch_size)
    batch_indices = [ind % num_examples for ind in batch_indices]

    # Indices for current minibatch in the second stream.
    batch_indices_sen = np.arange(
        batch_index * batch_size, (batch_index + 1) * batch_size)
    batch_indices_sen = [protected_group_indices[ind % num_examples_sen]
                         for ind in batch_indices_sen]

    # Assign features, labels, groups from the minibatches to the respective tensors.
    features_tensor.assign(text_train[batch_indices, :])

    features_tensor_sen.assign(text_train[batch_indices_sen, :])
    groups_tensor_sen.assign(groups_train[batch_indices_sen, :])

    # Gradient update.
    optimizer.minimize(problem, var_list=var_list)

    # Record and print batch training stats every 10 steps.
    if (batch_index + 1) % 10 == 0 or batch_index in (0, num_steps - 1):
      hinge_loss = problem.objective()
      max_violation = max(problem.constraints())

      elapsed_time = time.time() - start_time
          "\rStep %d / %d: Elapsed time = %ds, Loss = %.3f, Violation = %.3f" % 
          (batch_index + 1, num_steps, elapsed_time, hinge_loss, max_violation))
Step 747 / 747: Elapsed time = 69s, Loss = 0.076, Violation = -0.020

Having trained the constrained model, we plot various evaluation metrics for the model on the test set.

scores_constrained_test = model_constrained.predict(text_test)
eval_result_constrained = evaluate_results(
    labels_test, scores_constrained_test, groups_test, group_names)
996/996 [==============================] - 2s 2ms/step
WARNING:root:Make sure that locally built Python SDK docker image has Python 3.7 interpreter.
INFO:tensorflow:ExampleCount post export metric: could not find any of the standard keys in predictions_dict (keys were: dict_keys(['prediction']))
INFO:tensorflow:ExampleCount post export metric: could not find any of the standard keys in predictions_dict (keys were: dict_keys(['prediction']))
INFO:tensorflow:Using the first key from predictions_dict: prediction
INFO:tensorflow:Using the first key from predictions_dict: prediction 1 existing files in target path matching: -*-of-%(num_shards)05d shards found to finalize. num_shards: 1, skipped: 1 1 existing files in target path matching:

As with last time, remember to select false_positive_rate.

plot_fairness_indicators(eval_result_constrained, "Constrained")
multi_results = {

As we can see from the Fairness Indicators, compared to the unconstrained model the constrained model yields significantly lower false positive rates for the sexuality-related comments, and does so with only a slight dip in the overall accuracy.