Transferlernen und Feinabstimmung

Auf TensorFlow.org ansehen Quelle auf GitHub anzeigen Notizbuch herunterladen

In diesem Tutorial erfahren Sie, wie Sie Bilder von Katzen und Hunden mithilfe von Transferlernen aus einem vortrainierten Netzwerk klassifizieren.

Ein vortrainiertes Modell ist ein gespeichertes Netzwerk, das zuvor an einem großen Datensatz trainiert wurde, typischerweise an einer groß angelegten Bildklassifizierungsaufgabe. Sie verwenden entweder das vortrainierte Modell unverändert oder verwenden Transfer-Learning, um dieses Modell an eine bestimmte Aufgabe anzupassen.

Die Intuition hinter dem Transferlernen für die Bildklassifizierung ist, dass, wenn ein Modell mit einem großen und ausreichend allgemeinen Datensatz trainiert wird, dieses Modell effektiv als generisches Modell der visuellen Welt dient. Sie können dann diese erlernten Feature-Maps nutzen, ohne bei Null anfangen zu müssen, indem Sie ein großes Modell mit einem großen Dataset trainieren.

In diesem Notebook werden Sie zwei Möglichkeiten ausprobieren, um ein vortrainiertes Modell anzupassen:

  1. Merkmalsextraktion: Verwenden Sie die Repräsentationen, die von einem früheren Netzwerk gelernt wurden, um aussagekräftige Merkmale aus neuen Stichproben zu extrahieren. Sie fügen dem vortrainierten Modell einfach einen neuen Klassifikator hinzu, der von Grund auf neu trainiert wird, damit Sie die zuvor gelernten Feature-Maps für das Dataset wiederverwenden können.

    Sie müssen nicht das gesamte Modell (neu) trainieren. Das grundlegende Faltungsnetzwerk enthält bereits Merkmale, die allgemein zum Klassifizieren von Bildern nützlich sind. Der letzte Klassifikationsteil des vortrainierten Modells ist jedoch spezifisch für die ursprüngliche Klassifikationsaufgabe und anschließend spezifisch für den Klassensatz, auf dem das Modell trainiert wurde.

  2. Feinabstimmung: Entfrosten Sie einige der oberen Schichten einer eingefrorenen Modellbasis und trainieren Sie gemeinsam sowohl die neu hinzugefügten Klassifikatorschichten als auch die letzten Schichten des Basismodells. Auf diese Weise können wir die Feature-Repräsentationen höherer Ordnung im Basismodell "feinabstimmen", um sie für die spezifische Aufgabe relevanter zu machen.

Sie folgen dem allgemeinen Workflow für maschinelles Lernen.

  1. Untersuchen und verstehen Sie die Daten
  2. Erstellen Sie eine Eingabepipeline, in diesem Fall mit Keras ImageDataGenerator
  3. Stellen Sie das Modell zusammen
    • Laden im vortrainierten Basismodell (und vortrainierten Gewichten)
    • Stapeln Sie die Klassifizierungsebenen darüber
  4. Trainiere das Modell
  5. Modell auswerten
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf

from tensorflow.keras.preprocessing import image_dataset_from_directory

Datenvorverarbeitung

Datendownload

In diesem Tutorial verwenden Sie einen Datensatz mit mehreren Tausend Bildern von Katzen und Hunden. Download und Dekomprimierung eine ZIP - Datei , die die Bilder enthält, erstellen Sie dann eine tf.data.Dataset für die Ausbildung und Validierung der Verwendung von tf.keras.preprocessing.image_dataset_from_directory Dienstprogramm. Sie können mehr über den Laden von Bildern in diesen lernen Tutorial .

_URL = 'https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip'
path_to_zip = tf.keras.utils.get_file('cats_and_dogs.zip', origin=_URL, extract=True)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')

train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')

BATCH_SIZE = 32
IMG_SIZE = (160, 160)

train_dataset = image_dataset_from_directory(train_dir,
                                             shuffle=True,
                                             batch_size=BATCH_SIZE,
                                             image_size=IMG_SIZE)
Downloading data from https://storage.googleapis.com/mledu-datasets/cats_and_dogs_filtered.zip
68608000/68606236 [==============================] - 2s 0us/step
Found 2000 files belonging to 2 classes.
validation_dataset = image_dataset_from_directory(validation_dir,
                                                  shuffle=True,
                                                  batch_size=BATCH_SIZE,
                                                  image_size=IMG_SIZE)
Found 1000 files belonging to 2 classes.

Zeigen Sie die ersten neun Bilder und Labels aus dem Trainingsset an:

class_names = train_dataset.class_names

plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    plt.imshow(images[i].numpy().astype("uint8"))
    plt.title(class_names[labels[i]])
    plt.axis("off")

png

Da das ursprüngliche Dataset kein Testset enthält, erstellen Sie eines. Gehen Sie bestimmen so, wie viele Partien von Daten in dem Validierungssatz zur Verfügung stehen mit tf.data.experimental.cardinality , dann bewegen 20% von ihnen auf einem Test - Set.

val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
print('Number of validation batches: %d' % tf.data.experimental.cardinality(validation_dataset))
print('Number of test batches: %d' % tf.data.experimental.cardinality(test_dataset))
Number of validation batches: 26
Number of test batches: 6

Konfigurieren Sie das Dataset für die Leistung

Verwenden Sie gepuffertes Prefetching, um Bilder von der Festplatte zu laden, ohne dass die E/A blockiert wird. Weitere Informationen zu diesem Verfahren finden Sie in der lernen , Daten - Performance Guide.

AUTOTUNE = tf.data.AUTOTUNE

train_dataset = train_dataset.prefetch(buffer_size=AUTOTUNE)
validation_dataset = validation_dataset.prefetch(buffer_size=AUTOTUNE)
test_dataset = test_dataset.prefetch(buffer_size=AUTOTUNE)

Datenerweiterung verwenden

