![]() |
![]() |
![]() |
![]() |
This tutorial demonstrates how to generate text using a character-based RNN. You will work with a dataset of Shakespeare's writing from Andrej Karpathy's The Unreasonable Effectiveness of Recurrent Neural Networks. Given a sequence of characters from this data ("Shakespear"), train a model to predict the next character in the sequence ("e"). Longer sequences of text can be generated by calling the model repeatedly.
This tutorial includes runnable code implemented using tf.keras and eager execution. The following is the sample output when the model in this tutorial trained for 30 epochs, and started with the prompt "Q":
QUEENE: I had thought thou hadst a Roman; for the oracle, Thus by All bids the man against the word, Which are so weak of care, by old care done; Your children were in your holy love, And the precipitation through the bleeding throne. BISHOP OF ELY: Marry, and will, my lord, to weep in such a one were prettiest; Yet now I was adopted heir Of the world's lamentable day, To watch the next way with his father with his face? ESCALUS: The cause why then we are all resolved more sons. VOLUMNIA: O, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, no, it is no sin it should be dead, And love and pale as any will to that word. QUEEN ELIZABETH: But how long have I heard the soul for this world, And show his hands of life be proved to stand. PETRUCHIO: I say he look'd on, if I must be content To stay him from the fatal of our country's bliss. His lordship pluck'd from this sentence then for prey, And then let us twain, being the moon, were she such a case as fills m
While some of the sentences are grammatical, most do not make sense. The model has not learned the meaning of words, but consider:
The model is character-based. When training started, the model did not know how to spell an English word, or that words were even a unit of text.
The structure of the output resembles a play—blocks of text generally begin with a speaker name, in all capital letters similar to the dataset.
As demonstrated below, the model is trained on small batches of text (100 characters each), and is still able to generate a longer sequence of text with coherent structure.
Setup
Import TensorFlow and other libraries
import tensorflow as tf
import numpy as np
import os
import time
Download the Shakespeare dataset
Change the following line to run this code on your own data.
path_to_file = tf.keras.utils.get_file('shakespeare.txt', 'https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt')
Downloading data from https://storage.googleapis.com/download.tensorflow.org/data/shakespeare.txt 1115394/1115394 [==============================] - 0s 0us/step
Read the data
First, look in the text:
# Read, then decode for py2 compat.
text = open(path_to_file, 'rb').read().decode(encoding='utf-8')
# length of text is the number of characters in it
print(f'Length of text: {len(text)} characters')
Length of text: 1115394 characters
# Take a look at the first 250 characters in text
print(text[:250])
First Citizen: Before we proceed any further, hear me speak. All: Speak, speak. First Citizen: You are all resolved rather to die than to famish? All: Resolved. resolved. First Citizen: First, you know Caius Marcius is chief enemy to the people.
# The unique characters in the file
vocab = sorted(set(text))
print(f'{len(vocab)} unique characters')
65 unique characters
Process the text
Vectorize the text
Before training, you need to convert the strings to a numerical representation.
The tf.keras.layers.StringLookup
layer can convert each character into a numeric ID. It just needs the text to be split into tokens first.
example_texts = ['abcdefg', 'xyz']
chars = tf.strings.unicode_split(example_texts, input_encoding='UTF-8')
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>
Now create the tf.keras.layers.StringLookup
layer:
ids_from_chars = tf.keras.layers.StringLookup(
vocabulary=list(vocab), mask_token=None)
It converts from tokens to character IDs:
ids = ids_from_chars(chars)
ids
<tf.RaggedTensor [[40, 41, 42, 43, 44, 45, 46], [63, 64, 65]]>
Since the goal of this tutorial is to generate text, it will also be important to invert this representation and recover human-readable strings from it. For this you can use tf.keras.layers.StringLookup(..., invert=True)
.
chars_from_ids = tf.keras.layers.StringLookup(
vocabulary=ids_from_chars.get_vocabulary(), invert=True, mask_token=None)
This layer recovers the characters from the vectors of IDs, and returns them as a tf.RaggedTensor
of characters:
chars = chars_from_ids(ids)
chars
<tf.RaggedTensor [[b'a', b'b', b'c', b'd', b'e', b'f', b'g'], [b'x', b'y', b'z']]>
You can tf.strings.reduce_join
to join the characters back into strings.
tf.strings.reduce_join(chars, axis=-1).numpy()
array([b'abcdefg', b'xyz'], dtype=object)
def text_from_ids(ids):
return tf.strings.reduce_join(chars_from_ids(ids), axis=-1)
The prediction task
Given a character, or a sequence of characters, what is the most probable next character? This is the task you're training the model to perform. The input to the model will be a sequence of characters, and you train the model to predict the output—the following character at each time step.
Since RNNs maintain an internal state that depends on the previously seen elements, given all the characters computed until this moment, what is the next character?
Create training examples and targets
Next divide the text into example sequences. Each input sequence will contain seq_length
characters from the text.
For each input sequence, the corresponding targets contain the same length of text, except shifted one character to the right.
So break the text into chunks of seq_length+1
. For example, say seq_length
is 4 and our text is "Hello". The input sequence would be "Hell", and the target sequence "ello".
To do this first use the tf.data.Dataset.from_tensor_slices
function to convert the text vector into a stream of character indices.
all_ids = ids_from_chars(tf.strings.unicode_split(text, 'UTF-8'))
all_ids
<tf.Tensor: shape=(1115394,), dtype=int64, numpy=array([19, 48, 57, ..., 46, 9, 1])>
ids_dataset = tf.data.Dataset.from_tensor_slices(all_ids)
for ids in ids_dataset.take(10):
print(chars_from_ids(ids).numpy().decode('utf-8'))
F i r s t C i t i
seq_length = 100
The batch
method lets you easily convert these individual characters to sequences of the desired size.
sequences = ids_dataset.batch(seq_length+1, drop_remainder=True)
for seq in sequences.take(1):
print(chars_from_ids(seq))
tf.Tensor( [b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'B' b'e' b'f' b'o' b'r' b'e' b' ' b'w' b'e' b' ' b'p' b'r' b'o' b'c' b'e' b'e' b'd' b' ' b'a' b'n' b'y' b' ' b'f' b'u' b'r' b't' b'h' b'e' b'r' b',' b' ' b'h' b'e' b'a' b'r' b' ' b'm' b'e' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'A' b'l' b'l' b':' b'\n' b'S' b'p' b'e' b'a' b'k' b',' b' ' b's' b'p' b'e' b'a' b'k' b'.' b'\n' b'\n' b'F' b'i' b'r' b's' b't' b' ' b'C' b'i' b't' b'i' b'z' b'e' b'n' b':' b'\n' b'Y' b'o' b'u' b' '], shape=(101,), dtype=string)
It's easier to see what this is doing if you join the tokens back into strings:
for seq in sequences.take(5):
print(text_from_ids(seq).numpy())
b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou ' b'are all resolved rather to die than to famish?\n\nAll:\nResolved. resolved.\n\nFirst Citizen:\nFirst, you k' b"now Caius Marcius is chief enemy to the people.\n\nAll:\nWe know't, we know't.\n\nFirst Citizen:\nLet us ki" b"ll him, and we'll have corn at our own price.\nIs't a verdict?\n\nAll:\nNo more talking on't; let it be d" b'one: away, away!\n\nSecond Citizen:\nOne word, good citizens.\n\nFirst Citizen:\nWe are accounted poor citi'
For training you'll need a dataset of (input, label)
pairs. Where input
and
label
are sequences. At each time step the input is the current character and the label is the next character.
Here's a function that takes a sequence as input, duplicates, and shifts it to align the input and label for each timestep:
def split_input_target(sequence):
input_text = sequence[:-1]
target_text = sequence[1:]
return input_text, target_text
split_input_target(list("Tensorflow"))
(['T', 'e', 'n', 's', 'o', 'r', 'f', 'l', 'o'], ['e', 'n', 's', 'o', 'r', 'f', 'l', 'o', 'w'])
dataset = sequences.map(split_input_target)
for input_example, target_example in dataset.take(1):
print("Input :", text_from_ids(input_example).numpy())
print("Target:", text_from_ids(target_example).numpy())
Input : b'First Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou' Target: b'irst Citizen:\nBefore we proceed any further, hear me speak.\n\nAll:\nSpeak, speak.\n\nFirst Citizen:\nYou '
Create training batches
You used tf.data
to split the text into manageable sequences. But before feeding this data into the model, you need to shuffle the data and pack it into batches.
# Batch size
BATCH_SIZE = 64
# Buffer size to shuffle the dataset
# (TF data is designed to work with possibly infinite sequences,
# so it doesn't attempt to shuffle the entire sequence in memory. Instead,
# it maintains a buffer in which it shuffles elements).
BUFFER_SIZE = 10000
dataset = (
dataset
.shuffle(BUFFER_SIZE)
.batch(BATCH_SIZE, drop_remainder=True)
.prefetch(tf.data.experimental.AUTOTUNE))
dataset
<_PrefetchDataset element_spec=(TensorSpec(shape=(64, 100), dtype=tf.int64, name=None), TensorSpec(shape=(64, 100), dtype=tf.int64, name=None))>
Build The Model
This section defines the model as a keras.Model
subclass (For details see Making new Layers and Models via subclassing).
This model has three layers:
tf.keras.layers.Embedding
: The input layer. A trainable lookup table that will map each character-ID to a vector withembedding_dim
dimensions;tf.keras.layers.GRU
: A type of RNN with sizeunits=rnn_units
(You can also use an LSTM layer here.)tf.keras.layers.Dense
: The output layer, withvocab_size
outputs. It outputs one logit for each character in the vocabulary. These are the log-likelihood of each character according to the model.
# Length of the vocabulary in StringLookup Layer
vocab_size = len(ids_from_chars.get_vocabulary())
# The embedding dimension
embedding_dim = 256
# Number of RNN units
rnn_units = 1024
class MyModel(tf.keras.Model):
def __init__(self, vocab_size, embedding_dim, rnn_units):
super().__init__(self)
self.embedding = tf.keras.layers.Embedding(vocab_size, embedding_dim)
self.gru = tf.keras.layers.GRU(rnn_units,
return_sequences=True,
return_state=True)
self.dense = tf.keras.layers.Dense(vocab_size)
def call(self, inputs, states=None, return_state=False, training=False):
x = inputs
x = self.embedding(x, training=training)
if states is None:
states = self.gru.get_initial_state(x)
x, states = self.gru(x, initial_state=states, training=training)
x = self.dense(x, training=training)
if return_state:
return x, states
else:
return x
model = MyModel(
vocab_size=vocab_size,
embedding_dim=embedding_dim,
rnn_units=rnn_units)
For each character the model looks up the embedding, runs the GRU one timestep with the embedding as input, and applies the dense layer to generate logits predicting the log-likelihood of the next character:
Try the model
Now run the model to see that it behaves as expected.
First check the shape of the output:
for input_example_batch, target_example_batch in dataset.take(1):
example_batch_predictions = model(input_example_batch)
print(example_batch_predictions.shape, "# (batch_size, sequence_length, vocab_size)")
(64, 100, 66) # (batch_size, sequence_length, vocab_size)
In the above example the sequence length of the input is 100
but the model can be run on inputs of any length:
model.summary()
Model: "my_model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) multiple 16896 gru (GRU) multiple 3938304 dense (Dense) multiple 67650 ================================================================= Total params: 4,022,850 Trainable params: 4,022,850 Non-trainable params: 0 _________________________________________________________________
To get actual predictions from the model you need to sample from the output distribution, to get actual character indices. This distribution is defined by the logits over the character vocabulary.
Try it for the first example in the batch:
sampled_indices = tf.random.categorical(example_batch_predictions[0], num_samples=1)
sampled_indices = tf.squeeze(sampled_indices, axis=-1).numpy()
This gives us, at each timestep, a prediction of the next character index:
sampled_indices
array([27, 18, 48, 53, 63, 44, 27, 31, 11, 43, 34, 40, 11, 21, 2, 24, 30, 0, 32, 65, 24, 8, 53, 36, 21, 1, 50, 52, 62, 20, 21, 41, 48, 59, 7, 30, 0, 61, 13, 58, 61, 63, 29, 32, 16, 43, 19, 64, 64, 49, 36, 32, 56, 17, 43, 9, 45, 53, 23, 6, 27, 49, 52, 5, 47, 16, 32, 31, 64, 59, 22, 11, 7, 21, 44, 25, 28, 0, 19, 17, 52, 1, 13, 32, 45, 64, 30, 24, 14, 29, 30, 2, 56, 58, 11, 13, 56, 7, 3, 56])
Decode these to see the text predicted by this untrained model:
print("Input:\n", text_from_ids(input_example_batch[0]).numpy())
print()
print("Next Char Predictions:\n", text_from_ids(sampled_indices).numpy())
Input: b'I forget that by the house of York\nMy father came untimely to his death?\nDid I let pass the abuse do' Next Char Predictions: b"NEinxeNR:dUa:H KQ[UNK]SzK-nWH\nkmwGHbit,Q[UNK]v?svxPSCdFyyjWSqDd.fnJ'Njm&hCSRytI:,HeLO[UNK]FDm\n?SfyQKAPQ qs:?q,!q"
Train the model
At this point the problem can be treated as a standard classification problem. Given the previous RNN state, and the input this time step, predict the class of the next character.
Attach an optimizer, and a loss function
The standard tf.keras.losses.sparse_categorical_crossentropy
loss function works in this case because it is applied across the last dimension of the predictions.
Because your model returns logits, you need to set the from_logits
flag.
loss = tf.losses.SparseCategoricalCrossentropy(from_logits=True)
example_batch_mean_loss = loss(target_example_batch, example_batch_predictions)
print("Prediction shape: ", example_batch_predictions.shape, " # (batch_size, sequence_length, vocab_size)")
print("Mean loss: ", example_batch_mean_loss)
Prediction shape: (64, 100, 66) # (batch_size, sequence_length, vocab_size) Mean loss: tf.Tensor(4.1913805, shape=(), dtype=float32)
A newly initialized model shouldn't be too sure of itself, the output logits should all have similar magnitudes. To confirm this you can check that the exponential of the mean loss is approximately equal to the vocabulary size. A much higher loss means the model is sure of its wrong answers, and is badly initialized:
tf.exp(example_batch_mean_loss).numpy()
66.114
Configure the training procedure using the tf.keras.Model.compile
method. Use tf.keras.optimizers.Adam
with default arguments and the loss function.
model.compile(optimizer='adam', loss=loss)
Configure checkpoints
Use a tf.keras.callbacks.ModelCheckpoint
to ensure that checkpoints are saved during training:
# Directory where the checkpoints will be saved
checkpoint_dir = './training_checkpoints'
# Name of the checkpoint files
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_prefix,
save_weights_only=True)
Execute the training
To keep training time reasonable, use 10 epochs to train the model. In Colab, set the runtime to GPU for faster training.
EPOCHS = 20
history = model.fit(dataset, epochs=EPOCHS, callbacks=[checkpoint_callback])
Epoch 1/20 172/172 [==============================] - 10s 41ms/step - loss: 2.7194 Epoch 2/20 172/172 [==============================] - 8s 38ms/step - loss: 1.9872 Epoch 3/20 172/172 [==============================] - 7s 37ms/step - loss: 1.7153 Epoch 4/20 172/172 [==============================] - 8s 38ms/step - loss: 1.5552 Epoch 5/20 172/172 [==============================] - 7s 37ms/step - loss: 1.4550 Epoch 6/20 172/172 [==============================] - 8s 37ms/step - loss: 1.3868 Epoch 7/20 172/172 [==============================] - 7s 37ms/step - loss: 1.3343 Epoch 8/20 172/172 [==============================] - 7s 37ms/step - loss: 1.2899 Epoch 9/20 172/172 [==============================] - 8s 37ms/step - loss: 1.2497 Epoch 10/20 172/172 [==============================] - 7s 37ms/step - loss: 1.2106 Epoch 11/20 172/172 [==============================] - 7s 37ms/step - loss: 1.1727 Epoch 12/20 172/172 [==============================] - 8s 37ms/step - loss: 1.1321 Epoch 13/20 172/172 [==============================] - 8s 38ms/step - loss: 1.0901 Epoch 14/20 172/172 [==============================] - 7s 37ms/step - loss: 1.0462 Epoch 15/20 172/172 [==============================] - 7s 37ms/step - loss: 0.9999 Epoch 16/20 172/172 [==============================] - 7s 37ms/step - loss: 0.9505 Epoch 17/20 172/172 [==============================] - 7s 37ms/step - loss: 0.9000 Epoch 18/20 172/172 [==============================] - 7s 37ms/step - loss: 0.8484 Epoch 19/20 172/172 [==============================] - 7s 37ms/step - loss: 0.7972 Epoch 20/20 172/172 [==============================] - 7s 37ms/step - loss: 0.7494
Generate text
The simplest way to generate text with this model is to run it in a loop, and keep track of the model's internal state as you execute it.
Each time you call the model you pass in some text and an internal state. The model returns a prediction for the next character and its new state. Pass the prediction and state back in to continue generating text.
The following makes a single step prediction:
class OneStep(tf.keras.Model):
def __init__(self, model, chars_from_ids, ids_from_chars, temperature=1.0):
super().__init__()
self.temperature = temperature
self.model = model
self.chars_from_ids = chars_from_ids
self.ids_from_chars = ids_from_chars
# Create a mask to prevent "[UNK]" from being generated.
skip_ids = self.ids_from_chars(['[UNK]'])[:, None]
sparse_mask = tf.SparseTensor(
# Put a -inf at each bad index.
values=[-float('inf')]*len(skip_ids),
indices=skip_ids,
# Match the shape to the vocabulary
dense_shape=[len(ids_from_chars.get_vocabulary())])
self.prediction_mask = tf.sparse.to_dense(sparse_mask)
@tf.function
def generate_one_step(self, inputs, states=None):
# Convert strings to token IDs.
input_chars = tf.strings.unicode_split(inputs, 'UTF-8')
input_ids = self.ids_from_chars(input_chars).to_tensor()
# Run the model.
# predicted_logits.shape is [batch, char, next_char_logits]
predicted_logits, states = self.model(inputs=input_ids, states=states,
return_state=True)
# Only use the last prediction.
predicted_logits = predicted_logits[:, -1, :]
predicted_logits = predicted_logits/self.temperature
# Apply the prediction mask: prevent "[UNK]" from being generated.
predicted_logits = predicted_logits + self.prediction_mask
# Sample the output logits to generate token IDs.
predicted_ids = tf.random.categorical(predicted_logits, num_samples=1)
predicted_ids = tf.squeeze(predicted_ids, axis=-1)
# Convert from token ids to characters
predicted_chars = self.chars_from_ids(predicted_ids)
# Return the characters and model state.
return predicted_chars, states
one_step_model = OneStep(model, chars_from_ids, ids_from_chars)
Run it in a loop to generate some text. Looking at the generated text, you'll see the model knows when to capitalize, make paragraphs and imitates a Shakespeare-like writing vocabulary. With the small number of training epochs, it has not yet learned to form coherent sentences.
start = time.time()
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]
for n in range(1000):
next_char, states = one_step_model.generate_one_step(next_char, states=states)
result.append(next_char)
result = tf.strings.join(result)
end = time.time()
print(result[0].numpy().decode('utf-8'), '\n\n' + '_'*80)
print('\nRun time:', end - start)
ROMEO: Then will I grieve: you shall not love. BRUTUS: I do beseech you, I, fellow-bring fortune, unskin my name remies; My mirth be no kind of stay, All fuirting to the fair strang-herrio. BIANCA: He hath no faults upon some report. HORTENSIO: What are you, if Coriolanus come in badlity. BENVOLIO: Romeo is banish'd, keep your sword your father's death. Boy: Garden, many fend to me and kiss me up. GLOUCESTER: Whither? MOPSA: What should I, for thy waning streaks, When he may be reported. She is time enough. STANLEY: Unless he say I may thy youth, How doth had part in hollow pathic. Though you have bring once seed thee so much eremy for both yourself to me for him, As tell take creature in him fix, With many children, give shield now than him, For my sweet love's cousin, lest mine ear hath pregnant Hath been burnt told on thine eyes; For she upon to have a noble city for his blood. QUEEN: Sonder more much loss. ROMEO: Well, low now, go: farewell. LARTIUS: O my desera-ceence! BAANT ________________________________________________________________________________ Run time: 2.548593044281006
The easiest thing you can do to improve the results is to train it for longer (try EPOCHS = 30
).
You can also experiment with a different start string, try adding another RNN layer to improve the model's accuracy, or adjust the temperature parameter to generate more or less random predictions.
If you want the model to generate text faster the easiest thing you can do is batch the text generation. In the example below the model generates 5 outputs in about the same time it took to generate 1 above.
start = time.time()
states = None
next_char = tf.constant(['ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:', 'ROMEO:'])
result = [next_char]
for n in range(1000):
next_char, states = one_step_model.generate_one_step(next_char, states=states)
result.append(next_char)
result = tf.strings.join(result)
end = time.time()
print(result, '\n\n' + '_'*80)
print('\nRun time:', end - start)
tf.Tensor( [b"ROMEO:\nGood queen, my dear lord uncle Gloucester?\nI would take me an hour together: days appoor a\nhew of life or dismalond; and imprison't\nVoices on the quarrel duke of mantle's daughter.\n\nQUEEN ELIZABETH:\nGo all his lands, in how the doorman boast\nMost patient, and but even follow'd so much them;\nWhere you gave death'd of that your own daughters that\njay frums, where you should be proclaimed us in a loyal,\nthanks, indeed with the prince my ripe, my lord.\n\nBRUTUS:\nGrantagenemed me else,\nProud queen in presence to my rughts'l stain\nThan I desire.\n\nISABELLA:\nO, pardon me,\nBe presently as you shall command me to the\ngrace! Know you your words?\nOr was not a word?\n\nANTONIO:\nWhat, I were as liet the hand for a feward of,\nIf bring me to thy most obsin amends.\n\nPAge:\nMy liege, I hope, but one that is Lord Angelo;\nWhy, how farewell, dear queen, you shall be of her?\n\nLUCENTIO:\nTranio, I shall not like the sea,\nWill in armormed breaths for thee to auguen!\nShup do instructions make the imagination my lo" b"ROMEO:\nGrommous back.\n\nMENENIUS:\nPrithee, now my sword; I would say we to of;\nIf ever he needs make thee scarce, some loss\nOf wace every one: suse, now the\nprize of old and rulensapies with the very night\nHath pawn'd me of report you are too swore,\nOf the wolf worthy wretch dead alike.\n\nFirst Watchman:\nTake hear thee hunt: what answer so,\nYou have put you yourself.\n\nPOLIXENES:\nStay, the shadow of my son!\n\nDUKE OF AUMERLE:\nCounty me, I beseech you, hear me back more\nin love.\n\nHASTINGS:\nNo, Pompey, nor no more of him; but\nTheir sorrows lips me, and I am sorry,\nMost fitted forgiveness!\nThe other dukes for truths for her true-king,\nMight have deserved no prosperous\nmalicied and aptlar to this hoar.\n\nHENRY BOWINGS:\nHow fares our brains, Tybalt, within the eldest shrew.\nLet that galt call God forbid!\n\nQUEEN MARGARET:\nBelieve me, love, it may be you and in time\nWill be exhy till one hand that Henry\nHill not we known. He that actions have play'd to use,\nAnd make the gods as hard as it,\nOr that thy h" b"ROMEO:\nSo came from Mantua; what are you,\nFirst, tell me, good my lords, be blunt, I warrant\nhim.'\n\nSICINIUS:\nShe you must stay.\n\nYORK:\nNothing to Angelo, bring me to thy bed:\nIn't now have borne up, and come in his heart.\nHold, that Aniel, sirrah friendly drubs,\nTo clothe our company.\n\nLEONTES:\nWhat news with spying?\n\nQUEEN MARGARET:\nYes, look upon his scorn and tears, when you parted\nI prayed and marry her. My own wife,\nThy brother, thus I loved their ground:\nI wis a pottion spend.\n\nKING RICHARD II:\nWell you then for fear? impatience.\nThis as poor Elbow, noble Norfolk,\nThen give me leave and left no less\nThan Honey of my country's queen.\nI would foul word. This sue mocks; all I meet.\n\nLADY ANNE:\nI have been poor, I saw him with him.\n\nPAULINA:\nNo conful, and so young Henry's heart?\n\nLEONTES:\nShall sweet sir, he\nhit with inkspician liking: thou art dead and with\nthe sharper of the grieve.\nPreport, no: a man I am.\n\nPETRUCHIO:\nA sink in that work, how to beauty of it?\n\nDUCHESS OF YORK:\n\nHENRY " b"ROMEO:\n\nHORTENSIO:\n\nPignard:\nGood Perdita,\nIn them and my sweet love-pattle, heaven will I so.\n\nESCALUS:\nDo not repent the surpast.\n\nGREGORY:\n\nPARIS:\nThink you a piece of it is but love, a\ngovernment.\n\nPETRUCHIO:\nA boor i' the younger grave? now art thou livest:\nThe grace of Rome traitors in my liering blood\nWith many untired enemy.\n\nCLAUDIO:\nO, good hopice, murder! know you what he did\nIn this spoil fortune's veins.\n\nCAMILLO:\nShall sut it is the branlard girl:\nWith pilit in the Lord Bohemia cheers-hised Freence,\nTen thousand joins with wings not die.\n\nAUTOLYCUS:\nWhich you are not! why, who' the mighty\nGovernness: for ever I was badded me to-mortal\nThe day I make beworded.\n\nANGELO:\nPlease ye one, you warm but outch ile\nere her tears that I shall do with us, and\nfor the world: is given for the bound of Mortian\n'Courage, flotenar, let me see; for I am sorry.\nI shall not see her men, the greatest daughten instructs.\n\nGLOUCESTER:\nWhat now you accept? Pray?\nTend them that are patience: then, fair" b"ROMEO:\nGood my poor fellow, sir.\n\nHERRO ENV:\nAs clear, my liege; more of the malice or\nCan to before you go; and if you break him ground,\nAs misedeen blood this land at the worst,\nIs all your former proud world obder hereafter.\nFarewell, by his soul.\nThis deed unskin. But who is remains\nThe flood of the tortured sons.\n\nPAULINA:\nI no; for gratity:\nNothing both dispatch; 'tis by right and not my close.\nAnd sir; He'er replied the ballad in this lord\nWhen villains, brainness. Her knowing of't not with his body.\n\nHORTENSIO:\nPerpetuable and lowly mail'd\nat: they every daughter made am Jonn\nand are but now 'pardon' doth reputed and fight;\nWho gladly found a virginant.\n\nVOLUMNIA:\nI will be.\n\nPOMPEY:\nHe shall bring Romeo than my life:\nPeace have we put me this hour her honour\nThan deciver in tears.\nCome, you are cluel for't; and every thing\nMore honour than our folly. Friar, thus I swear to-night.\n\nGLOY:\nSort news aboat, as penitentless in abready when their\ndrums blood that I have there and more in "], shape=(5,), dtype=string) ________________________________________________________________________________ Run time: 2.388988971710205
Export the generator
This single-step model can easily be saved and restored, allowing you to use it anywhere a tf.saved_model
is accepted.
tf.saved_model.save(one_step_model, 'one_step')
one_step_reloaded = tf.saved_model.load('one_step')
WARNING:tensorflow:Skipping full serialization of Keras layer <__main__.OneStep object at 0x7f0bac1778b0>, because it is not built. WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config. WARNING:tensorflow:Model's `__init__()` arguments contain non-serializable objects. Please implement a `get_config()` method in the subclassed Model for proper saving and loading. Defaulting to empty config. WARNING:absl:Found untraced functions such as gru_cell_layer_call_fn, gru_cell_layer_call_and_return_conditional_losses while saving (showing 2 of 2). These functions will not be directly callable after loading. INFO:tensorflow:Assets written to: one_step/assets INFO:tensorflow:Assets written to: one_step/assets
states = None
next_char = tf.constant(['ROMEO:'])
result = [next_char]
for n in range(100):
next_char, states = one_step_reloaded.generate_one_step(next_char, states=states)
result.append(next_char)
print(tf.strings.join(result)[0].numpy().decode("utf-8"))
ROMEO: Verona. POMPEY: Varrail! These I will bring not on our knowledge, Is colder than besides and not o
Advanced: Customized Training
The above training procedure is simple, but does not give you much control. It uses teacher-forcing which prevents bad predictions from being fed back to the model, so the model never learns to recover from mistakes.
So now that you've seen how to run the model manually next you'll implement the training loop. This gives a starting point if, for example, you want to implement curriculum learning to help stabilize the model's open-loop output.
The most important part of a custom training loop is the train step function.
Use tf.GradientTape
to track the gradients. You can learn more about this approach by reading the eager execution guide.
The basic procedure is:
- Execute the model and calculate the loss under a
tf.GradientTape
. - Calculate the updates and apply them to the model using the optimizer.
class CustomTraining(MyModel):
@tf.function
def train_step(self, inputs):
inputs, labels = inputs
with tf.GradientTape() as tape:
predictions = self(inputs, training=True)
loss = self.loss(labels, predictions)
grads = tape.gradient(loss, model.trainable_variables)
self.optimizer.apply_gradients(zip(grads, model.trainable_variables))
return {'loss': loss}
The above implementation of the train_step
method follows Keras' train_step
conventions. This is optional, but it allows you to change the behavior of the train step and still use keras' Model.compile
and Model.fit
methods.
model = CustomTraining(
vocab_size=len(ids_from_chars.get_vocabulary()),
embedding_dim=embedding_dim,
rnn_units=rnn_units)
model.compile(optimizer = tf.keras.optimizers.Adam(),
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True))
model.fit(dataset, epochs=1)
172/172 [==============================] - 10s 39ms/step - loss: 2.7315 <keras.callbacks.History at 0x7f0b741e5730>
Or if you need more control, you can write your own complete custom training loop:
EPOCHS = 10
mean = tf.metrics.Mean()
for epoch in range(EPOCHS):
start = time.time()
mean.reset_states()
for (batch_n, (inp, target)) in enumerate(dataset):
logs = model.train_step([inp, target])
mean.update_state(logs['loss'])
if batch_n % 50 == 0:
template = f"Epoch {epoch+1} Batch {batch_n} Loss {logs['loss']:.4f}"
print(template)
# saving (checkpoint) the model every 5 epochs
if (epoch + 1) % 5 == 0:
model.save_weights(checkpoint_prefix.format(epoch=epoch))
print()
print(f'Epoch {epoch+1} Loss: {mean.result().numpy():.4f}')
print(f'Time taken for 1 epoch {time.time() - start:.2f} sec')
print("_"*80)
model.save_weights(checkpoint_prefix.format(epoch=epoch))
Epoch 1 Batch 0 Loss 2.2062 Epoch 1 Batch 50 Loss 2.0497 Epoch 1 Batch 100 Loss 1.9332 Epoch 1 Batch 150 Loss 1.8559 Epoch 1 Loss: 2.0034 Time taken for 1 epoch 8.75 sec ________________________________________________________________________________ Epoch 2 Batch 0 Loss 1.8145 Epoch 2 Batch 50 Loss 1.7844 Epoch 2 Batch 100 Loss 1.6931 Epoch 2 Batch 150 Loss 1.6794 Epoch 2 Loss: 1.7251 Time taken for 1 epoch 7.27 sec ________________________________________________________________________________ Epoch 3 Batch 0 Loss 1.5716 Epoch 3 Batch 50 Loss 1.5892 Epoch 3 Batch 100 Loss 1.6109 Epoch 3 Batch 150 Loss 1.5411 Epoch 3 Loss: 1.5612 Time taken for 1 epoch 7.27 sec ________________________________________________________________________________ Epoch 4 Batch 0 Loss 1.5222 Epoch 4 Batch 50 Loss 1.4486 Epoch 4 Batch 100 Loss 1.4479 Epoch 4 Batch 150 Loss 1.4528 Epoch 4 Loss: 1.4601 Time taken for 1 epoch 7.27 sec ________________________________________________________________________________ Epoch 5 Batch 0 Loss 1.3614 Epoch 5 Batch 50 Loss 1.3993 Epoch 5 Batch 100 Loss 1.3564 Epoch 5 Batch 150 Loss 1.3712 Epoch 5 Loss: 1.3912 Time taken for 1 epoch 7.61 sec ________________________________________________________________________________ Epoch 6 Batch 0 Loss 1.3379 Epoch 6 Batch 50 Loss 1.3391 Epoch 6 Batch 100 Loss 1.3659 Epoch 6 Batch 150 Loss 1.3547 Epoch 6 Loss: 1.3378 Time taken for 1 epoch 7.39 sec ________________________________________________________________________________ Epoch 7 Batch 0 Loss 1.2858 Epoch 7 Batch 50 Loss 1.2782 Epoch 7 Batch 100 Loss 1.2711 Epoch 7 Batch 150 Loss 1.2913 Epoch 7 Loss: 1.2928 Time taken for 1 epoch 7.28 sec ________________________________________________________________________________ Epoch 8 Batch 0 Loss 1.2410 Epoch 8 Batch 50 Loss 1.2365 Epoch 8 Batch 100 Loss 1.2342 Epoch 8 Batch 150 Loss 1.2675 Epoch 8 Loss: 1.2513 Time taken for 1 epoch 7.36 sec ________________________________________________________________________________ Epoch 9 Batch 0 Loss 1.2105 Epoch 9 Batch 50 Loss 1.2482 Epoch 9 Batch 100 Loss 1.2228 Epoch 9 Batch 150 Loss 1.2096 Epoch 9 Loss: 1.2126 Time taken for 1 epoch 7.29 sec ________________________________________________________________________________ Epoch 10 Batch 0 Loss 1.1437 Epoch 10 Batch 50 Loss 1.1936 Epoch 10 Batch 100 Loss 1.2005 Epoch 10 Batch 150 Loss 1.1695 Epoch 10 Loss: 1.1726 Time taken for 1 epoch 7.53 sec ________________________________________________________________________________