Menampilkan data gambar di TensorBoard

Lihat di TensorFlow.org Jalankan di Google Colab Lihat sumber di GitHub Unduh buku catatan

Ringkasan

Menggunakan TensorFlow Gambar Ringkasan API, Anda dapat dengan mudah login tensor dan gambar sewenang-wenang dan melihatnya dalam TensorBoard. Hal ini dapat sangat membantu untuk sampel dan memeriksa data masukan Anda, atau untuk memvisualisasikan bobot lapisan dan tensor dihasilkan . Anda juga dapat mencatat data diagnostik sebagai gambar yang dapat membantu dalam pengembangan model Anda.

Dalam tutorial ini, Anda akan belajar cara menggunakan Image Summary API untuk memvisualisasikan tensor sebagai gambar. Anda juga akan belajar cara mengambil gambar arbitrer, mengubahnya menjadi tensor, dan memvisualisasikannya di TensorBoard. Anda akan bekerja melalui contoh sederhana namun nyata yang menggunakan Ringkasan Gambar untuk membantu Anda memahami bagaimana kinerja model Anda.

Mempersiapkan

try:
  # %tensorflow_version only exists in Colab.
  %tensorflow_version 2.x
except Exception:
  pass

# Load the TensorBoard notebook extension.
%load_ext tensorboard
TensorFlow 2.x selected.
from datetime import datetime
import io
import itertools
from packaging import version

import tensorflow as tf
from tensorflow import keras

import matplotlib.pyplot as plt
import numpy as np
import sklearn.metrics

print("TensorFlow version: ", tf.__version__)
assert version.parse(tf.__version__).release[0] >= 2, \
    "This notebook requires TensorFlow 2.0 or above."
TensorFlow version:  2.2

Unduh kumpulan data Fashion-MNIST

Anda akan membangun jaringan saraf sederhana untuk gambar mengklasifikasikan di dalam Fashion-MNIST dataset. Dataset ini terdiri dari 70.000 gambar skala abu-abu 28x28 produk fashion dari 10 kategori, dengan 7.000 gambar per kategori.

Pertama, unduh datanya:

# Download the data. The data is already divided into train and test.
# The labels are integers representing classes.
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = \
    fashion_mnist.load_data()

# Names of the integer classes, i.e., 0 -> T-short/top, 1 -> Trouser, etc.
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 
    'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz
