![]() |
![]() |
![]() |
This text classification tutorial trains a recurrent neural network on the IMDB large movie review dataset for sentiment analysis.
Setup
import numpy as np
import tensorflow_datasets as tfds
import tensorflow as tf
tfds.disable_progress_bar()
Import matplotlib
and create a helper function to plot graphs:
import matplotlib.pyplot as plt
def plot_graphs(history, metric):
plt.plot(history.history[metric])
plt.plot(history.history['val_'+metric], '')
plt.xlabel("Epochs")
plt.ylabel(metric)
plt.legend([metric, 'val_'+metric])
Setup input pipeline
The IMDB large movie review dataset is a binary classification dataset—all the reviews have either a positive or negative sentiment.
Download the dataset using TFDS. See the loading text tutorial for details on how to load this sort of data manually.
dataset, info = tfds.load('imdb_reviews', with_info=True,
as_supervised=True)
train_dataset, test_dataset = dataset['train'], dataset['test']
train_dataset.element_spec
(TensorSpec(shape=(), dtype=tf.string, name=None), TensorSpec(shape=(), dtype=tf.int64, name=None))
Initially this returns a dataset of (text, label pairs):
for example, label in train_dataset.take(1):
print('text: ', example.numpy())
print('label: ', label.numpy())
text: b"This was an absolutely terrible movie. Don't be lured in by Christopher Walken or Michael Ironside. Both are great actors, but this must simply be their worst role in history. Even their great acting could not redeem this movie's ridiculous storyline. This movie is an early nineties US propaganda piece. The most pathetic scenes were those when the Columbian rebels were making their cases for revolutions. Maria Conchita Alonso appeared phony, and her pseudo-love affair with Walken was nothing but a pathetic emotional plug in a movie that was devoid of any real meaning. I am disappointed that there are movies like this, ruining actor's like Christopher Walken's good name. I could barely sit through it." label: 0
Next shuffle the data for training and create batches of these (text, label)
pairs:
BUFFER_SIZE = 10000
BATCH_SIZE = 64
train_dataset = train_dataset.shuffle(BUFFER_SIZE).batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
test_dataset = test_dataset.batch(BATCH_SIZE).prefetch(tf.data.AUTOTUNE)
for example, label in train_dataset.take(1):
print('texts: ', example.numpy()[:3])
print()
print('labels: ', label.numpy()[:3])
texts: [b'I managed to grab a viewing of this with the aid of MST3K, and oh boy, even with the riffing this movie was excruciatingly bad. Imagine someone whose competence with a camera could be out done by a monkey.<br /><br />The highlights (what little there were) came from the special effects, which were "OK". The acting for the most part was also "OK"; though nothing special, it was of a higher quality than other B-Movies I have seen in the past.<br /><br />The rest of this movie is dismally bad, The camera work often looks like they\'ve just put the camera man on roller skates and pushed him along. The story (if it can be called that) is so full of holes it\'s almost funny, It never really explains why the hell he survived in the first place, or needs human flesh in order to survive. The script is poorly written and the dialogue verges on just plane stupid. The climax to movie (if there is one) is absolutely laughable.<br /><br />If you can\'t find the MST3K version, avoid this at all costs.' b'This movie does contradict the first one as far as the origins of the Care Bears and the Care Bear Cousins goes. I won\'t deny that. However, if you look at "Part II" as a separate film, then it\'s a very good movie. I remember watching this in the early 80\'s (and fitting into its targeted demographic audience then), and absolutely loving it much more than the first movie (not that I didn\'t enjoy that one too, it\'s just that this one seemed to have a little something extra to it). Sure it\'s darker than the first one too, but perhaps maybe that\'s why it\'s so good. And it\'s dark in deeper kind of subtle way too (that kids may not fully understand, but could still be a bit scared of because of the atmosphere it gives off, and adults watching will surely get quicker as I have now watching this film again now in my mid-twenties) where you basically have a young girl making a deal with an evil spirit/demon in exchange for something else. Get the picture? But simply watching that as a child, sure as I said it may have been a little scary, but nothing traumatizing. In fact if anything it gave me another fantasy game I could play when I was that age. I can\'t tell you the number of times I used to pretend Dark Heart wanted to imprison me, have me help him capture the Care Bears, tried to make me turn over to his dark side, and other things like that etc. So this movie was also good for my imagination. And it\'s also got great emotional depth to it too. I used to watch it at least once a week.<br /><br />Also Hadley Kay was the perfect choice for the voice of Dark Heart (I always thought so and I always will).<br /><br />Now it\'s just too bad that they never made a soundtrack available. Sometimes I just want to hear Growing Up without watching the movie, as good as it is.<br /><br />"What good is love and caring if it can\'t save her?"' b'Let\'s just say I had to suspend my disbelief less for Spiderman than I did for Hooligans. That is, to say, I have less of a problem believing Toby McGuire can stick to buildings than I do Elija Wood throwing down with toughs in Manchester. I won\'t get into specifics, as I don\'t want to write a spoiler, but the idea of grown, professional, British men getting into near death scraps every weekend is, well... funny. And this film is not. The fighting, the idea of fighting, is taken far too seriously. The gravity of the pugilism, the reverence with which the subject matter is treated becomes irritating, as it neither establishes or resolves the conflict. It seems as though the plot, with holes big enough to drive a Guiness truck through, has been slapped together with a contrived "fish out of water" theme so that viewers can gaze into Woods teary eyes as he learns how to become a man ie. hitting other young men of opposing football tastes with blunt objects and then running away as fast as he can. The characters are cartoonish, especially the Americans at Harvard. The character development and story line are telegraphed to the viewer throughout the picture. Unfortunately, the absurdity of the film doesn\'t reach its height until nearly the end, which by then you\'ll have spent nearly two hours of your life you are never getting back. Pick up "The Football Factory" or "Fight Club" instead of this corny, and disappointing dud. It doesn\'t waste time with empty melodrama, the tired old "Yankee in King Aurthur\'s Court," or weepy, parables of coming of age bullsh*t. They\'re just pure, dark, and clever fun; the way violence is supposed to be.'] labels: [0 1 0]
Create the text encoder
The raw text loaded by tfds
needs to be processed before it can be used in a model. The simplest way to process text for training is using the experimental.preprocessing.TextVectorization
layer. This layer has many capabilities, but this tutorial sticks to the default behavior.
Create the layer, and pass the dataset's text to the layer's .adapt
method:
VOCAB_SIZE = 1000
encoder = tf.keras.layers.experimental.preprocessing.TextVectorization(
max_tokens=VOCAB_SIZE)
encoder.adapt(train_dataset.map(lambda text, label: text))
The .adapt
method sets the layer's vocabulary. Here are the first 20 tokens. After the padding and unknown tokens they're sorted by frequency:
vocab = np.array(encoder.get_vocabulary())
vocab[:20]
array(['', '[UNK]', 'the', 'and', 'a', 'of', 'to', 'is', 'in', 'it', 'i', 'this', 'that', 'br', 'was', 'as', 'for', 'with', 'movie', 'but'], dtype='<U14')
Once the vocabulary is set, the layer can encode text into indices. The tensors of indices are 0-padded to the longest sequence in the batch (unless you set a fixed output_sequence_length
):
encoded_example = encoder(example)[:3].numpy()
encoded_example
array([[ 10, 1, 6, ..., 0, 0, 0], [ 11, 18, 121, ..., 0, 0, 0], [599, 41, 130, ..., 0, 0, 0]])
With the default settings, the process is not completely reversible. There are three main reasons for that:
- The default value for
preprocessing.TextVectorization
'sstandardize
argument is"lower_and_strip_punctuation"
. - The limited vocabulary size and lack of character-based fallback results in some unknown tokens.
for n in range(3):
print("Original: ", example[n].numpy())
print("Round-trip: ", " ".join(vocab[encoded_example[n]]))
print()
Original: b'I managed to grab a viewing of this with the aid of MST3K, and oh boy, even with the riffing this movie was excruciatingly bad. Imagine someone whose competence with a camera could be out done by a monkey.<br /><br />The highlights (what little there were) came from the special effects, which were "OK". The acting for the most part was also "OK"; though nothing special, it was of a higher quality than other B-Movies I have seen in the past.<br /><br />The rest of this movie is dismally bad, The camera work often looks like they\'ve just put the camera man on roller skates and pushed him along. The story (if it can be called that) is so full of holes it\'s almost funny, It never really explains why the hell he survived in the first place, or needs human flesh in order to survive. The script is poorly written and the dialogue verges on just plane stupid. The climax to movie (if there is one) is absolutely laughable.<br /><br />If you can\'t find the MST3K version, avoid this at all costs.' Round-trip: i [UNK] to [UNK] a viewing of this with the [UNK] of [UNK] and oh boy even with the [UNK] this movie was [UNK] bad imagine someone whose [UNK] with a camera could be out done by a [UNK] br the [UNK] what little there were came from the special effects which were ok the acting for the most part was also ok though nothing special it was of a [UNK] quality than other [UNK] i have seen in the [UNK] br the rest of this movie is [UNK] bad the camera work often looks like [UNK] just put the camera man on [UNK] [UNK] and [UNK] him along the story if it can be called that is so full of [UNK] its almost funny it never really [UNK] why the hell he [UNK] in the first place or needs human [UNK] in order to [UNK] the script is poorly written and the dialogue [UNK] on just [UNK] stupid the [UNK] to movie if there is one is absolutely [UNK] br if you cant find the [UNK] version avoid this at all [UNK] Original: b'This movie does contradict the first one as far as the origins of the Care Bears and the Care Bear Cousins goes. I won\'t deny that. However, if you look at "Part II" as a separate film, then it\'s a very good movie. I remember watching this in the early 80\'s (and fitting into its targeted demographic audience then), and absolutely loving it much more than the first movie (not that I didn\'t enjoy that one too, it\'s just that this one seemed to have a little something extra to it). Sure it\'s darker than the first one too, but perhaps maybe that\'s why it\'s so good. And it\'s dark in deeper kind of subtle way too (that kids may not fully understand, but could still be a bit scared of because of the atmosphere it gives off, and adults watching will surely get quicker as I have now watching this film again now in my mid-twenties) where you basically have a young girl making a deal with an evil spirit/demon in exchange for something else. Get the picture? But simply watching that as a child, sure as I said it may have been a little scary, but nothing traumatizing. In fact if anything it gave me another fantasy game I could play when I was that age. I can\'t tell you the number of times I used to pretend Dark Heart wanted to imprison me, have me help him capture the Care Bears, tried to make me turn over to his dark side, and other things like that etc. So this movie was also good for my imagination. And it\'s also got great emotional depth to it too. I used to watch it at least once a week.<br /><br />Also Hadley Kay was the perfect choice for the voice of Dark Heart (I always thought so and I always will).<br /><br />Now it\'s just too bad that they never made a soundtrack available. Sometimes I just want to hear Growing Up without watching the movie, as good as it is.<br /><br />"What good is love and caring if it can\'t save her?"' Round-trip: this movie does [UNK] the first one as far as the [UNK] of the care [UNK] and the care [UNK] [UNK] goes i wont [UNK] that however if you look at part [UNK] as a [UNK] film then its a very good movie i remember watching this in the early 80s and [UNK] into its [UNK] [UNK] audience then and absolutely [UNK] it much more than the first movie not that i didnt enjoy that one too its just that this one seemed to have a little something [UNK] to it sure its [UNK] than the first one too but perhaps maybe thats why its so good and its dark in [UNK] kind of [UNK] way too that kids may not [UNK] understand but could still be a bit [UNK] of because of the atmosphere it gives off and [UNK] watching will [UNK] get [UNK] as i have now watching this film again now in my [UNK] where you basically have a young girl making a deal with an evil [UNK] in [UNK] for something else get the picture but simply watching that as a child sure as i said it may have been a little scary but nothing [UNK] in fact if anything it gave me another fantasy game i could play when i was that age i cant tell you the number of times i used to [UNK] dark heart wanted to [UNK] me have me help him [UNK] the care [UNK] tried to make me turn over to his dark side and other things like that etc so this movie was also good for my [UNK] and its also got great emotional [UNK] to it too i used to watch it at least once a [UNK] br also [UNK] [UNK] was the perfect [UNK] for the voice of dark heart i always thought so and i always [UNK] br now its just too bad that they never made a soundtrack [UNK] sometimes i just want to hear [UNK] up without watching the movie as good as it [UNK] br what good is love and [UNK] if it cant save her Original: b'Let\'s just say I had to suspend my disbelief less for Spiderman than I did for Hooligans. That is, to say, I have less of a problem believing Toby McGuire can stick to buildings than I do Elija Wood throwing down with toughs in Manchester. I won\'t get into specifics, as I don\'t want to write a spoiler, but the idea of grown, professional, British men getting into near death scraps every weekend is, well... funny. And this film is not. The fighting, the idea of fighting, is taken far too seriously. The gravity of the pugilism, the reverence with which the subject matter is treated becomes irritating, as it neither establishes or resolves the conflict. It seems as though the plot, with holes big enough to drive a Guiness truck through, has been slapped together with a contrived "fish out of water" theme so that viewers can gaze into Woods teary eyes as he learns how to become a man ie. hitting other young men of opposing football tastes with blunt objects and then running away as fast as he can. The characters are cartoonish, especially the Americans at Harvard. The character development and story line are telegraphed to the viewer throughout the picture. Unfortunately, the absurdity of the film doesn\'t reach its height until nearly the end, which by then you\'ll have spent nearly two hours of your life you are never getting back. Pick up "The Football Factory" or "Fight Club" instead of this corny, and disappointing dud. It doesn\'t waste time with empty melodrama, the tired old "Yankee in King Aurthur\'s Court," or weepy, parables of coming of age bullsh*t. They\'re just pure, dark, and clever fun; the way violence is supposed to be.' Round-trip: lets just say i had to [UNK] my [UNK] less for [UNK] than i did for [UNK] that is to say i have less of a problem [UNK] [UNK] [UNK] can [UNK] to [UNK] than i do [UNK] [UNK] [UNK] down with [UNK] in [UNK] i wont get into [UNK] as i dont want to write a [UNK] but the idea of [UNK] [UNK] british men getting into near death [UNK] every [UNK] is well funny and this film is not the fighting the idea of fighting is taken far too seriously the [UNK] of the [UNK] the [UNK] with which the subject matter is [UNK] becomes [UNK] as it [UNK] [UNK] or [UNK] the [UNK] it seems as though the plot with [UNK] big enough to [UNK] a [UNK] [UNK] through has been [UNK] together with a [UNK] [UNK] out of [UNK] theme so that viewers can [UNK] into [UNK] [UNK] eyes as he [UNK] how to become a man [UNK] [UNK] other young men of [UNK] [UNK] [UNK] with [UNK] [UNK] and then running away as fast as he can the characters are [UNK] especially the [UNK] at [UNK] the character development and story line are [UNK] to the viewer throughout the picture unfortunately the [UNK] of the film doesnt [UNK] its [UNK] until nearly the end which by then youll have [UNK] nearly two hours of your life you are never getting back [UNK] up the [UNK] [UNK] or fight [UNK] instead of this [UNK] and [UNK] [UNK] it doesnt waste time with [UNK] [UNK] the [UNK] old [UNK] in king [UNK] [UNK] or [UNK] [UNK] of coming of age [UNK] theyre just [UNK] dark and [UNK] fun the way violence is supposed to be
Create the model
Above is a diagram of the model.
This model can be build as a
tf.keras.Sequential
.The first layer is the
encoder
, which converts the text to a sequence of token indices.After the encoder is an embedding layer. An embedding layer stores one vector per word. When called, it converts the sequences of word indices to sequences of vectors. These vectors are trainable. After training (on enough data), words with similar meanings often have similar vectors.
This index-lookup is much more efficient than the equivalent operation of passing a one-hot encoded vector through a
tf.keras.layers.Dense
layer.A recurrent neural network (RNN) processes sequence input by iterating through the elements. RNNs pass the outputs from one timestep to their input on the next timestep.
The
tf.keras.layers.Bidirectional
wrapper can also be used with an RNN layer. This propagates the input forward and backwards through the RNN layer and then concatenates the final output.The main advantage of a bidirectional RNN is that the signal from the beginning of the input doesn't need to be processed all the way through every timestep to affect the output.
The main disadvantage of a bidirectional RNN is that you can't efficiently stream predictions as words are being added to the end.
After the RNN has converted the sequence to a single vector the two
layers.Dense
do some final processing, and convert from this vector representation to a single logit as the classification output.
The code to implement this is below:
model = tf.keras.Sequential([
encoder,
tf.keras.layers.Embedding(
input_dim=len(encoder.get_vocabulary()),
output_dim=64,
# Use masking to handle the variable sequence lengths
mask_zero=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dense(1)
])
Please note that Keras sequential model is used here since all the layers in the model only have single input and produce single output. In case you want to use stateful RNN layer, you might want to build your model with Keras functional API or model subclassing so that you can retrieve and reuse the RNN layer states. Please check Keras RNN guide for more details.
The embedding layer uses masking to handle the varying sequence-lengths. All the layers after the Embedding
support masking:
print([layer.supports_masking for layer in model.layers])
[False, True, True, True, True]
To confirm that this works as expected, evaluate a sentence twice. First, alone so there's no padding to mask:
# predict on a sample text without padding.
sample_text = ('The movie was cool. The animation and the graphics '
'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions[0])
[0.00256683]
Now, evaluate it again in a batch with a longer sentence. The result should be identical:
# predict on a sample text with padding
padding = "the " * 2000
predictions = model.predict(np.array([sample_text, padding]))
print(predictions[0])
[0.00256683]
Compile the Keras model to configure the training process:
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
Train the model
history = model.fit(train_dataset, epochs=10,
validation_data=test_dataset,
validation_steps=30)
Epoch 1/10 391/391 [==============================] - 46s 99ms/step - loss: 0.6837 - accuracy: 0.5113 - val_loss: 0.4732 - val_accuracy: 0.7500 Epoch 2/10 391/391 [==============================] - 36s 91ms/step - loss: 0.4369 - accuracy: 0.7796 - val_loss: 0.3933 - val_accuracy: 0.8026 Epoch 3/10 391/391 [==============================] - 36s 90ms/step - loss: 0.3665 - accuracy: 0.8418 - val_loss: 0.3396 - val_accuracy: 0.8448 Epoch 4/10 391/391 [==============================] - 35s 88ms/step - loss: 0.3245 - accuracy: 0.8615 - val_loss: 0.3298 - val_accuracy: 0.8526 Epoch 5/10 391/391 [==============================] - 36s 86ms/step - loss: 0.3090 - accuracy: 0.8687 - val_loss: 0.3259 - val_accuracy: 0.8531 Epoch 6/10 391/391 [==============================] - 33s 82ms/step - loss: 0.3053 - accuracy: 0.8693 - val_loss: 0.3231 - val_accuracy: 0.8531 Epoch 7/10 391/391 [==============================] - 35s 83ms/step - loss: 0.3100 - accuracy: 0.8666 - val_loss: 0.3243 - val_accuracy: 0.8604 Epoch 8/10 391/391 [==============================] - 35s 88ms/step - loss: 0.2963 - accuracy: 0.8756 - val_loss: 0.3249 - val_accuracy: 0.8578 Epoch 9/10 391/391 [==============================] - 34s 85ms/step - loss: 0.2992 - accuracy: 0.8722 - val_loss: 0.3200 - val_accuracy: 0.8562 Epoch 10/10 391/391 [==============================] - 33s 84ms/step - loss: 0.2925 - accuracy: 0.8756 - val_loss: 0.3208 - val_accuracy: 0.8552
test_loss, test_acc = model.evaluate(test_dataset)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 17s 41ms/step - loss: 0.3151 - accuracy: 0.8588 Test Loss: 0.31505945324897766 Test Accuracy: 0.8587999939918518
plt.figure(figsize=(16, 8))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.ylim(None, 1)
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
plt.ylim(0, None)
(0.0, 0.6610833600163459)
Run a prediction on a new sentence:
If the prediction is >= 0.0, it is positive else it is negative.
sample_text = ('The movie was cool. The animation and the graphics '
'were out of this world. I would recommend this movie.')
predictions = model.predict(np.array([sample_text]))
Stack two or more LSTM layers
Keras recurrent layers have two available modes that are controlled by the return_sequences
constructor argument:
If
False
it returns only the last output for each input sequence (a 2D tensor of shape (batch_size, output_features)). This is the default, used in the previous model.If
True
the full sequences of successive outputs for each timestep is returned (a 3D tensor of shape(batch_size, timesteps, output_features)
).
Here is what the flow of information looks like with return_sequences=True
:
The interesting thing about using an RNN
with return_sequences=True
is that the output still has 3-axes, like the input, so it can be passed to another RNN layer, like this:
model = tf.keras.Sequential([
encoder,
tf.keras.layers.Embedding(len(encoder.get_vocabulary()), 64, mask_zero=True),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(64, return_sequences=True)),
tf.keras.layers.Bidirectional(tf.keras.layers.LSTM(32)),
tf.keras.layers.Dense(64, activation='relu'),
tf.keras.layers.Dropout(0.5),
tf.keras.layers.Dense(1)
])
model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
optimizer=tf.keras.optimizers.Adam(1e-4),
metrics=['accuracy'])
history = model.fit(train_dataset, epochs=10,
validation_data=test_dataset,
validation_steps=30)
Epoch 1/10 391/391 [==============================] - 80s 162ms/step - loss: 0.6801 - accuracy: 0.5097 - val_loss: 0.4260 - val_accuracy: 0.8135 Epoch 2/10 391/391 [==============================] - 62s 153ms/step - loss: 0.3985 - accuracy: 0.8262 - val_loss: 0.3529 - val_accuracy: 0.8391 Epoch 3/10 391/391 [==============================] - 60s 151ms/step - loss: 0.3397 - accuracy: 0.8507 - val_loss: 0.3302 - val_accuracy: 0.8500 Epoch 4/10 391/391 [==============================] - 59s 150ms/step - loss: 0.3143 - accuracy: 0.8627 - val_loss: 0.3269 - val_accuracy: 0.8589 Epoch 5/10 391/391 [==============================] - 60s 150ms/step - loss: 0.3124 - accuracy: 0.8649 - val_loss: 0.3248 - val_accuracy: 0.8464 Epoch 6/10 391/391 [==============================] - 60s 151ms/step - loss: 0.3065 - accuracy: 0.8689 - val_loss: 0.3403 - val_accuracy: 0.8448 Epoch 7/10 391/391 [==============================] - 56s 141ms/step - loss: 0.2983 - accuracy: 0.8753 - val_loss: 0.3165 - val_accuracy: 0.8495 Epoch 8/10 391/391 [==============================] - 56s 142ms/step - loss: 0.3023 - accuracy: 0.8711 - val_loss: 0.3188 - val_accuracy: 0.8557 Epoch 9/10 391/391 [==============================] - 57s 143ms/step - loss: 0.2971 - accuracy: 0.8744 - val_loss: 0.3179 - val_accuracy: 0.8526 Epoch 10/10 391/391 [==============================] - 57s 145ms/step - loss: 0.2964 - accuracy: 0.8724 - val_loss: 0.3252 - val_accuracy: 0.8589
test_loss, test_acc = model.evaluate(test_dataset)
print('Test Loss:', test_loss)
print('Test Accuracy:', test_acc)
391/391 [==============================] - 26s 66ms/step - loss: 0.3233 - accuracy: 0.8618 Test Loss: 0.32332155108451843 Test Accuracy: 0.8618000149726868
# predict on a sample text without padding.
sample_text = ('The movie was not good. The animation and the graphics '
'were terrible. I would not recommend this movie.')
predictions = model.predict(np.array([sample_text]))
print(predictions)
[[-1.8525568]]
plt.figure(figsize=(16, 6))
plt.subplot(1, 2, 1)
plot_graphs(history, 'accuracy')
plt.subplot(1, 2, 2)
plot_graphs(history, 'loss')
Check out other existing recurrent layers such as GRU layers.
If you're interestied in building custom RNNs, see the Keras RNN Guide.