![]() |
![]() |
![]() |
This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.
This tutorial includes runnable code implemented using tf.keras and eager execution. The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt "Q":
QUEENE: I had thought thou hadst a Roman; for the oracle, Thus by All bids the man against the word, Which are so weak of care, by old care done; Your children were in your holy love, And the precipitation through the bleeding throne. BISHOP OF ELY: Marry, and will, my lord, to weep in such a one were prettiest; Yet now I was adopted heir Of the world's lamentable day, To watch the next way with his father with his face? ESCALUS: The cause why then we are all resolved more sons. VOLUMNIA: O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead, And love and pale as any will to that word. QUEEN ELIZABETH: But how long have I heard the soul for this world, And show his hands of life be proved to stand. PETRUCHIO: I say he look'd on, if I must be content To stay him from the fatal of our country's bliss. His lordship pluck'd from this sentence then for prey, And then let us twain, being the moon, were she such a case as fills m
While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:
The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.
The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.
As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.
Setup
Import TensorFlow and other libraries
import tensorflow as tf
from tensorflow.keras.layers.experimental import preprocessing
import numpy as np
import os
import time
Download the Shakespeare dataset
Change the following line to run this code on your own data.
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt 1122304/1115394 [==============================] - 0s 0us/step
Read the data
First, look in the text:
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')
Length of text: 1115394 characters
# Take a look at the first 250 characters in text
print(text[:250])
First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people.
# The unique characters in the file
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')
65 unique characters
Process the text
Vectorize the text
Before training, you need to convert the strings to a numerical representation.
The preprocessing.StringLookup
layer can convert each character into a numeric ID. It just needs the text to be split into tokens first.
example_texts = ['abcdefg', 'xyz']
chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>
Now create the preprocessing.StringLookup
layer:
ids_from_chars = preprocessing.StringLookup(
vocabulary=list(vocab))
It converts form tokens to character IDs, padding with 0
:
ids = ids_from_chars(chars)
ids
<tf.RaggedTensor [[41, 42, 43, 44, 45, 46, 47], [64, 65, 66]]>
Since the goal of this tutorial is to generate text, it will also be important to invert this representation and recover human-readable strings from it. For this you can use preprocessing.StringLookup(..., invert=True)
.
chars_from_ids = tf.keras.layers.experimental.preprocessing.StringLookup(
vocabulary=ids_from_chars.get_vocabulary(), invert=True)
This layer recovers the characters from the vectors of IDs, and returns them as a tf.RaggedTensor
of characters:
chars = chars_from_ids(ids)
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>
You can tf.strings.reduce_join
to join the characters back into strings.
tf.strings.reduce_join(chars, axis=-1).numpy()
array([b'abcdefg', b'xyz'], dtype=object)
def text_from_ids(ids):
return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)
The prediction task
Given a character, or a sequence of characters, what is the most probable next character? This is the task you're training the model to perform. The input to the model will be a sequence of characters, and you train the model to predict the output—the following character at each time step.
Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?
Create training examples and targets
Next divide the text into example sequences. Each input sequence will contain seq_length
characters from the text.
For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.
So break the text into chunks of seq_length+1
. For example, say seq_length
is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".
To do this first use the tf.data.Dataset.from_tensor_slices
function to convert the text vector into a stream of character indices.
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids
<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([20, 49, 58, ..., 47, 10, 2])>
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
print(chars_from_ids(ids).numpy().decode('utf-8'))
F i r s t C i t i
seq_length = 100
examples_per_epoch = len(text)//(seq_length+1)
The batch
method lets you easily convert these individual characters to sequences of the desired size.
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)
for seq in sequences.take(1):
print(chars_from_ids(seq))
tf.Tensor( [b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o' b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h' b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e' b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y' b'o' b'u' b' '], shape=(101,), dtype=string)
It's easier to see what this is doing if you join the tokens back into strings:
for seq in sequences.take(5):
print(text_from_ids(seq).numpy())
b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou ' b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k' b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki" b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d" b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'
For training you'll need a dataset of (input, label)
pairs. Where input
and
label
are sequences. At each time step the input is the current character and the label is the next character.
Here's a function that takes a sequence as input, duplicates, and shifts it to align the input and label for each timestep:
def split_input_target(sequence):
input_text = sequence[:-1]
target_text = sequence[1:]
return input_text, target_text
split_input_target(list("Tensorflow"))
(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'], ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
print("Input :", text_from_ids(input_example).numpy())
print("Target:", text_from_ids(target_example).numpy())
Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou' Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
Create training batches
You used tf.data
to split the text into manageable sequences. But before feeding this data into the model, you need to shuffle the data and pack it into batches.
# Batch size
BATCH_SIZE = 64
# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000
dataset = (
dataset
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
dataset
<PrefetchDataset shapes: ((64, 100), (64, 100)), types: (tf.int64, tf.int64)>
Build The Model
This section defines the model as a keras.Model
subclass (For details see Making new Layers and Models via subclassing).
This model has three layers:
tf.keras.layers.Embedding
: The input layer. A trainable lookup table that will map each character-ID to a vector withembedding_dim
dimensions;tf.keras.layers.GRU
: A type of RNN with sizeunits=rnn_units
(You can also use an LSTM layer here.)tf.keras.layers.Dense
: The output layer, withvocab_size
outputs. It outputs one logit for each character in the vocabulary. These are the log-likelihood of each character according to the model.
# Length of the vocabulary in chars
vocab_size = len(vocab)
# The embedding dimension
embedding_dim = 256
# Number of RNN units
rnn_units = 1024
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
model = MyModel(
# Be sure the vocabulary size matches the `StringLookup` layers.
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units)
For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character:
Try the model
Now run the model to see that it behaves as expected.
First check the shape of the output:
for input_example_batch, target_example_batch in dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 67) # (batch_size, sequence_length, vocab_size)
In the above example the sequence length of the input is 100
but the model can be run on inputs of any length:
model.summary()
Model: "my_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) multiple 17152 _________________________________________________________________ gru (GRU) multiple 3938304 _________________________________________________________________ dense (Dense) multiple 68675 ================================================================= Total params: 4,024,131 Trainable params: 4,024,131 Non-trainable params: 0 _________________________________________________________________
To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.
Try it for the first example in the batch:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
This gives us, at each timestep, a prediction of the next character index:
sampled_indices
array([41, 34, 7, 28, 21, 45, 11, 14, 59, 8, 15, 11, 6, 33, 44, 33, 8, 40, 5, 39, 32, 22, 37, 14, 53, 0, 48, 22, 23, 46, 44, 58, 39, 41, 47, 7, 6, 62, 48, 9, 2, 27, 58, 17, 26, 17, 6, 16, 36, 28, 36, 8, 3, 23, 19, 57, 50, 51, 59, 27, 6, 12, 39, 9, 23, 29, 37, 30, 62, 51, 63, 35, 45, 52, 18, 7, 58, 17, 53, 41, 28, 37, 10, 64, 55, 49, 61, 45, 57, 56, 21, 0, 0, 7, 48, 27, 51, 12, 1, 49])
Decode these to see the text predicted by this untrained model:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())
Input: b'seeming man!\nOr ill-beseeming beast in seeming both!\nThou hast amazed me: by my holy order,\nI though' Next Char Predictions: b"aT'NGe3?s,A3&SdS,Z$YRHW?mhHIfdrYag'&vh-\nMrCLC&BVNV, IEqjksM&:Y-IOWPvkwUelD'rCmaNW.xoiueqpG'hMk:[UNK]i"
Train the model
At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character.
Attach an optimizer, and a loss function
The standard tf.keras.losses.sparse_categorical_crossentropy
loss function works in this case because it is applied across the last dimension of the predictions.
Because your model returns logits, you need to set the from_logits
flag.
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_loss = loss(target_example_batch, example_batch_predictions)
mean_loss = example_batch_loss.numpy().mean()
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss: ", mean_loss)
Prediction shape: (64, 100, 67) # (batch_size, sequence_length, vocab_size) Mean loss: 4.20401
A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:
tf.exp(mean_loss).numpy()
66.954285
Configure the training procedure using the tf.keras.Model.compile
method. Use tf.keras.optimizers.Adam
with default arguments and the loss function.
model.compile(optimizer='adam', loss=loss)
Configure checkpoints
Use a tf.keras.callbacks.ModelCheckpoint
to ensure that checkpoints are saved during training:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
Execute the training
To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training.
EPOCHS = 20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/20 172/172 [==============================] - 7s 26ms/step - loss: 3.2441 Epoch 2/20 172/172 [==============================] - 6s 26ms/step - loss: 2.0601 Epoch 3/20 172/172 [==============================] - 5s 26ms/step - loss: 1.7388 Epoch 4/20 172/172 [==============================] - 5s 26ms/step - loss: 1.5589 Epoch 5/20 172/172 [==============================] - 5s 26ms/step - loss: 1.4484 Epoch 6/20 172/172 [==============================] - 5s 26ms/step - loss: 1.3771 Epoch 7/20 172/172 [==============================] - 5s 26ms/step - loss: 1.3199 Epoch 8/20 172/172 [==============================] - 5s 26ms/step - loss: 1.2721 Epoch 9/20 172/172 [==============================] - 5s 26ms/step - loss: 1.2302 Epoch 10/20 172/172 [==============================] - 5s 26ms/step - loss: 1.1853 Epoch 11/20 172/172 [==============================] - 5s 26ms/step - loss: 1.1447 Epoch 12/20 172/172 [==============================] - 5s 26ms/step - loss: 1.0984 Epoch 13/20 172/172 [==============================] - 5s 26ms/step - loss: 1.0547 Epoch 14/20 172/172 [==============================] - 5s 26ms/step - loss: 1.0081 Epoch 15/20 172/172 [==============================] - 5s 26ms/step - loss: 0.9530 Epoch 16/20 172/172 [==============================] - 5s 26ms/step - loss: 0.9024 Epoch 17/20 172/172 [==============================] - 6s 26ms/step - loss: 0.8471 Epoch 18/20 172/172 [==============================] - 5s 26ms/step - loss: 0.7954 Epoch 19/20 172/172 [==============================] - 6s 26ms/step - loss: 0.7426 Epoch 20/20 172/172 [==============================] - 6s 26ms/step - loss: 0.6958
Generate text
The simplest way to generate text with this model is to run it in a loop, and keep track of the model's internal state as you execute it.
Each time you call the model you pass in some text and an internal state. The model returns a prediction for the next character and its new state. Pass the prediction and state back in to continue generating text.
The following makes a single step prediction:
class OneStep(tf.keras.Model):
def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
super().__init__()
self.temperature = temperature
self.model = model
self.chars_from_ids = chars_from_ids
self.ids_from_chars = ids_from_chars
# Create a mask to prevent "" or "[UNK]" from being generated.
skip_ids = self.ids_from_chars(['', '[UNK]'])[:, None]
sparse_mask = tf.SparseTensor(
# Put a -inf at each bad index.
values=[-float('inf')]*len(skip_ids),
indices=skip_ids,
# Match the shape to the vocabulary
dense_shape=[len(ids_from_chars.get_vocabulary())])
self.prediction_mask = tf.sparse.to_dense(sparse_mask)
@tf.function
def generate_one_step(self, inputs, states=None):
# Convert strings to token IDs.
input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
input_ids = self.ids_from_chars(input_chars).to_tensor()
# Run the model.
# predicted_logits.shape is [batch, char, next_char_logits]
predicted_logits, states = self.model(inputs=input_ids, states=states,
return_state=True)
# Only use the last prediction.
predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/self.temperature
# Apply the prediction mask: prevent "" or "[UNK]" from being generated.
predicted_logits = predicted_logits + self.prediction_mask
# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)
# Convert from token ids to characters
predicted_chars = self.chars_from_ids(predicted_ids)
# Return the characters and model state.
return predicted_chars, states
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)
Run it in a loop to generate some text. Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.
start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]
for n in range(1000):
next_char, states = one_step_model.generate_one_step(next_char, states=states)
result.append(next_char)
result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)
ROMEO: The poxio's blows, no: bitter and call'd up Angelo. ELBOW: With all my heart the other through what they would were, Or else a knee each parks of rage, To undertake the truth of person will command. Now, by the imjured both yield three vantage. Now, by my state I should accuse me, and I will make a doum with the best on the deed. 'Tis numbed glad I break no other from her death: in arms Between'd their hearts, and fled to sleep no greater than enemy Where I have subjects for your death: Therefore they fall in substitute black,-- As I, Jove large and true mine adversaries: Make full obsenve for a quarrel or us, Elves a quarter old. Hold, then farewell. BENVOLIO: In faith, be most agreed, and want that hang'd our guest? KING RICHARD II: Why, Buckingham, be there be with you. LUCIO: 'Tis better ord, I follow it. CAMILLO: Swear you, poor soul, O valiant point on eightee out goodness, threatening stock Against black-parting 'love and gainsay the extremest warrior, nays suffer as thy n ________________________________________________________________________________ Run time: 2.4890763759613037
The easiest thing you can do to improve the results is to train it for longer (try EPOCHS = 30
).
You can also experiment with a different start string, try adding another RNN layer to improve the model's accuracy, or adjust the temperature parameter to generate more or less random predictions.
If you want the model to generate text faster the easiest thing you can do is batch the text generation. In the example below the model generates 5 outputs in about the same time it took to generate 1 above.
start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]
for n in range(1000):
next_char, states = one_step_model.generate_one_step(next_char, states=states)
result.append(next_char)
result = tf.strings.join(result)
end = time.time()
print(result, '\n\n' + '_'*80)
print('\nRun time:', end - start)
tf.Tensor( [b"ROMEO:\nGood even, my lord.\n\nGoNZALO:\nI go. Now now be gold! I mean again.\n\nROMEO:\nI warrant thee, tell me, be sure the duke and all these walls\nRichard our cadismed; prick it not.\n\nBUSHBALYHAM:\nRomeo, the next way homewible.\n\nKING EDWARD IV:\nWilt thou my son, my master is so, nor dangerous\nBy help to noble peerish; whom, and 'tis be.\n\nPOMPEY:\nWhy, but a word moreof? what do yet that seek of his prince:\nMethinks a smell my part there: where are you?\nFor he is all and make you well.\n\nLUCIO:\n\nISABELLA:\nO, weed my banners be so bold to give him o'erwook\nWhich shows me for himself; but we want them,\nBut one seven years so cries, and thus with us\nStill 'tis cartly as't, we will consumed\nWhile Hereford, and my father and thy gentle senseboly,\nWhich never lived nor rage debase more bellies 'by;\nBut I'll Clarence and the lord. How far infirmity!\nWhat, will you gave her? Phile tyree, grieve now?\nHere, merry mistracious Warwick's daughter,\nO wonder is it like to a chair with sadless indeed.\nYour fresh " b"ROMEO:\nPardon is in my tent these two courses of the king.\n\nNORTHUMBERLAND:\nThere on his well-said i' the world; and brick-sprain'd up\nThy great Bolingbroke days be ruled by me.\n\nLord Mayor:\nWhat good mad hat the lords of Clarence! Well a wench!\nShould a lord wader, by Adibe to't.\n\nMENENIUS:\nWell, I beseech you,\nBy charing home and you my sir were not to know.\n\nNurse:\nYour will should knock you were not weeping\nTill he hath prosper best of all prace;\nThe proud issue with her soul flish wounds and realm; in\na leaden all the very penitence, if the\npeofle sworn, I'll crave the woroneth of my true dear kinds.\n\nKING EDWARD IV:\nThen love, as it was, but most proofing and smiles,\n'Think of wine and called by Bland,\nBut such a year and aguments. I'll to my true opinion!'\nAnd thinkest serve awhile to humbly brother?\n\nKING EDWARD IV:\nSent thou draw me.\nThere is no virtue, or that hear that\nto tell us our more. But Your dun ratis and Duke of Clarence weeds.\nI will tell you would not proud-heads to stan" b"ROMEO:\nYour reasons are my fairly queen?\nBetter on that son should fled to seet in heaten;\nI will prove a sweeter blood in his under well;\nTo-morrow must I tell the counterfeitns:\nAgainst the early tide that we will weep,\nBut raimould by hundry fortune and flowers.\nSups, I would think up:\nMy life drowned, untimely brought\nWhen the abound son of peeresp, nob life,\nAs is the linent bunky way to sweat: alas!\nWhy should have I with our hands, now death.\nHere in your mother?\n\nSerdnever:\nBe it so, for the lurs; whose love I had guest\nWas mutind ours: sweet Kate to meet it.\n\nGLOUCESTER:\nHarp! mark me.\nI am about me: our generance take then\nyou are like to: no man but beasts he did she be\nEmburatement, look into a hell met,\nFoe, like an your quit my disposition,\nAgainst them, fought with you.\nAh, what say'st thou? Camillo, pace me, repeat:\nMany guilty clouds and not of Mine; and it\nAre come another person. My queen's husband and my soveriness\nTo bring his daughter to my sink in mide of death\nI'ed so" b"ROMEO:\nNay, if thy wits have all for gravenced, so husband!\nAh, I by leave, and lugh his pleasures to the fair\nShould, suck any coward as the sun.\n\nAUTOLYCUS:\nHence!\n\nFLORIZEL:\nShould I.\nGo to, a bride that he against the deed.\n\nHMORE ETWARD IV:\nWell, Clifford, tell me, how believe hed,\nAnd bid her fail that he did seem to bed,\nWho look'd for Rome is next dewivers.\n\nANHERIO:\nIf so, or at a bowling bovel? Menenius!\n\nAll:\nCall that greater say he till then be spent myself:\nbut I shall go see the lonal bub tybalt from the cause,\nAnd back'd, as when we should hear me in my cold\nPost-take of her own sword, you would susprome them,\nwife, too weak deaf, and the king post to the senate:\nAcquaint her life, against what shot fares,\nNot that in the iced for their glift enemies,\nyet, if I would entreat it folly, my rights are all\ndearly against Their and all aforements:\n'Twas more than meant to say\nthat I am going. Metum's monstrous town;\nAnd not means yet; so weep; for I intend joy,\nand when I wand to " b"ROMEO:\nSyop wherein, the Fater, there; and with the land\nIn soldiers ributances, was so gentleman:\nher heavens have fallen out with all the world were butcher'd.\nCome, come, King Edward to his soul of mine.\n\nKING RICHARD III:\nSandly, were there will not be distinguingly.\n\nDUKE VINCENTIO:\nI know with me: in this be steaded in the king.\n\nDUCHESS OF YORK:\nWhat he's with sweat? is it a lawd we stand alone?\n\nPedant:\nHencefactors, traitors! From Ploucester's death.\nMeantime the hope to growl she were they all,\nTrue straight from the world and all the ship special sport.\n\nBIONDELLO:\nWhy, hear you, while's a spain were as tubority?\n\nBUSHY:\nDenolaculate, a vessed gunecome a moturer?\n\nPETRUCHIO:\nWhy, I thank your wanton alteration: my manory have\nlabourer out;\nEven love, and, suppine, a vessel, ever stood doth limb.\n\nLORD ROSS:\nThe senate and your crothes slow, then in praise.\n\nISABELLA:\nI warrant him, and keep thy name; And now I mean our father's sake.\n\nPETRUCHIO:\nA man hath besied long-'janion's wi"], shape=(5,), dtype=string) ________________________________________________________________________________ Run time: 2.3512351512908936
Export the generator
This single-step model can easily be saved and restored, allowing you to use it anywhere a tf.saved_model
is accepted.
tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.OneStep object at 0x7f926497ff28>, because it is not built. WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading. WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses, gru_cell_layer_call_and_return_conditional_losses while saving (showing 5 of 5). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: one_step/assets INFO:tensorflow:Assets written to: one_step/assets
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]
for n in range(100):
next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
result.append(next_char)
print(tf.strings.join(result)[0].numpy().decode("utf-8"))
WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7f91e02ac730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:5 out of the last 5 calls to <function recreate_function.<locals>.restored_function_body at 0x7f91e02ac730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:6 out of the last 6 calls to <function recreate_function.<locals>.restored_function_body at 0x7f91e02ac730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. WARNING:tensorflow:6 out of the last 6 calls to <function recreate_function.<locals>.restored_function_body at 0x7f91e02ac730> triggered tf.function retracing. Tracing is expensive and the excessive number of tracings could be due to (1) creating @tf.function repeatedly in a loop, (2) passing tensors with different shapes, (3) passing Python objects instead of tensors. For (1), please define your @tf.function outside of the loop. For (2), @tf.function has experimental_relax_shapes=True option that relaxes argument shapes that can avoid unnecessary retracing. For (3), please refer to https://www.tensorflow.org/guide/function#controlling_retracing and https://www.tensorflow.org/api_docs/python/tf/function for more details. ROMEO: HERMIONE: Why, what man? and say you will. WARWICK: It will be safer than thee, and next be debt:
Advanced: Customized Training
The above training procedure is simple, but does not give you much control. It uses teacher-forcing which prevents bad predictions from being fed back to the model, so the model never learns to recover from mistakes.
So now that you've seen how to run the model manually next you'll implement the training loop. This gives a starting point if, for example, you want to implement curriculum learning to help stabilize the model's open-loop output.
The most important part of a custom training loop is the train step function.
Use tf.GradientTape
to track the gradients. You can learn more about this approach by reading the eager execution guide.
The basic procedure is:
- Execute the model and calculate the loss under a
tf.GradientTape
. - Calculate the updates and apply them to the model using the optimizer.
class CustomTraining(MyModel):
@tf.function
def train_step(self, inputs):
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True)
loss = self.loss(labels, predictions)
grads = tape.gradient(loss, model.trainable_variables)
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
return {'loss': loss}
The above implementation of the train_step
method follows Keras' train_step
conventions. This is optional, but it allows you to change the behavior of the train step and still use keras' Model.compile
and Model.fit
methods.
model = CustomTraining(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units)
model.compile(optimizer = tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(dataset, epochs=1)
172/172 [==============================] - 16s 81ms/step - loss: 2.7281 <tensorflow.python.keras.callbacks.History at 0x7f918c6fb400>
Or if you need more control, you can write your own complete custom training loop:
EPOCHS = 10
mean = tf.metrics.Mean()
for epoch in range(EPOCHS):
start = time.time()
mean.reset_states()
for (batch_n, (inp, target)) in enumerate(dataset):
logs = model.train_step([inp, target])
mean.update_state(logs['loss'])
if batch_n % 50 == 0:
template = f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}"
print(template)
# saving (checkpoint) the model every 5 epochs
if (epoch + 1) % 5 == 0:
model.save_weights(checkpoint_prefix.format(epoch=epoch))
print()
print(f'Epoch {epoch+1} Loss: {mean.result().numpy():.4f}')
print(f'Time taken for 1 epoch {time.time() - start:.2f} sec')
print("_"*80)
model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 2.1604 Epoch 1 Batch 50 Loss 2.0819 Epoch 1 Batch 100 Loss 2.0088 Epoch 1 Batch 150 Loss 1.9042 Epoch 1 Loss: 2.0039 Time taken for 1 epoch 15.29 sec ________________________________________________________________________________ Epoch 2 Batch 0 Loss 1.8115 Epoch 2 Batch 50 Loss 1.7712 Epoch 2 Batch 100 Loss 1.6842 Epoch 2 Batch 150 Loss 1.6678 Epoch 2 Loss: 1.7288 Time taken for 1 epoch 14.73 sec ________________________________________________________________________________ Epoch 3 Batch 0 Loss 1.5976 Epoch 3 Batch 50 Loss 1.6310 Epoch 3 Batch 100 Loss 1.5008 Epoch 3 Batch 150 Loss 1.5508 Epoch 3 Loss: 1.5654 Time taken for 1 epoch 14.79 sec ________________________________________________________________________________ Epoch 4 Batch 0 Loss 1.5221 Epoch 4 Batch 50 Loss 1.4466 Epoch 4 Batch 100 Loss 1.4530 Epoch 4 Batch 150 Loss 1.4677 Epoch 4 Loss: 1.4635 Time taken for 1 epoch 14.78 sec ________________________________________________________________________________ Epoch 5 Batch 0 Loss 1.4302 Epoch 5 Batch 50 Loss 1.4034 Epoch 5 Batch 100 Loss 1.4557 Epoch 5 Batch 150 Loss 1.4137 Epoch 5 Loss: 1.3943 Time taken for 1 epoch 15.19 sec ________________________________________________________________________________ Epoch 6 Batch 0 Loss 1.3380 Epoch 6 Batch 50 Loss 1.3404 Epoch 6 Batch 100 Loss 1.3174 Epoch 6 Batch 150 Loss 1.3430 Epoch 6 Loss: 1.3400 Time taken for 1 epoch 15.05 sec ________________________________________________________________________________ Epoch 7 Batch 0 Loss 1.3027 Epoch 7 Batch 50 Loss 1.3185 Epoch 7 Batch 100 Loss 1.2899 Epoch 7 Batch 150 Loss 1.2744 Epoch 7 Loss: 1.2955 Time taken for 1 epoch 15.01 sec ________________________________________________________________________________ Epoch 8 Batch 0 Loss 1.1957 Epoch 8 Batch 50 Loss 1.2315 Epoch 8 Batch 100 Loss 1.2380 Epoch 8 Batch 150 Loss 1.2457 Epoch 8 Loss: 1.2549 Time taken for 1 epoch 15.05 sec ________________________________________________________________________________ Epoch 9 Batch 0 Loss 1.2290 Epoch 9 Batch 50 Loss 1.2181 Epoch 9 Batch 100 Loss 1.1872 Epoch 9 Batch 150 Loss 1.2097 Epoch 9 Loss: 1.2160 Time taken for 1 epoch 14.81 sec ________________________________________________________________________________ Epoch 10 Batch 0 Loss 1.1551 Epoch 10 Batch 50 Loss 1.1808 Epoch 10 Batch 100 Loss 1.1605 Epoch 10 Batch 150 Loss 1.1976 Epoch 10 Loss: 1.1772 Time taken for 1 epoch 15.23 sec ________________________________________________________________________________