图像分割

在 TensorFlow.org 上查看 在 Google Colab 中运行 在 GitHub 上查看源代码 下载笔记本

这篇教程将重点讨论图像分割任务,使用的是改进版的 U-Net

什么是图像分割?

在图像分类任务中,网络会为每个输入图像分配一个标签(或类)。但是,如何了解该对象的形状、哪个像素属于哪个对象等信息呢?在这种情况下,您需要为图像的每个像素分配一个类。此任务称为分割。分割模型会返回有关图像的更详细信息。图像分割在医学成像、自动驾驶汽车和卫星成像等方面有很多应用。

本教程使用 Oxford-IIIT Pet Dataset (Parkhi et al, 2012)。该数据集由 37 个宠物品种的图像组成,每个品种有 200 个图像(训练拆分和测试拆分各有 100 个)。每个图像都包含相应的标签和像素级掩码。掩码是每个像素的类标签。每个像素都会被划入以下三个类别之一:

  • 第 1 类:属于宠物的像素。
  • 第 2 类:宠物边缘的像素。
  • 第 3 类:以上都不是/周围的像素。
pip install git+https://github.com/tensorflow/examples.git
import tensorflow as tf

import tensorflow_datasets as tfds
2023-11-07 22:50:13.585808: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2023-11-07 22:50:13.585864: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2023-11-07 22:50:13.587520: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
from tensorflow_examples.models.pix2pix import pix2pix

from IPython.display import clear_output
import matplotlib.pyplot as plt

下载 Oxford-IIIT Pets 数据集

该数据集可从 TensorFlow Datasets 获得。分割掩码包含在版本 3 以上的版本中。

dataset, info = tfds.load('oxford_iiit_pet:3.*.*', with_info=True)

此外,图像颜色值被归一化到 [0,1] 范围。最后,如上所述,分割掩码中的像素被标记为 {1, 2, 3}。为方便起见,从分割掩码中减去 1,得到的标签为:{0, 1, 2}。

def normalize(input_image, input_mask):
  input_image = tf.cast(input_image, tf.float32) / 255.0
  input_mask -= 1
  return input_image, input_mask
def load_image(datapoint):
  input_image = tf.image.resize(datapoint['image'], (128, 128))
  input_mask = tf.image.resize(
    datapoint['segmentation_mask'],
    (128, 128),
    method = tf.image.ResizeMethod.NEAREST_NEIGHBOR,
  )

  input_image, input_mask = normalize(input_image, input_mask)

  return input_image, input_mask

数据集已包含所需的训练拆分和测试拆分,因此请继续使用相同的拆分。

TRAIN_LENGTH = info.splits['train'].num_examples
BATCH_SIZE = 64
BUFFER_SIZE = 1000
STEPS_PER_EPOCH = TRAIN_LENGTH // BATCH_SIZE
train_images = dataset['train'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)
test_images = dataset['test'].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)

下面的类通过随机翻转图像来执行简单的增强。请转到图像增强教程以了解更多信息。

class Augment(tf.keras.layers.Layer):
  def __init__(self, seed=42):
    super().__init__()
    # both use the same seed, so they'll make the same random changes.
    self.augment_inputs = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)
    self.augment_labels = tf.keras.layers.RandomFlip(mode="horizontal", seed=seed)

  def call(self, inputs, labels):
    inputs = self.augment_inputs(inputs)
    labels = self.augment_labels(labels)
    return inputs, labels

构建输入流水线,在对输入进行批处理后应用增强:

train_batches = (
    train_images
    .cache()
    .shuffle(BUFFER_SIZE)
    .batch(BATCH_SIZE)
    .repeat()
    .map(Augment())
    .prefetch(buffer_size=tf.data.AUTOTUNE))

test_batches = test_images.batch(BATCH_SIZE)

呈现数据集中的图像样本及其对应的掩码:

def display(display_list):
  plt.figure(figsize=(15, 15))

  title = ['Input Image', 'True Mask', 'Predicted Mask']

  for i in range(len(display_list)):
    plt.subplot(1, len(display_list), i+1)
    plt.title(title[i])
    plt.imshow(tf.keras.utils.array_to_img(display_list[i]))
    plt.axis('off')
  plt.show()
for images, masks in train_batches.take(2):
  sample_image, sample_mask = images[0], masks[0]
  display([sample_image, sample_mask])
Corrupt JPEG data: 240 extraneous bytes before marker 0xd9
Corrupt JPEG data: premature end of data segment

png

2023-11-07 22:50:22.420833: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.

png

定义模型

这里使用的模型是修改后的 U-Net。U-Net 由编码器(下采样器)和解码器(上采样器)组成。为了学习稳健的特征并减少可训练参数的数量,请使用预训练模型 MobileNetV2 作为编码器。对于解码器,您将使用上采样块,该块已在 TensorFlow Examples 仓库的 pix2pix 示例中实现。(请查看笔记本中的 pix2pix:使用条件 GAN 进行图像到图像转换教程。)