Wenn Sie kein großes Bild-Dataset haben, empfiehlt es sich, die Stichprobendiversität künstlich einzuführen, indem Sie zufällige, aber realistische Transformationen auf die Trainingsbilder anwenden, z. B. Drehung und horizontales Spiegeln. Dies hilft , um das Modell zu verschiedenen Aspekten der Trainingsdaten aussetzen und reduziert Überanpassung . Sie können mehr über Datenvergrößerung in diesen lernen Tutorial .

data_augmentation = tf.keras.Sequential([
  tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
  tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
])

Lassen Sie uns diese Ebenen wiederholt auf dasselbe Bild anwenden und das Ergebnis sehen.

for image, _ in train_dataset.take(1):
  plt.figure(figsize=(10, 10))
  first_image = image[0]
  for i in range(9):
    ax = plt.subplot(3, 3, i + 1)
    augmented_image = data_augmentation(tf.expand_dims(first_image, 0))
    plt.imshow(augmented_image[0] / 255)
    plt.axis('off')

png

Pixelwerte neu skalieren

In einem Moment werden Sie den Download tf.keras.applications.MobileNetV2 zur Verwendung als Basismodell. Dieses Modell erwartet Pixelwerte in [-1, 1] , aber an diesem Punkt wurde die Pixelwerte in Ihren Bildern sind in [0, 255] . Um sie neu zu skalieren, verwenden Sie die im Modell enthaltene Vorverarbeitungsmethode.

preprocess_input = tf.keras.applications.mobilenet_v2.preprocess_input
rescale = tf.keras.layers.experimental.preprocessing.Rescaling(1./127.5, offset= -1)

Erstellen Sie das Basismodell aus den vortrainierten Convnets

Sie werden das Basismodell aus dem MobileNet V2 Modell bei Google entwickelt erstellen. Dies wird auf dem ImageNet-Datensatz vortrainiert, einem großen Datensatz, der aus 1,4 Millionen Bildern und 1000 Klassen besteht. IMAGEnet ist ein Forschungs - Trainingsdaten mit einer Vielzahl von Kategorien wie jackfruit und syringe . Diese Wissensbasis wird uns helfen, Katzen und Hunde aus unserem spezifischen Datensatz zu klassifizieren.

Zuerst müssen Sie auswählen, welche Schicht von MobileNet V2 Sie für die Funktionsextraktion verwenden möchten. Die allerletzte Klassifikationsschicht (auf "oben", da die meisten Diagramme von Modellen des maschinellen Lernens von unten nach oben verlaufen) ist nicht sehr nützlich. Stattdessen folgen Sie der üblichen Praxis, sich vor dem Abflachen auf die allerletzte Schicht zu verlassen. Diese Schicht wird als "Flaschenhalsschicht" bezeichnet. Die Merkmale der Engpassschicht behalten im Vergleich zur letzten/obersten Schicht eine größere Allgemeingültigkeit.

Instanziieren Sie zunächst ein MobileNet V2-Modell, das mit auf ImageNet trainierten Gewichtungen vorinstalliert ist. Durch die Angabe des include_top = False Argument, laden Sie ein Netzwerk , das nicht die Klassifizierungs Schichten an der Spitze enthält, die ideal für die Merkmalsextraktion ist.

# Create the base model from the pre-trained model MobileNet V2
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,
                                               include_top=False,
                                               weights='imagenet')
Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/mobilenet_v2/mobilenet_v2_weights_tf_dim_ordering_tf_kernels_1.0_160_no_top.h5
9412608/9406464 [==============================] - 0s 0us/step

Dieser Merkmal - Extraktor wandelt jedes 160x160x3 Bild in einen 5x5x1280 Block von Funktionen. Sehen wir uns an, was es mit einem Beispielstapel von Bildern macht:

image_batch, label_batch = next(iter(train_dataset))
feature_batch = base_model(image_batch)
print(feature_batch.shape)
(32, 5, 5, 1280)

Merkmalsextraktion

In diesem Schritt frieren Sie die im vorherigen Schritt erstellte Faltungsbasis ein und verwenden sie als Feature-Extraktor. Darüber hinaus fügen Sie darüber einen Klassifikator hinzu und trainieren den Klassifikator der obersten Ebene.

Friere die Faltungsbasis ein

Es ist wichtig, die Faltungsbasis einzufrieren, bevor Sie das Modell kompilieren und trainieren. Das Einfrieren (durch Setzen von layer.trainable = False) verhindert, dass die Gewichtungen in einem bestimmten Layer während des Trainings aktualisiert werden. MobileNet V2 hat viele Schichten, so dass das gesamte Modell der Einstellung trainable Flag auf False werden alle von ihnen einfrieren.

base_model.trainable = False

Wichtiger Hinweis zu BatchNormalization-Layern

Viele Modelle enthalten tf.keras.layers.BatchNormalization Schichten. Dieser Layer ist ein Sonderfall und im Rahmen der Feinabstimmung sollten Vorsichtsmaßnahmen getroffen werden, wie später in diesem Tutorial gezeigt.

Wenn Sie setzen layer.trainable = False , die BatchNormalization wird Schicht in Inferenz - Modus ausgeführt, und nicht seine Mittel und die Varianz Statistiken aktualisieren.

Wenn Sie ein Modell aufzutauen , die BatchNormalization Schichten enthält , um die Feinabstimmung zu tun, sollten Sie die BatchNormalization Schichten in Inferenz - Modus halten , indem training = False , wenn das Basismodell aufrufen. Andernfalls zerstören die Aktualisierungen, die auf die nicht trainierbaren Gewichtungen angewendet werden, das, was das Modell gelernt hat.

Weitere Einzelheiten finden Sie in der Übertragung Lernbegleiter .

