Modèles de points de contrôle

La possibilité de sauvegarder et de restaurer l'état d'un modèle est vitale pour un certain nombre d'applications, telles que l'apprentissage par transfert ou la réalisation d'inférences à l'aide de modèles pré-entraînés. L'enregistrement des paramètres d'un modèle (poids, biais, etc.) dans un fichier ou un répertoire de points de contrôle est un moyen d'y parvenir.

Ce module fournit une interface de haut niveau pour charger et enregistrer les points de contrôle du format TensorFlow v2 , ainsi que des composants de niveau inférieur qui écrivent et lisent à partir de ce format de fichier.

Chargement et sauvegarde de modèles simples

En se conformant au protocole Checkpointable , de nombreux modèles simples peuvent être sérialisés en points de contrôle sans aucun code supplémentaire :

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

et ensuite ce même point de contrôle peut être lu en utilisant :

try model.readCheckpoint(from: directory, name: "LeNet")

Cette implémentation par défaut pour le chargement et l'enregistrement du modèle utilisera un schéma de dénomination basé sur le chemin pour chaque tenseur du modèle, basé sur les noms des propriétés dans les structures du modèle. Par exemple, les poids et les biais au sein de la première convolution du modèle LeNet-5 seront enregistrés respectivement sous les noms conv1/filter et conv1/bias . Lors du chargement, le lecteur de point de contrôle recherchera les tenseurs portant ces noms.

Personnalisation du chargement et de l'enregistrement du modèle

Si vous souhaitez avoir un meilleur contrôle sur les tenseurs enregistrés et chargés, ou sur la dénomination de ces tenseurs, le protocole Checkpointable propose quelques points de personnalisation.

Pour ignorer les propriétés sur certains types, vous pouvez fournir une implémentation de ignoredTensorPaths sur votre modèle qui renvoie un ensemble de chaînes sous la forme de Type.property . Par exemple, pour ignorer la propriété scale sur chaque couche Attention, vous pouvez renvoyer ["Attention.scale"] .

Par défaut, une barre oblique est utilisée pour séparer chaque niveau plus profond dans un modèle. Cela peut être personnalisé en implémentant checkpointSeparator sur votre modèle et en fournissant une nouvelle chaîne à utiliser pour ce séparateur.

Enfin, pour le plus grand degré de personnalisation de la dénomination du tenseur, vous pouvez implémenter tensorNameMap et fournir une fonction qui mappe le nom de chaîne par défaut généré pour un tenseur dans le modèle vers un nom de chaîne souhaité dans le point de contrôle. Le plus souvent, cela sera utilisé pour interagir avec des points de contrôle générés avec d'autres frameworks, chacun ayant ses propres conventions de dénomination et structures de modèle. Une fonction de mappage personnalisée offre le plus grand degré de personnalisation de la façon dont ces tenseurs sont nommés.

Certaines fonctions d'assistance standard sont fournies, comme la CheckpointWriter.identityMap par défaut (qui utilise simplement le nom de chemin du tenseur généré automatiquement pour les points de contrôle), ou la fonction CheckpointWriter.lookupMap(table:) , qui peut créer un mappage à partir d'un dictionnaire.

Pour un exemple de la façon dont un mappage personnalisé peut être réalisé, veuillez consulter le modèle GPT-2 , qui utilise une fonction de mappage pour correspondre au schéma de dénomination exact utilisé pour les points de contrôle d'OpenAI.

Les composants CheckpointReader et CheckpointWriter

Pour l'écriture de points de contrôle, l'extension fournie par le protocole Checkpointable utilise la réflexion et les chemins de clés pour parcourir les propriétés d'un modèle et générer un dictionnaire qui mappe les chemins du tenseur de chaîne aux valeurs du tenseur. Ce dictionnaire est fourni à un CheckpointWriter sous-jacent, avec un répertoire dans lequel écrire le point de contrôle. Ce CheckpointWriter gère la tâche de génération du point de contrôle sur le disque à partir de ce dictionnaire.

L'inverse de ce processus est la lecture, où un CheckpointReader reçoit l'emplacement d'un répertoire de point de contrôle sur le disque. Il lit ensuite à partir de ce point de contrôle et forme un dictionnaire qui mappe les noms des tenseurs au sein du point de contrôle avec leurs valeurs enregistrées. Ce dictionnaire est utilisé pour remplacer les tenseurs actuels d'un modèle par ceux de ce dictionnaire.

Pour le chargement et l'enregistrement, le protocole Checkpointable mappe les chemins de chaîne vers les tenseurs aux noms de tenseurs correspondants sur le disque à l'aide de la fonction de mappage décrite ci-dessus.