如前所述,编码器是一个预训练的 MobileNetV2 模型。您将使用来自 tf.keras.applications 的模型。编码器由模型中中间层的特定输出组成。请注意,在训练过程中不会训练编码器。

base_model = tf.keras.applications.MobileNetV2(input_shape=[128, 128, 3], include_top=False)

# Use the activations of these layers
layer_names = [
    'block_1_expand_relu',   # 64x64
    'block_3_expand_relu',   # 32x32
    'block_6_expand_relu',   # 16x16
    'block_13_expand_relu',  # 8x8
    'block_16_project',      # 4x4
]
base_model_outputs = [base_model.get_layer(name).output for name in layer_names]

# Create the feature extraction model
down_stack = tf.keras.Model(inputs=base_model.input, outputs=base_model_outputs)

down_stack.trainable = False
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_128_no_top.h5
9406464/9406464 [==============================] - 0s 0us/step

解码器/上采样器只是在 TensorFlow 示例中实现的一系列上采样块:

up_stack = [
    pix2pix.upsample(512, 3),  # 4x4 -> 8x8
    pix2pix.upsample(256, 3),  # 8x8 -> 16x16
    pix2pix.upsample(128, 3),  # 16x16 -> 32x32
    pix2pix.upsample(64, 3),   # 32x32 -> 64x64
]
def unet_model(output_channels:int):
  inputs = tf.keras.layers.Input(shape=[128, 128, 3])

  # Downsampling through the model
  skips = down_stack(inputs)
  x = skips[-1]
  skips = reversed(skips[:-1])

  # Upsampling and establishing the skip connections
  for up, skip in zip(up_stack, skips):
    x = up(x)
    concat = tf.keras.layers.Concatenate()
    x = concat([x, skip])

  # This is the last layer of the model
  last = tf.keras.layers.Conv2DTranspose(
      filters=output_channels, kernel_size=3, strides=2,
      padding='same')  #64x64 -> 128x128

  x = last(x)

  return tf.keras.Model(inputs=inputs, outputs=x)

请注意,最后一层的筛选器数量设置为 output_channels 的数量。每个类将有一个输出通道。

训练模型

现在,剩下要做的是编译和训练模型。

由于这是一个多类分类问题,请使用 tf.keras.losses.SparseCategoricalCrossentropy 损失函数,并将 from_logits 参数设置为 True,因为标签是标量整数,而不是每个类的每个像素的分数向量。

运行推断时,分配给像素的标签是具有最高值的通道。这就是 create_mask 函数的作用。

OUTPUT_CLASSES = 3

model = unet_model(output_channels=OUTPUT_CLASSES)
model.compile(optimizer='adam',
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
              metrics=['accuracy'])

绘制最终的模型架构:

tf.keras.utils.plot_model(model, show_shapes=True)

png

在训练前试用一下该模型,以检查其预测结果:

def create_mask(pred_mask):
  pred_mask = tf.math.argmax(pred_mask, axis=-1)
  pred_mask = pred_mask[..., tf.newaxis]
  return pred_mask[0]
def show_predictions(dataset=None, num=1):
  if dataset:
    for image, mask in dataset.take(num):
      pred_mask = model.predict(image)
      display([image[0], mask[0], create_mask(pred_mask)])
  else:
    display([sample_image, sample_mask,
             create_mask(model.predict(sample_image[tf.newaxis, ...]))])
show_predictions()
1/1 [==============================] - 3s 3s/step

png

下面定义的回调用于观察模型在训练过程中的改进情况:

class DisplayCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs=None):
    clear_output(wait=True)
    show_predictions()
    print ('\nSample Prediction after epoch {}\n'.format(epoch+1))
EPOCHS = 20
VAL_SUBSPLITS = 5
VALIDATION_STEPS = info.splits['test'].num_examples//BATCH_SIZE//VAL_SUBSPLITS

model_history = model.fit(train_batches, epochs=EPOCHS,
                          steps_per_epoch=STEPS_PER_EPOCH,
                          validation_steps=VALIDATION_STEPS,
                          validation_data=test_batches,
                          callbacks=[DisplayCallback()])
1/1 [==============================] - 0s 47ms/step

png

Sample Prediction after epoch 20

57/57 [==============================] - 8s 144ms/step - loss: 0.1713 - accuracy: 0.9301 - val_loss: 0.2807 - val_accuracy: 0.9046
loss = model_history.history['loss']
val_loss = model_history.history['val_loss']

plt.figure()
plt.plot(model_history.epoch, loss, 'r', label='Training loss')
plt.plot(model_history.epoch, val_loss, 'bo', label='Validation loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss Value')
plt.ylim([0, 1])
plt.legend()
plt.show()

png

做出预测

接下来,进行一些预测。为了节省时间,保持较小周期数,但您也可以将其设置得更高以获得更准确的结果。

show_predictions(test_batches, 3)
2/2 [==============================] - 0s 25ms/step

png

2/2 [==============================] - 0s 32ms/step

png

2/2 [==============================] - 0s 35ms/step

png

可选:不平衡的类和类权重