# Let's take a look at the base model architecture
base_model.summary()
Model: "mobilenetv2_1.00_160"
__________________________________________________________________________________________________
Layer (type)                    Output Shape         Param #     Connected to                     
==================================================================================================
input_1 (InputLayer)            [(None, 160, 160, 3) 0                                            
__________________________________________________________________________________________________
Conv1 (Conv2D)                  (None, 80, 80, 32)   864         input_1[0][0]                    
__________________________________________________________________________________________________
bn_Conv1 (BatchNormalization)   (None, 80, 80, 32)   128         Conv1[0][0]                      
__________________________________________________________________________________________________
Conv1_relu (ReLU)               (None, 80, 80, 32)   0           bn_Conv1[0][0]                   
__________________________________________________________________________________________________
expanded_conv_depthwise (Depthw (None, 80, 80, 32)   288         Conv1_relu[0][0]                 
__________________________________________________________________________________________________
expanded_conv_depthwise_BN (Bat (None, 80, 80, 32)   128         expanded_conv_depthwise[0][0]    
__________________________________________________________________________________________________
expanded_conv_depthwise_relu (R (None, 80, 80, 32)   0           expanded_conv_depthwise_BN[0][0] 
__________________________________________________________________________________________________
expanded_conv_project (Conv2D)  (None, 80, 80, 16)   512         expanded_conv_depthwise_relu[0][0
__________________________________________________________________________________________________
expanded_conv_project_BN (Batch (None, 80, 80, 16)   64          expanded_conv_project[0][0]      
__________________________________________________________________________________________________
block_1_expand (Conv2D)         (None, 80, 80, 96)   1536        expanded_conv_project_BN[0][0]   
__________________________________________________________________________________________________
block_1_expand_BN (BatchNormali (None, 80, 80, 96)   384         block_1_expand[0][0]             
__________________________________________________________________________________________________
block_1_expand_relu (ReLU)      (None, 80, 80, 96)   0           block_1_expand_BN[0][0]          
__________________________________________________________________________________________________
block_1_pad (ZeroPadding2D)     (None, 81, 81, 96)   0           block_1_expand_relu[0][0]        
__________________________________________________________________________________________________
block_1_depthwise (DepthwiseCon (None, 40, 40, 96)   864         block_1_pad[0][0]                
__________________________________________________________________________________________________
block_1_depthwise_BN (BatchNorm (None, 40, 40, 96)   384         block_1_depthwise[0][0]          
__________________________________________________________________________________________________
block_1_depthwise_relu (ReLU)   (None, 40, 40, 96)   0           block_1_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_1_project (Conv2D)        (None, 40, 40, 24)   2304        block_1_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_1_project_BN (BatchNormal (None, 40, 40, 24)   96          block_1_project[0][0]            
__________________________________________________________________________________________________
block_2_expand (Conv2D)         (None, 40, 40, 144)  3456        block_1_project_BN[0][0]         
__________________________________________________________________________________________________
block_2_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_2_expand[0][0]             
__________________________________________________________________________________________________
block_2_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_2_expand_BN[0][0]          
__________________________________________________________________________________________________
block_2_depthwise (DepthwiseCon (None, 40, 40, 144)  1296        block_2_expand_relu[0][0]        
__________________________________________________________________________________________________
block_2_depthwise_BN (BatchNorm (None, 40, 40, 144)  576         block_2_depthwise[0][0]          
__________________________________________________________________________________________________
block_2_depthwise_relu (ReLU)   (None, 40, 40, 144)  0           block_2_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_2_project (Conv2D)        (None, 40, 40, 24)   3456        block_2_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_2_project_BN (BatchNormal (None, 40, 40, 24)   96          block_2_project[0][0]            
__________________________________________________________________________________________________
block_2_add (Add)               (None, 40, 40, 24)   0           block_1_project_BN[0][0]         
                                                                 block_2_project_BN[0][0]         
__________________________________________________________________________________________________
block_3_expand (Conv2D)         (None, 40, 40, 144)  3456        block_2_add[0][0]                
__________________________________________________________________________________________________
block_3_expand_BN (BatchNormali (None, 40, 40, 144)  576         block_3_expand[0][0]             
__________________________________________________________________________________________________
block_3_expand_relu (ReLU)      (None, 40, 40, 144)  0           block_3_expand_BN[0][0]          
__________________________________________________________________________________________________
block_3_pad (ZeroPadding2D)     (None, 41, 41, 144)  0           block_3_expand_relu[0][0]        
__________________________________________________________________________________________________
block_3_depthwise (DepthwiseCon (None, 20, 20, 144)  1296        block_3_pad[0][0]                
__________________________________________________________________________________________________
block_3_depthwise_BN (BatchNorm (None, 20, 20, 144)  576         block_3_depthwise[0][0]          
__________________________________________________________________________________________________
block_3_depthwise_relu (ReLU)   (None, 20, 20, 144)  0           block_3_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_3_project (Conv2D)        (None, 20, 20, 32)   4608        block_3_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_3_project_BN (BatchNormal (None, 20, 20, 32)   128         block_3_project[0][0]            
__________________________________________________________________________________________________
block_4_expand (Conv2D)         (None, 20, 20, 192)  6144        block_3_project_BN[0][0]         
__________________________________________________________________________________________________
block_4_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_4_expand[0][0]             
__________________________________________________________________________________________________
block_4_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_4_expand_BN[0][0]          
__________________________________________________________________________________________________
block_4_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_4_expand_relu[0][0]        
__________________________________________________________________________________________________
block_4_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_4_depthwise[0][0]          
__________________________________________________________________________________________________
block_4_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_4_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_4_project (Conv2D)        (None, 20, 20, 32)   6144        block_4_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_4_project_BN (BatchNormal (None, 20, 20, 32)   128         block_4_project[0][0]            
__________________________________________________________________________________________________
block_4_add (Add)               (None, 20, 20, 32)   0           block_3_project_BN[0][0]         
                                                                 block_4_project_BN[0][0]         
__________________________________________________________________________________________________
block_5_expand (Conv2D)         (None, 20, 20, 192)  6144        block_4_add[0][0]                
__________________________________________________________________________________________________
block_5_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_5_expand[0][0]             
__________________________________________________________________________________________________
block_5_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_5_expand_BN[0][0]          
__________________________________________________________________________________________________
block_5_depthwise (DepthwiseCon (None, 20, 20, 192)  1728        block_5_expand_relu[0][0]        
__________________________________________________________________________________________________
block_5_depthwise_BN (BatchNorm (None, 20, 20, 192)  768         block_5_depthwise[0][0]          
__________________________________________________________________________________________________
block_5_depthwise_relu (ReLU)   (None, 20, 20, 192)  0           block_5_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_5_project (Conv2D)        (None, 20, 20, 32)   6144        block_5_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_5_project_BN (BatchNormal (None, 20, 20, 32)   128         block_5_project[0][0]            
__________________________________________________________________________________________________
block_5_add (Add)               (None, 20, 20, 32)   0           block_4_add[0][0]                
                                                                 block_5_project_BN[0][0]         
__________________________________________________________________________________________________
block_6_expand (Conv2D)         (None, 20, 20, 192)  6144        block_5_add[0][0]                
__________________________________________________________________________________________________
block_6_expand_BN (BatchNormali (None, 20, 20, 192)  768         block_6_expand[0][0]             
__________________________________________________________________________________________________
block_6_expand_relu (ReLU)      (None, 20, 20, 192)  0           block_6_expand_BN[0][0]          
__________________________________________________________________________________________________
block_6_pad (ZeroPadding2D)     (None, 21, 21, 192)  0           block_6_expand_relu[0][0]        
__________________________________________________________________________________________________
block_6_depthwise (DepthwiseCon (None, 10, 10, 192)  1728        block_6_pad[0][0]                
__________________________________________________________________________________________________
block_6_depthwise_BN (BatchNorm (None, 10, 10, 192)  768         block_6_depthwise[0][0]          
__________________________________________________________________________________________________
block_6_depthwise_relu (ReLU)   (None, 10, 10, 192)  0           block_6_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_6_project (Conv2D)        (None, 10, 10, 64)   12288       block_6_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_6_project_BN (BatchNormal (None, 10, 10, 64)   256         block_6_project[0][0]            
__________________________________________________________________________________________________
block_7_expand (Conv2D)         (None, 10, 10, 384)  24576       block_6_project_BN[0][0]         
__________________________________________________________________________________________________
block_7_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_7_expand[0][0]             
__________________________________________________________________________________________________
block_7_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_7_expand_BN[0][0]          
__________________________________________________________________________________________________
block_7_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_7_expand_relu[0][0]        
__________________________________________________________________________________________________
block_7_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_7_depthwise[0][0]          
__________________________________________________________________________________________________
block_7_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_7_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_7_project (Conv2D)        (None, 10, 10, 64)   24576       block_7_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_7_project_BN (BatchNormal (None, 10, 10, 64)   256         block_7_project[0][0]            
__________________________________________________________________________________________________
block_7_add (Add)               (None, 10, 10, 64)   0           block_6_project_BN[0][0]         
                                                                 block_7_project_BN[0][0]         
__________________________________________________________________________________________________
block_8_expand (Conv2D)         (None, 10, 10, 384)  24576       block_7_add[0][0]                
__________________________________________________________________________________________________
block_8_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_8_expand[0][0]             
__________________________________________________________________________________________________
block_8_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_8_expand_BN[0][0]          
__________________________________________________________________________________________________
block_8_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_8_expand_relu[0][0]        
__________________________________________________________________________________________________
block_8_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_8_depthwise[0][0]          
__________________________________________________________________________________________________
block_8_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_8_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_8_project (Conv2D)        (None, 10, 10, 64)   24576       block_8_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_8_project_BN (BatchNormal (None, 10, 10, 64)   256         block_8_project[0][0]            
__________________________________________________________________________________________________
block_8_add (Add)               (None, 10, 10, 64)   0           block_7_add[0][0]                
                                                                 block_8_project_BN[0][0]         
__________________________________________________________________________________________________
block_9_expand (Conv2D)         (None, 10, 10, 384)  24576       block_8_add[0][0]                
__________________________________________________________________________________________________
block_9_expand_BN (BatchNormali (None, 10, 10, 384)  1536        block_9_expand[0][0]             
__________________________________________________________________________________________________
block_9_expand_relu (ReLU)      (None, 10, 10, 384)  0           block_9_expand_BN[0][0]          
__________________________________________________________________________________________________
block_9_depthwise (DepthwiseCon (None, 10, 10, 384)  3456        block_9_expand_relu[0][0]        
__________________________________________________________________________________________________
block_9_depthwise_BN (BatchNorm (None, 10, 10, 384)  1536        block_9_depthwise[0][0]          
__________________________________________________________________________________________________
block_9_depthwise_relu (ReLU)   (None, 10, 10, 384)  0           block_9_depthwise_BN[0][0]       
__________________________________________________________________________________________________
block_9_project (Conv2D)        (None, 10, 10, 64)   24576       block_9_depthwise_relu[0][0]     
__________________________________________________________________________________________________
block_9_project_BN (BatchNormal (None, 10, 10, 64)   256         block_9_project[0][0]            
__________________________________________________________________________________________________
block_9_add (Add)               (None, 10, 10, 64)   0           block_8_add[0][0]                
                                                                 block_9_project_BN[0][0]         
__________________________________________________________________________________________________
block_10_expand (Conv2D)        (None, 10, 10, 384)  24576       block_9_add[0][0]                
__________________________________________________________________________________________________
block_10_expand_BN (BatchNormal (None, 10, 10, 384)  1536        block_10_expand[0][0]            
__________________________________________________________________________________________________
block_10_expand_relu (ReLU)     (None, 10, 10, 384)  0           block_10_expand_BN[0][0]         
__________________________________________________________________________________________________
block_10_depthwise (DepthwiseCo (None, 10, 10, 384)  3456        block_10_expand_relu[0][0]       
__________________________________________________________________________________________________
block_10_depthwise_BN (BatchNor (None, 10, 10, 384)  1536        block_10_depthwise[0][0]         
__________________________________________________________________________________________________
block_10_depthwise_relu (ReLU)  (None, 10, 10, 384)  0           block_10_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_10_project (Conv2D)       (None, 10, 10, 96)   36864       block_10_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_10_project_BN (BatchNorma (None, 10, 10, 96)   384         block_10_project[0][0]           
__________________________________________________________________________________________________
block_11_expand (Conv2D)        (None, 10, 10, 576)  55296       block_10_project_BN[0][0]        
__________________________________________________________________________________________________
block_11_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_11_expand[0][0]            
__________________________________________________________________________________________________
block_11_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_11_expand_BN[0][0]         
__________________________________________________________________________________________________
block_11_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_11_expand_relu[0][0]       
__________________________________________________________________________________________________
block_11_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_11_depthwise[0][0]         
__________________________________________________________________________________________________
block_11_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_11_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_11_project (Conv2D)       (None, 10, 10, 96)   55296       block_11_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_11_project_BN (BatchNorma (None, 10, 10, 96)   384         block_11_project[0][0]           
__________________________________________________________________________________________________
block_11_add (Add)              (None, 10, 10, 96)   0           block_10_project_BN[0][0]        
                                                                 block_11_project_BN[0][0]        
__________________________________________________________________________________________________
block_12_expand (Conv2D)        (None, 10, 10, 576)  55296       block_11_add[0][0]               
__________________________________________________________________________________________________
block_12_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_12_expand[0][0]            
__________________________________________________________________________________________________
block_12_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_12_expand_BN[0][0]         
__________________________________________________________________________________________________
block_12_depthwise (DepthwiseCo (None, 10, 10, 576)  5184        block_12_expand_relu[0][0]       
__________________________________________________________________________________________________
block_12_depthwise_BN (BatchNor (None, 10, 10, 576)  2304        block_12_depthwise[0][0]         
__________________________________________________________________________________________________
block_12_depthwise_relu (ReLU)  (None, 10, 10, 576)  0           block_12_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_12_project (Conv2D)       (None, 10, 10, 96)   55296       block_12_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_12_project_BN (BatchNorma (None, 10, 10, 96)   384         block_12_project[0][0]           
__________________________________________________________________________________________________
block_12_add (Add)              (None, 10, 10, 96)   0           block_11_add[0][0]               
                                                                 block_12_project_BN[0][0]        
__________________________________________________________________________________________________
block_13_expand (Conv2D)        (None, 10, 10, 576)  55296       block_12_add[0][0]               
__________________________________________________________________________________________________
block_13_expand_BN (BatchNormal (None, 10, 10, 576)  2304        block_13_expand[0][0]            
__________________________________________________________________________________________________
block_13_expand_relu (ReLU)     (None, 10, 10, 576)  0           block_13_expand_BN[0][0]         
__________________________________________________________________________________________________
block_13_pad (ZeroPadding2D)    (None, 11, 11, 576)  0           block_13_expand_relu[0][0]       
__________________________________________________________________________________________________
block_13_depthwise (DepthwiseCo (None, 5, 5, 576)    5184        block_13_pad[0][0]               
__________________________________________________________________________________________________
block_13_depthwise_BN (BatchNor (None, 5, 5, 576)    2304        block_13_depthwise[0][0]         
__________________________________________________________________________________________________
block_13_depthwise_relu (ReLU)  (None, 5, 5, 576)    0           block_13_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_13_project (Conv2D)       (None, 5, 5, 160)    92160       block_13_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_13_project_BN (BatchNorma (None, 5, 5, 160)    640         block_13_project[0][0]           
__________________________________________________________________________________________________
block_14_expand (Conv2D)        (None, 5, 5, 960)    153600      block_13_project_BN[0][0]        
__________________________________________________________________________________________________
block_14_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_14_expand[0][0]            
__________________________________________________________________________________________________
block_14_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_14_expand_BN[0][0]         
__________________________________________________________________________________________________
block_14_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_14_expand_relu[0][0]       
__________________________________________________________________________________________________
block_14_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_14_depthwise[0][0]         
__________________________________________________________________________________________________
block_14_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_14_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_14_project (Conv2D)       (None, 5, 5, 160)    153600      block_14_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_14_project_BN (BatchNorma (None, 5, 5, 160)    640         block_14_project[0][0]           
__________________________________________________________________________________________________
block_14_add (Add)              (None, 5, 5, 160)    0           block_13_project_BN[0][0]        
                                                                 block_14_project_BN[0][0]        
__________________________________________________________________________________________________
block_15_expand (Conv2D)        (None, 5, 5, 960)    153600      block_14_add[0][0]               
__________________________________________________________________________________________________
block_15_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_15_expand[0][0]            
__________________________________________________________________________________________________
block_15_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_15_expand_BN[0][0]         
__________________________________________________________________________________________________
block_15_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_15_expand_relu[0][0]       
__________________________________________________________________________________________________
block_15_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_15_depthwise[0][0]         
__________________________________________________________________________________________________
block_15_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_15_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_15_project (Conv2D)       (None, 5, 5, 160)    153600      block_15_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_15_project_BN (BatchNorma (None, 5, 5, 160)    640         block_15_project[0][0]           
__________________________________________________________________________________________________
block_15_add (Add)              (None, 5, 5, 160)    0           block_14_add[0][0]               
                                                                 block_15_project_BN[0][0]        
__________________________________________________________________________________________________
block_16_expand (Conv2D)        (None, 5, 5, 960)    153600      block_15_add[0][0]               
__________________________________________________________________________________________________
block_16_expand_BN (BatchNormal (None, 5, 5, 960)    3840        block_16_expand[0][0]            
__________________________________________________________________________________________________
block_16_expand_relu (ReLU)     (None, 5, 5, 960)    0           block_16_expand_BN[0][0]         
__________________________________________________________________________________________________
block_16_depthwise (DepthwiseCo (None, 5, 5, 960)    8640        block_16_expand_relu[0][0]       
__________________________________________________________________________________________________
block_16_depthwise_BN (BatchNor (None, 5, 5, 960)    3840        block_16_depthwise[0][0]         
__________________________________________________________________________________________________
block_16_depthwise_relu (ReLU)  (None, 5, 5, 960)    0           block_16_depthwise_BN[0][0]      
__________________________________________________________________________________________________
block_16_project (Conv2D)       (None, 5, 5, 320)    307200      block_16_depthwise_relu[0][0]    
__________________________________________________________________________________________________
block_16_project_BN (BatchNorma (None, 5, 5, 320)    1280        block_16_project[0][0]           
__________________________________________________________________________________________________
Conv_1 (Conv2D)                 (None, 5, 5, 1280)   409600      block_16_project_BN[0][0]        
__________________________________________________________________________________________________
Conv_1_bn (BatchNormalization)  (None, 5, 5, 1280)   5120        Conv_1[0][0]                     
__________________________________________________________________________________________________
out_relu (ReLU)                 (None, 5, 5, 1280)   0           Conv_1_bn[0][0]                  
==================================================================================================
Total params: 2,257,984
Trainable params: 0
Non-trainable params: 2,257,984
__________________________________________________________________________________________________

Klassifikationskopf hinzufügen

Vorhersagen zu erzeugen , aus dem Block der Merkmale, Durchschnitt der räumlichen 5x5 räumlichen Stellen, unter Verwendung einer tf.keras.layers.GlobalAveragePooling2D Schicht die Eigenschaften auf einen einzigen 1280 - Element - Vektor pro Bild zu konvertieren.

global_average_layer = tf.keras.layers.GlobalAveragePooling2D()
feature_batch_average = global_average_layer(feature_batch)
print(feature_batch_average.shape)
(32, 1280)

Tragen Sie eine tf.keras.layers.Dense Schicht diese Funktionen in einem einzigen Vorhersage pro Bild zu konvertieren. Sie benötigen keine Aktivierungsfunktion hier brauchen , weil diese Vorhersage wird als behandelt werden logit oder einen rohen Vorhersagewert. Positive Zahlen sagen Klasse 1 voraus, negative Zahlen sagen Klasse 0 voraus.

prediction_layer = tf.keras.layers.Dense(1)
prediction_batch = prediction_layer(feature_batch_average)
print(prediction_batch.shape)
(32, 1)

Baue ein Modell von der Daten Augmentation Verkettungs zusammen, neu zu skalieren, und base_model Merkmalsextraktor Schichten , die die Verwendung von Keras Functional API . Wie bereits erwähnt, verwenden Sie training=False, da unser Modell eine BatchNormalization-Schicht enthält.

inputs = tf.keras.Input(shape=(160, 160, 3))
x = data_augmentation(inputs)
x = preprocess_input(x)
x = base_model(x, training=False)
x = global_average_layer(x)
x = tf.keras.layers.Dropout(0.2)(x)
outputs = prediction_layer(x)
model = tf.keras.Model(inputs, outputs)

Kompilieren Sie das Modell

Kompilieren Sie das Modell, bevor Sie es trainieren. Da gibt es zwei Klassen sind, verwenden Sie eine binäre Quer Entropieverlust mit from_logits=True , da das Modell ein lineares Ausgangssignal zur Verfügung stellt.

base_learning_rate = 0.0001
model.compile(optimizer=tf.keras.optimizers.Adam(lr=base_learning_rate),
              loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              metrics=['accuracy'])
/tmpfs/src/tf_docs_env/lib/python3.7/site-packages/tensorflow/python/keras/optimizer_v2/optimizer_v2.py:375: UserWarning: The `lr` argument is deprecated, use `learning_rate` instead.
  "The `lr` argument is deprecated, use `learning_rate` instead.")
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,281
Non-trainable params: 2,257,984
_________________________________________________________________

Die 2,5 M Parameter in MobileNet sind eingefroren, aber es gibt 1.2K trainierbar Parameter in der dichten Schicht. Diese sind aufgeteilt zwischen zwei tf.Variable Objekten, die Gewichte und Vorspannungen.

len(model.trainable_variables)
2

Trainiere das Modell

Nach dem Training für 10 Epochen sollten Sie eine Genauigkeit von ~94 % auf dem Validierungssatz sehen.

initial_epochs = 10

loss0, accuracy0 = model.evaluate(validation_dataset)
26/26 [==============================] - 2s 16ms/step - loss: 1.0528 - accuracy: 0.5223
print("initial loss: {:.2f}".format(loss0))
print("initial accuracy: {:.2f}".format(accuracy0))
initial loss: 1.05
initial accuracy: 0.52
history = model.fit(train_dataset,
                    epochs=initial_epochs,
                    validation_data=validation_dataset)
Epoch 1/10
63/63 [==============================] - 4s 22ms/step - loss: 0.8021 - accuracy: 0.6015 - val_loss: 0.5626 - val_accuracy: 0.7166
Epoch 2/10
63/63 [==============================] - 1s 20ms/step - loss: 0.5231 - accuracy: 0.7185 - val_loss: 0.4190 - val_accuracy: 0.8243
Epoch 3/10
63/63 [==============================] - 1s 20ms/step - loss: 0.4331 - accuracy: 0.7880 - val_loss: 0.3245 - val_accuracy: 0.8738
Epoch 4/10
63/63 [==============================] - 1s 20ms/step - loss: 0.3672 - accuracy: 0.8330 - val_loss: 0.2675 - val_accuracy: 0.9010
Epoch 5/10
63/63 [==============================] - 1s 20ms/step - loss: 0.3213 - accuracy: 0.8660 - val_loss: 0.2280 - val_accuracy: 0.9257
Epoch 6/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2867 - accuracy: 0.8750 - val_loss: 0.1962 - val_accuracy: 0.9369
Epoch 7/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2659 - accuracy: 0.8920 - val_loss: 0.1704 - val_accuracy: 0.9517
Epoch 8/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2401 - accuracy: 0.8990 - val_loss: 0.1520 - val_accuracy: 0.9542
Epoch 9/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2368 - accuracy: 0.8985 - val_loss: 0.1384 - val_accuracy: 0.9592
Epoch 10/10
63/63 [==============================] - 1s 20ms/step - loss: 0.2096 - accuracy: 0.9140 - val_loss: 0.1312 - val_accuracy: 0.9592

Lernkurven

Werfen wir einen Blick auf die Lernkurven der Trainings- und Validierungsgenauigkeit/-verlust, wenn das MobileNet V2-Basismodell als Extraktor fester Funktionen verwendet wird.

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.ylabel('Accuracy')
plt.ylim([min(plt.ylim()),1])
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.ylabel('Cross Entropy')
plt.ylim([0,1.0])
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

In geringerem Maße liegt dies auch daran, dass Trainingsmetriken den Durchschnitt für eine Epoche angeben, während Validierungsmetriken nach der Epoche ausgewertet werden, sodass Validierungsmetriken ein Modell sehen, das etwas länger trainiert wurde.

Feintuning

Im Feature-Extraktionsexperiment haben Sie nur einige Layer auf einem MobileNet V2-Basismodell trainiert. Die Gewichte der vortrainiert Netzwerk wurden nicht während des Trainings aktualisiert.

Eine Möglichkeit, die Leistung noch weiter zu steigern, besteht darin, die Gewichtungen der obersten Schichten des vortrainierten Modells neben dem Training des von Ihnen hinzugefügten Klassifikators zu trainieren (oder "fein abzustimmen"). Der Trainingsprozess erzwingt, dass die Gewichtungen von generischen Feature-Maps auf Features abgestimmt werden, die speziell mit dem Dataset verknüpft sind.

Außerdem sollten Sie versuchen, eine kleine Anzahl von Top-Layer zu optimieren, anstatt das gesamte MobileNet-Modell. In den meisten Faltungsnetzwerken gilt: Je höher eine Schicht ist, desto spezialisierter ist sie. Die ersten paar Ebenen lernen sehr einfache und generische Funktionen, die sich auf fast alle Arten von Bildern verallgemeinern lassen. Je höher Sie gehen, desto spezifischer werden die Features für das Dataset, auf dem das Modell trainiert wurde. Das Ziel der Feinabstimmung besteht darin, diese speziellen Funktionen an die Arbeit mit dem neuen Datensatz anzupassen, anstatt das generische Lernen zu überschreiben.

Entfrosten Sie die oberen Schichten des Modells

Alles , was Sie tun müssen , ist auftauen die base_model und die unteren Schichten gesetzt untrainierbar zu sein. Anschließend sollten Sie das Modell neu kompilieren (erforderlich, damit diese Änderungen wirksam werden) und das Training wieder aufnehmen.

base_model.trainable = True
# Let's take a look to see how many layers are in the base model
print("Number of layers in the base model: ", len(base_model.layers))

# Fine-tune from this layer onwards
fine_tune_at = 100

# Freeze all the layers before the `fine_tune_at` layer
for layer in base_model.layers[:fine_tune_at]:
  layer.trainable =  False
Number of layers in the base model:  154

Kompilieren Sie das Modell

Da Sie ein viel größeres Modell trainieren und die vortrainierten Gewichte neu anpassen möchten, ist es wichtig, in dieser Phase eine niedrigere Lernrate zu verwenden. Andernfalls könnte Ihr Modell sehr schnell überanpassungen.

model.compile(loss=tf.keras.losses.BinaryCrossentropy(from_logits=True),
              optimizer = tf.keras.optimizers.RMSprop(lr=base_learning_rate/10),
              metrics=['accuracy'])
model.summary()
Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
input_2 (InputLayer)         [(None, 160, 160, 3)]     0         
_________________________________________________________________
sequential (Sequential)      (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.truediv (TFOpLambda) (None, 160, 160, 3)       0         
_________________________________________________________________
tf.math.subtract (TFOpLambda (None, 160, 160, 3)       0         
_________________________________________________________________
mobilenetv2_1.00_160 (Functi (None, 5, 5, 1280)        2257984   
_________________________________________________________________
global_average_pooling2d (Gl (None, 1280)              0         
_________________________________________________________________
dropout (Dropout)            (None, 1280)              0         
_________________________________________________________________
dense (Dense)                (None, 1)                 1281      
=================================================================
Total params: 2,259,265
Trainable params: 1,862,721
Non-trainable params: 396,544
_________________________________________________________________
len(model.trainable_variables)
56

Trainieren Sie das Modell weiter

Wenn Sie früher auf Konvergenz trainiert haben, verbessert dieser Schritt Ihre Genauigkeit um einige Prozentpunkte.

fine_tune_epochs = 10
total_epochs =  initial_epochs + fine_tune_epochs

history_fine = model.fit(train_dataset,
                         epochs=total_epochs,
                         initial_epoch=history.epoch[-1],
                         validation_data=validation_dataset)
Epoch 10/20
63/63 [==============================] - 6s 38ms/step - loss: 0.1514 - accuracy: 0.9375 - val_loss: 0.0548 - val_accuracy: 0.9740
Epoch 11/20
63/63 [==============================] - 2s 27ms/step - loss: 0.1262 - accuracy: 0.9490 - val_loss: 0.0597 - val_accuracy: 0.9691
Epoch 12/20
63/63 [==============================] - 2s 26ms/step - loss: 0.1073 - accuracy: 0.9595 - val_loss: 0.0444 - val_accuracy: 0.9777
Epoch 13/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0997 - accuracy: 0.9590 - val_loss: 0.0437 - val_accuracy: 0.9777
Epoch 14/20
63/63 [==============================] - 2s 26ms/step - loss: 0.0983 - accuracy: 0.9610 - val_loss: 0.0466 - val_accuracy: 0.9790
Epoch 15/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0715 - accuracy: 0.9740 - val_loss: 0.0378 - val_accuracy: 0.9814
Epoch 16/20
63/63 [==============================] - 2s 26ms/step - loss: 0.0753 - accuracy: 0.9720 - val_loss: 0.0465 - val_accuracy: 0.9765
Epoch 17/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0727 - accuracy: 0.9710 - val_loss: 0.0348 - val_accuracy: 0.9827
Epoch 18/20
63/63 [==============================] - 2s 26ms/step - loss: 0.0638 - accuracy: 0.9715 - val_loss: 0.0393 - val_accuracy: 0.9839
Epoch 19/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0552 - accuracy: 0.9755 - val_loss: 0.0349 - val_accuracy: 0.9851
Epoch 20/20
63/63 [==============================] - 2s 27ms/step - loss: 0.0605 - accuracy: 0.9730 - val_loss: 0.0419 - val_accuracy: 0.9864

Werfen wir einen Blick auf die Lernkurven der Trainings- und Validierungsgenauigkeit/-verlust bei der Feinabstimmung der letzten Schichten des MobileNet V2-Basismodells und dem darauf aufbauenden Training des Klassifikators. Der Validierungsverlust ist viel höher als der Trainingsverlust, daher kann es zu einer Überanpassung kommen.

Es kann auch zu einer Überanpassung kommen, da der neue Trainingssatz relativ klein ist und den ursprünglichen MobileNet V2-Datensätzen ähnelt.

Nach der Feinabstimmung erreicht das Modell auf dem Validierungssatz eine Genauigkeit von fast 98%.

acc += history_fine.history['accuracy']
val_acc += history_fine.history['val_accuracy']

loss += history_fine.history['loss']
val_loss += history_fine.history['val_loss']
plt.figure(figsize=(8, 8))
plt.subplot(2, 1, 1)
plt.plot(acc, label='Training Accuracy')
plt.plot(val_acc, label='Validation Accuracy')
plt.ylim([0.8, 1])
plt.plot([initial_epochs-1,initial_epochs-1],
          plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(2, 1, 2)
plt.plot(loss, label='Training Loss')
plt.plot(val_loss, label='Validation Loss')
plt.ylim([0, 1.0])
plt.plot([initial_epochs-1,initial_epochs-1],
         plt.ylim(), label='Start Fine Tuning')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.xlabel('epoch')
plt.show()

png

Auswertung und Vorhersage

Schließlich können Sie die Leistung des Modells mit neuen Daten mit einem Testsatz überprüfen.

loss, accuracy = model.evaluate(test_dataset)
print('Test accuracy :', accuracy)
6/6 [==============================] - 0s 13ms/step - loss: 0.0438 - accuracy: 0.9792
Test accuracy : 0.9791666865348816

Und jetzt können Sie dieses Modell verwenden, um vorherzusagen, ob Ihr Haustier eine Katze oder ein Hund ist.

#Retrieve a batch of images from the test set
image_batch, label_batch = test_dataset.as_numpy_iterator().next()
predictions = model.predict_on_batch(image_batch).flatten()

# Apply a sigmoid since our model returns logits
predictions = tf.nn.sigmoid(predictions)
predictions = tf.where(predictions < 0.5, 0, 1)

print('Predictions:\n', predictions.numpy())
print('Labels:\n', label_batch)

plt.figure(figsize=(10, 10))
for i in range(9):
  ax = plt.subplot(3, 3, i + 1)
  plt.imshow(image_batch[i].astype("uint8"))
  plt.title(class_names[predictions[i]])
  plt.axis("off")
Predictions:
 [0 1 0 0 1 0 1 1 0 0 0 0 0 1 1 0 1 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1]
Labels:
 [0 1 0 0 0 0 1 1 0 0 0 0 0 1 0 0 1 0 1 1 1 0 1 1 1 1 1 1 0 1 0 1]

png

Zusammenfassung

  • Ein vortrainiert Modell zur Merkmalsextraktion unter Verwendung: Wenn mit einem kleinen Datenmenge arbeitet, ist es eine gängige Praxis Vorteil von Merkmalen durch ein Modell trainiert auf einer größere Datenmenge in der gleichen Domäne gelernt zu nehmen. Dies geschieht durch Instanziieren des vortrainierten Modells und Hinzufügen eines vollständig verbundenen Klassifikators darüber. Das vortrainierte Modell wird "eingefroren" und nur die Gewichte des Klassifikators werden während des Trainings aktualisiert. In diesem Fall hat die Faltungsbasis alle mit jedem Bild verknüpften Merkmale extrahiert und Sie haben gerade einen Klassifikator trainiert, der die Bildklasse anhand dieses Satzes extrahierter Merkmale bestimmt.

  • Feinabstimmung ein vortrainiert Modells: Um die Leistung weiter zu verbessern, könnte man will die obersten Ebene der vortrainiert Modelle auf den neuen Datensatz über die Feinabstimmung umfunktionieren. In diesem Fall haben Sie Ihre Gewichtungen so abgestimmt, dass Ihr Modell für das Dataset spezifische High-Level-Features gelernt hat. Diese Technik wird normalerweise empfohlen, wenn das Trainings-Dataset groß ist und dem ursprünglichen Dataset, auf dem das vortrainierte Modell trainiert wurde, sehr ähnlich ist.

Um mehr zu erfahren, besuchen Sie die Übertragung Lernbegleiter .

# MIT License
#
# Copyright (c) 2017 François Chollet                                                                                                                    # IGNORE_COPYRIGHT: cleared by OSS licensing
#
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the "Software"),
# to deal in the Software without restriction, including without limitation
# the rights to use, copy, modify, merge, publish, distribute, sublicense,
# and/or sell copies of the Software, and to permit persons to whom the
# Software is furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
# FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
# DEALINGS IN THE SOFTWARE.