Richiamate dei componenti aggiuntivi di TensorFlow: barra di avanzamento TQDM

Panoramica

Questo notebook dimostrerà come utilizzare TQDMCallback in TensorFlow Addons.

Impostare

pip install -U tensorflow-addons
!pip install -q "tqdm>=4.36.1"

import tensorflow as tf
import tensorflow_addons as tfa

from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense, Dropout, Flatten
import tqdm

# quietly deep-reload tqdm
import sys
from IPython.lib import deepreload

stdout
= sys.stdout
sys
.stdout = open('junk','w')
deepreload
.reload(tqdm)
sys
.stdout = stdout

tqdm
.__version__
'4.62.3'

Importa e normalizza i dati

# the data, split between train and test sets
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# normalize data
x_train
, x_test = x_train / 255.0, x_test / 255.0

Costruisci un modello CNN MNIST semplice

# build the model using the Sequential API
model
= Sequential()
model
.add(Flatten(input_shape=(28, 28)))
model
.add(Dense(128, activation='relu'))
model
.add(Dropout(0.2))
model
.add(Dense(10, activation='softmax'))

model
.compile(optimizer='adam',
              loss
= 'sparse_categorical_crossentropy',
              metrics
=['accuracy'])

Utilizzo predefinito di TQDMCallback

# initialize tqdm callback with default parameters
tqdm_callback
= tfa.callbacks.TQDMProgressBar()

# train the model with tqdm_callback
# make sure to set verbose = 0 to disable
# the default progress bar.
model
.fit(x_train, y_train,
          batch_size
=64,
          epochs
=10,
          verbose
=0,
          callbacks
=[tqdm_callback],
          validation_data
=(x_test, y_test))
Training:   0%|           0/10 ETA: ?s,  ?epochs/s
Epoch 1/10
0/938           ETA: ?s -
Epoch 2/10
0/938           ETA: ?s -
Epoch 3/10
0/938           ETA: ?s -
Epoch 4/10
0/938           ETA: ?s -
Epoch 5/10
0/938           ETA: ?s -
Epoch 6/10
0/938           ETA: ?s -
Epoch 7/10
0/938           ETA: ?s -
Epoch 8/10
0/938           ETA: ?s -
Epoch 9/10
0/938           ETA: ?s -
Epoch 10/10
0/938           ETA: ?s -
<keras.callbacks.History at 0x7f4a8d35aed0>

Di seguito è riportato l'output previsto quando si esegue la cella sopra Figura della barra di avanzamento TQDM

# TQDMProgressBar() also works with evaluate()
model
.evaluate(x_test, y_test, batch_size=64, callbacks=[tqdm_callback], verbose=0)
0/157           ETA: ?s - Evaluating
[0.06689586490392685, 0.9805999994277954]

Di seguito è riportato l'output previsto quando si esegue la cella sopra TQDM Valuta la figura della barra di avanzamento