Si le protocole Checkpointable ne dispose pas des fonctionnalités nécessaires ou si un contrôle accru est souhaité sur le processus de chargement et de sauvegarde des points de contrôle, les classes CheckpointReader et CheckpointWriter peuvent être utilisées seules.

Le format de point de contrôle TensorFlow v2

Le format de point de contrôle TensorFlow v2, tel que brièvement décrit dans cet en-tête , est le format de deuxième génération pour les points de contrôle du modèle TensorFlow. Ce format de deuxième génération est utilisé depuis fin 2016 et présente un certain nombre d'améliorations par rapport au format de point de contrôle v1. Les TensorFlow SavedModels utilisent des points de contrôle v2 pour enregistrer les paramètres du modèle.

Un point de contrôle TensorFlow v2 consiste en un répertoire avec une structure comme la suivante :

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

où le premier fichier stocke les métadonnées du point de contrôle et les fichiers restants sont des fragments binaires contenant les paramètres sérialisés du modèle.

Le fichier de métadonnées d'index contient les types, les tailles, les emplacements et les noms de chaîne de tous les tenseurs sérialisés contenus dans les fragments. Ce fichier d'index est la partie structurellement la plus complexe du point de contrôle et est basé sur tensorflow::table , qui est lui-même basé sur SSTable / LevelDB. Ce fichier d'index est composé d'une série de paires clé-valeur, où les clés sont des chaînes et les valeurs sont des tampons de protocole. Les chaînes sont triées et compressées avec préfixe. Par exemple : si la première entrée est conv1/weight et la suivante conv1/bias , la deuxième entrée utilise uniquement la partie bias .

Ce fichier d'index global est parfois compressé à l'aide de la compression Snappy . Le fichier SnappyDecompression.swift fournit une implémentation Swift native de la décompression Snappy à partir d'une instance de données compressée.

Les métadonnées d'en-tête d'index et les métadonnées de tenseur sont codées sous forme de tampons de protocole et codées/décodées directement via Swift Protobuf .

Les classes CheckpointIndexReader et CheckpointIndexWriter gèrent le chargement et l'enregistrement de ces fichiers d'index dans le cadre des classes globales CheckpointReader et CheckpointWriter . Ces derniers utilisent les fichiers d'index comme base pour déterminer ce qu'il faut lire et écrire sur les fragments binaires structurellement plus simples qui contiennent les données tensorielles.

,

La possibilité de sauvegarder et de restaurer l'état d'un modèle est vitale pour un certain nombre d'applications, telles que l'apprentissage par transfert ou la réalisation d'inférences à l'aide de modèles pré-entraînés. L'enregistrement des paramètres d'un modèle (poids, biais, etc.) dans un fichier ou un répertoire de points de contrôle est un moyen d'y parvenir.

Ce module fournit une interface de haut niveau pour charger et enregistrer les points de contrôle du format TensorFlow v2 , ainsi que des composants de niveau inférieur qui écrivent et lisent à partir de ce format de fichier.

Chargement et sauvegarde de modèles simples

En se conformant au protocole Checkpointable , de nombreux modèles simples peuvent être sérialisés en points de contrôle sans aucun code supplémentaire :

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

puis ce même point de contrôle peut être lu en utilisant :

try model.readCheckpoint(from: directory, name: "LeNet")

Cette implémentation par défaut pour le chargement et l'enregistrement du modèle utilisera un schéma de dénomination basé sur le chemin pour chaque tenseur du modèle, basé sur les noms des propriétés dans les structures du modèle. Par exemple, les poids et les biais au sein de la première convolution du modèle LeNet-5 seront enregistrés respectivement sous les noms conv1/filter et conv1/bias . Lors du chargement, le lecteur de point de contrôle recherchera les tenseurs portant ces noms.

Personnalisation du chargement et de l'enregistrement du modèle

Si vous souhaitez avoir un meilleur contrôle sur les tenseurs enregistrés et chargés, ou sur la dénomination de ces tenseurs, le protocole Checkpointable propose quelques points de personnalisation.

Pour ignorer les propriétés sur certains types, vous pouvez fournir une implémentation de ignoredTensorPaths sur votre modèle qui renvoie un ensemble de chaînes sous la forme de Type.property . Par exemple, pour ignorer la propriété scale sur chaque couche Attention, vous pouvez renvoyer ["Attention.scale"] .

Par défaut, une barre oblique est utilisée pour séparer chaque niveau plus profond dans un modèle. Cela peut être personnalisé en implémentant checkpointSeparator sur votre modèle et en fournissant une nouvelle chaîne à utiliser pour ce séparateur.