32768/29515 [=================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz
26427392/26421880 [==============================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz
8192/5148 [===============================================] - 0s 0us/step
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz
4423680/4422102 [==============================] - 0s 0us/step

Memvisualisasikan satu gambar

Untuk memahami cara kerja Image Summary API, Anda sekarang cukup mencatat gambar pelatihan pertama di set pelatihan Anda di TensorBoard.

Sebelum Anda melakukannya, periksa bentuk data pelatihan Anda:

print("Shape: ", train_images[0].shape)
print("Label: ", train_labels[0], "->", class_names[train_labels[0]])
Shape:  (28, 28)
Label:  9 -> Ankle boot

Perhatikan bahwa bentuk setiap gambar dalam kumpulan data adalah tensor bentuk peringkat-2 (28, 28), yang mewakili tinggi dan lebar.

Namun, tf.summary.image() mengharapkan peringkat ke-4 tensor mengandung (batch_size, height, width, channels) . Oleh karena itu, tensor perlu dibentuk kembali.

Anda sedang logging hanya satu gambar, sehingga batch_size adalah 1. gambar yang grayscale, sehingga mengatur channels ke 1.

# Reshape the image for the Summary API.
img = np.reshape(train_images[0], (-1, 28, 28, 1))

Anda sekarang siap untuk mencatat gambar ini dan melihatnya di TensorBoard.

# Clear out any prior log data.
!rm -rf logs

# Sets up a timestamped log directory.
logdir = "logs/train_data/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Creates a file writer for the log directory.
file_writer = tf.summary.create_file_writer(logdir)

# Using the file writer, log the reshaped image.
with file_writer.as_default():
  tf.summary.image("Training data", img, step=0)

Sekarang, gunakan TensorBoard untuk memeriksa gambar. Tunggu beberapa detik hingga UI berputar.

%tensorboard --logdir logs/train_data

Tab "Gambar" menampilkan gambar yang baru saja Anda login. Ini adalah "sepatu bot".

Gambar diskalakan ke ukuran default agar lebih mudah dilihat. Jika Anda ingin melihat gambar asli tanpa skala, centang "Tampilkan ukuran gambar sebenarnya" di kiri atas.

Mainkan dengan penggeser kecerahan dan kontras untuk melihat pengaruhnya terhadap piksel gambar.

Memvisualisasikan banyak gambar

Mencatat satu tensor memang bagus, tetapi bagaimana jika Anda ingin mencatat beberapa contoh pelatihan?

Hanya menentukan jumlah gambar yang ingin Anda log ketika melewati data ke tf.summary.image() .

with file_writer.as_default():
  # Don't forget to reshape.
  images = np.reshape(train_images[0:25], (-1, 28, 28, 1))
  tf.summary.image("25 training data examples", images, max_outputs=25, step=0)

%tensorboard --logdir logs/train_data

Mencatat data gambar arbitrer

Bagaimana jika Anda ingin memvisualisasikan gambar yang tidak tensor, seperti gambar yang dihasilkan oleh matplotlib ?

Anda memerlukan beberapa kode boilerplate untuk mengonversi plot menjadi tensor, tetapi setelah itu, Anda siap melakukannya.

Dalam kode di bawah, Anda akan log 25 gambar pertama sebagai bagus grid menggunakan matplotlib ini subplot() fungsi. Anda kemudian akan melihat kisi di TensorBoard:

# Clear out prior logging data.
!rm -rf logs/plots

logdir = "logs/plots/" + datetime.now().strftime("%Y%m%d-%H%M%S")
file_writer = tf.summary.create_file_writer(logdir)

def plot_to_image(figure):
  """Converts the matplotlib plot specified by 'figure' to a PNG image and
  returns it. The supplied figure is closed and inaccessible after this call."""
  # Save the plot to a PNG in memory.
  buf = io.BytesIO()
  plt.savefig(buf, format='png')
  # Closing the figure prevents it from being displayed directly inside
  # the notebook.
  plt.close(figure)
  buf.seek(0)
  # Convert PNG buffer to TF image
  image = tf.image.decode_png(buf.getvalue(), channels=4)
  # Add the batch dimension
  image = tf.expand_dims(image, 0)
  return image

def image_grid():
  """Return a 5x5 grid of the MNIST images as a matplotlib figure."""
  # Create a figure to contain the plot.
  figure = plt.figure(figsize=(10,10))
  for i in range(25):
    # Start next subplot.
    plt.subplot(5, 5, i + 1, title=class_names[train_labels[i]])
    plt.xticks([])
    plt.yticks([])
    plt.grid(False)
    plt.imshow(train_images[i], cmap=plt.cm.binary)

  return figure

# Prepare the plot
figure = image_grid()
# Convert to image and log
with file_writer.as_default():
  tf.summary.image("Training data", plot_to_image(figure), step=0)

%tensorboard --logdir logs/plots

Membangun pengklasifikasi gambar

Sekarang gabungkan ini semua dengan contoh nyata. Lagi pula, Anda di sini untuk melakukan pembelajaran mesin dan tidak merencanakan gambar-gambar cantik!

Anda akan menggunakan ringkasan gambar untuk memahami seberapa baik performa model Anda saat melatih pengklasifikasi sederhana untuk set data Fashion-MNIST.

Pertama, buat model yang sangat sederhana dan kompilasi, atur fungsi pengoptimal dan kehilangan. Langkah kompilasi juga menentukan bahwa Anda ingin mencatat keakuratan pengklasifikasi di sepanjang jalan.

model = keras.models.Sequential([
    keras.layers.Flatten(input_shape=(28, 28)),
    keras.layers.Dense(32, activation='relu'),
    keras.layers.Dense(10, activation='softmax')
])

model.compile(
    optimizer='adam', 
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

Ketika melatih classifier, itu berguna untuk melihat matriks kebingungan . Matriks kebingungan memberi Anda pengetahuan mendetail tentang kinerja pengklasifikasi Anda pada data uji.

Tentukan fungsi yang menghitung matriks konfusi. Anda akan menggunakan nyaman Scikit-belajar fungsi untuk melakukan hal ini, dan kemudian plot menggunakan matplotlib.

def plot_confusion_matrix(cm, class_names):
  """
  Returns a matplotlib figure containing the plotted confusion matrix.

  Args:
    cm (array, shape = [n, n]): a confusion matrix of integer classes
    class_names (array, shape = [n]): String names of the integer classes
  """
  figure = plt.figure(figsize=(8, 8))
  plt.imshow(cm, interpolation='nearest', cmap=plt.cm.Blues)
  plt.title("Confusion matrix")
  plt.colorbar()
  tick_marks = np.arange(len(class_names))
  plt.xticks(tick_marks, class_names, rotation=45)
  plt.yticks(tick_marks, class_names)

  # Compute the labels from the normalized confusion matrix.
  labels = np.around(cm.astype('float') / cm.sum(axis=1)[:, np.newaxis], decimals=2)

  # Use white text if squares are dark; otherwise black.
  threshold = cm.max() / 2.
  for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
    color = "white" if cm[i, j] > threshold else "black"
    plt.text(j, i, labels[i, j], horizontalalignment="center", color=color)

  plt.tight_layout()
  plt.ylabel('True label')
  plt.xlabel('Predicted label')
  return figure

Anda sekarang siap untuk melatih pengklasifikasi dan secara teratur mencatat matriks kebingungan di sepanjang jalan.

Inilah yang akan Anda lakukan:

  1. Buat callback Keras TensorBoard login metrik dasar
  2. Buat Keras LambdaCallback untuk log matriks kebingungan pada akhir setiap zaman
  3. Latih model menggunakan Model.fit(), pastikan untuk melewati kedua callback

Saat pelatihan berlangsung, gulir ke bawah untuk melihat TensorBoard memulai.

# Clear out prior logging data.
!rm -rf logs/image

logdir = "logs/image/" + datetime.now().strftime("%Y%m%d-%H%M%S")
# Define the basic TensorBoard callback.
tensorboard_callback = keras.callbacks.TensorBoard(log_dir=logdir)
file_writer_cm = tf.summary.create_file_writer(logdir + '/cm')
def log_confusion_matrix(epoch, logs):
  # Use the model to predict the values from the validation dataset.
  test_pred_raw = model.predict(test_images)
  test_pred = np.argmax(test_pred_raw, axis=1)

  # Calculate the confusion matrix.
  cm = sklearn.metrics.confusion_matrix(test_labels, test_pred)
  # Log the confusion matrix as an image summary.
  figure = plot_confusion_matrix(cm, class_names=class_names)
  cm_image = plot_to_image(figure)

  # Log the confusion matrix as an image summary.
  with file_writer_cm.as_default():
    tf.summary.image("Confusion Matrix", cm_image, step=epoch)

# Define the per-epoch callback.
cm_callback = keras.callbacks.LambdaCallback(on_epoch_end=log_confusion_matrix)
# Start TensorBoard.
%tensorboard --logdir logs/image

# Train the classifier.
model.fit(
    train_images,
    train_labels,
    epochs=5,
    verbose=0, # Suppress chatty output
    callbacks=[tensorboard_callback, cm_callback],
    validation_data=(test_images, test_labels),
)

Perhatikan bahwa akurasi meningkat pada set kereta dan validasi. Itu pertanda baik. Tetapi bagaimana kinerja model pada subset data tertentu?

Pilih tab "Gambar" untuk memvisualisasikan matriks kebingungan yang dicatat. Centang "Tampilkan ukuran gambar sebenarnya" di kiri atas untuk melihat matriks kebingungan dalam ukuran penuh.

Secara default, dasbor menampilkan ringkasan gambar untuk langkah atau zaman terakhir yang dicatat. Gunakan penggeser untuk melihat matriks kebingungan sebelumnya. Perhatikan bagaimana matriks berubah secara signifikan saat pelatihan berlangsung, dengan kotak yang lebih gelap bergabung di sepanjang diagonal, dan matriks lainnya cenderung ke arah 0 dan putih. Ini berarti pengklasifikasi Anda meningkat seiring dengan kemajuan pelatihan! Kerja bagus!

Confusion matrix menunjukkan bahwa model sederhana ini memiliki beberapa masalah. Meskipun kemajuan besar, Shirts, T-Shirts, dan Pullover semakin bingung satu sama lain. Model membutuhkan lebih banyak pekerjaan.

Jika Anda tertarik, mencoba untuk meningkatkan model ini dengan jaringan convolutional (CNN).