语义分割数据集可能会高度不平衡,这意味着特定类别的像素可以比其他类别的像素更多地出现在图像内部。由于分割问题可以被视为逐像素分类问题,因此您可以通过加权损失函数来解决不平衡问题。这是处理此问题的一种简单而优雅的方式。请参阅不平衡数据分类教程了解更多信息。

避免歧义,{code 1}Model.fit{/code 1} 不支持具有 3 个以上维度的目标的 class_weight 参数。

try:
  model_history = model.fit(train_batches, epochs=EPOCHS,
                            steps_per_epoch=STEPS_PER_EPOCH,
                            class_weight = {0:2.0, 1:2.0, 2:1.0})
  assert False
except Exception as e:
  print(f"Expected {type(e).__name__}: {e}")
Epoch 1/20
57/57 [==============================] - 8s 114ms/step - loss: 0.2579 - accuracy: 0.9219
Epoch 2/20
57/57 [==============================] - 6s 114ms/step - loss: 0.2671 - accuracy: 0.9201
Epoch 3/20
57/57 [==============================] - 6s 114ms/step - loss: 0.2479 - accuracy: 0.9245
Epoch 4/20
57/57 [==============================] - 6s 114ms/step - loss: 0.2331 - accuracy: 0.9279
Epoch 5/20
57/57 [==============================] - 6s 114ms/step - loss: 0.2283 - accuracy: 0.9293
Epoch 6/20
57/57 [==============================] - 6s 113ms/step - loss: 0.2157 - accuracy: 0.9325
Epoch 7/20
57/57 [==============================] - 6s 113ms/step - loss: 0.2114 - accuracy: 0.9337
Epoch 8/20
57/57 [==============================] - 6s 113ms/step - loss: 0.2032 - accuracy: 0.9357
Epoch 9/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1988 - accuracy: 0.9371
Epoch 10/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1935 - accuracy: 0.9384
Epoch 11/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1971 - accuracy: 0.9377
Epoch 12/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1869 - accuracy: 0.9403
Epoch 13/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1814 - accuracy: 0.9419
Epoch 14/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1727 - accuracy: 0.9446
Epoch 15/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1699 - accuracy: 0.9452
Epoch 16/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1683 - accuracy: 0.9458
Epoch 17/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1666 - accuracy: 0.9465
Epoch 18/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1594 - accuracy: 0.9485
Epoch 19/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1538 - accuracy: 0.9502
Epoch 20/20
57/57 [==============================] - 6s 113ms/step - loss: 0.1524 - accuracy: 0.9507
Expected AssertionError:

因此,在这种情况下,您需要自己实现加权。您将使用样本权重来执行此操作:除了 (data, label) 对之外,Model.fit 还接受 (data, label, sample_weight) 三元组。

Keras Model.fitsample_weight 传播给损失和指标,它们也接受 sample_weight 参数。在归约步骤之前,将样本权重乘以样本值。例如:

label = [0,0]
prediction = [[-3., 0], [-3, 0]] 
sample_weight = [1, 10] 

loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True,
                                               reduction=tf.keras.losses.Reduction.NONE)
loss(label, prediction, sample_weight).numpy()
array([ 3.0485873, 30.485874 ], dtype=float32)

因此,要为本教程设置样本权重,您需要一个函数,该函数接受 (data, label) 对并返回 (data, label, sample_weight) 三元组,其中 sample_weight 是包含每个像素的类权重的单通道图像。

最简单的可能实现是将标签用作 class_weight 列表的索引:

def add_sample_weights(image, label):
  # The weights for each class, with the constraint that:
  #     sum(class_weights) == 1.0
  class_weights = tf.constant([2.0, 2.0, 1.0])
  class_weights = class_weights/tf.reduce_sum(class_weights)

  # Create an image of `sample_weights` by using the label at each pixel as an 
  # index into the `class weights` .
  sample_weights = tf.gather(class_weights, indices=tf.cast(label, tf.int32))

  return image, label, sample_weights

每个生成的数据集元素包含 3 个图像:

train_batches.map(add_sample_weights).element_spec
(TensorSpec(shape=(None, 128, 128, 3), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None),
 TensorSpec(shape=(None, 128, 128, 1), dtype=tf.float32, name=None))

现在,您可以在此加权数据集上训练模型:

weighted_model = unet_model(OUTPUT_CLASSES)
weighted_model.compile(
    optimizer='adam',
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
    metrics=['accuracy'])
weighted_model.fit(
    train_batches.map(add_sample_weights),
    epochs=1,
    steps_per_epoch=10)
10/10 [==============================] - 6s 119ms/step - loss: 0.2683 - accuracy: 0.6645
<keras.src.callbacks.History at 0x7f6d97411700>

接下来

现在您已经了解了什么是图像分割及其工作原理,您可以使用不同的中间层输出,甚至不同的预训练模型来尝试本教程。您也可以通过尝试在 Kaggle 上托管的 Carvana 图像掩码挑战来挑战自己。

您可能还想查看另一个可以根据自己的数据重新训练的模型的 Tensorflow Object Detection APITensorFlow Hub 上提供了预训练模型。