Enfin, pour le plus grand degré de personnalisation de la dénomination du tenseur, vous pouvez implémenter tensorNameMap et fournir une fonction qui mappe le nom de chaîne par défaut généré pour un tenseur dans le modèle vers un nom de chaîne souhaité dans le point de contrôle. Le plus souvent, cela sera utilisé pour interagir avec des points de contrôle générés avec d'autres frameworks, chacun ayant ses propres conventions de dénomination et structures de modèle. Une fonction de mappage personnalisée offre le plus grand degré de personnalisation de la façon dont ces tenseurs sont nommés.

Certaines fonctions d'assistance standard sont fournies, comme la CheckpointWriter.identityMap par défaut (qui utilise simplement le nom de chemin du tenseur généré automatiquement pour les points de contrôle), ou la fonction CheckpointWriter.lookupMap(table:) , qui peut créer un mappage à partir d'un dictionnaire.

Pour un exemple de la façon dont un mappage personnalisé peut être réalisé, veuillez consulter le modèle GPT-2 , qui utilise une fonction de mappage pour correspondre au schéma de dénomination exact utilisé pour les points de contrôle d'OpenAI.

Les composants CheckpointReader et CheckpointWriter

Pour l'écriture de points de contrôle, l'extension fournie par le protocole Checkpointable utilise la réflexion et les chemins de clés pour parcourir les propriétés d'un modèle et générer un dictionnaire qui mappe les chemins du tenseur de chaîne aux valeurs du tenseur. Ce dictionnaire est fourni à un CheckpointWriter sous-jacent, avec un répertoire dans lequel écrire le point de contrôle. Ce CheckpointWriter gère la tâche de génération du point de contrôle sur le disque à partir de ce dictionnaire.

L'inverse de ce processus est la lecture, où un CheckpointReader reçoit l'emplacement d'un répertoire de point de contrôle sur le disque. Il lit ensuite à partir de ce point de contrôle et forme un dictionnaire qui mappe les noms des tenseurs au sein du point de contrôle avec leurs valeurs enregistrées. Ce dictionnaire est utilisé pour remplacer les tenseurs actuels d'un modèle par ceux de ce dictionnaire.

Pour le chargement et l'enregistrement, le protocole Checkpointable mappe les chemins de chaîne vers les tenseurs aux noms de tenseurs correspondants sur le disque à l'aide de la fonction de mappage décrite ci-dessus.

Si le protocole Checkpointable ne dispose pas des fonctionnalités nécessaires ou si un contrôle accru est souhaité sur le processus de chargement et de sauvegarde des points de contrôle, les classes CheckpointReader et CheckpointWriter peuvent être utilisées seules.

Le format de point de contrôle TensorFlow v2

Le format de point de contrôle TensorFlow v2, tel que brièvement décrit dans cet en-tête , est le format de deuxième génération pour les points de contrôle du modèle TensorFlow. Ce format de deuxième génération est utilisé depuis fin 2016 et présente un certain nombre d'améliorations par rapport au format de point de contrôle v1. Les TensorFlow SavedModels utilisent des points de contrôle v2 pour enregistrer les paramètres du modèle.

Un point de contrôle TensorFlow v2 consiste en un répertoire avec une structure comme la suivante :

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

où le premier fichier stocke les métadonnées du point de contrôle et les fichiers restants sont des fragments binaires contenant les paramètres sérialisés du modèle.

Le fichier de métadonnées d'index contient les types, les tailles, les emplacements et les noms de chaîne de tous les tenseurs sérialisés contenus dans les fragments. Ce fichier d'index est la partie structurellement la plus complexe du point de contrôle et est basé sur tensorflow::table , qui est lui-même basé sur SSTable / LevelDB. Ce fichier d'index est composé d'une série de paires clé-valeur, où les clés sont des chaînes et les valeurs sont des tampons de protocole. Les chaînes sont triées et compressées avec préfixe. Par exemple : si la première entrée est conv1/weight et la suivante conv1/bias , la deuxième entrée utilise uniquement la partie bias .

Ce fichier d'index global est parfois compressé à l'aide de la compression Snappy . Le fichier SnappyDecompression.swift fournit une implémentation Swift native de la décompression Snappy à partir d'une instance de données compressée.

Les métadonnées d'en-tête d'index et les métadonnées de tenseur sont codées sous forme de tampons de protocole et codées/décodées directement via Swift Protobuf .

Les classes CheckpointIndexReader et CheckpointIndexWriter gèrent le chargement et l'enregistrement de ces fichiers d'index dans le cadre des classes globales CheckpointReader et CheckpointWriter . Ces derniers utilisent les fichiers d'index comme base pour déterminer ce qu'il faut lire et écrire sur les fragments binaires structurellement plus simples qui contiennent les données tensorielles.