Pos pemeriksaan model

Kemampuan untuk menyimpan dan memulihkan status model sangat penting untuk sejumlah aplikasi, seperti dalam pembelajaran transfer atau untuk melakukan inferensi menggunakan model yang telah dilatih sebelumnya. Menyimpan parameter model (bobot, bias, dll.) dalam file atau direktori pos pemeriksaan adalah salah satu cara untuk mencapai hal ini.

Modul ini menyediakan antarmuka tingkat tinggi untuk memuat dan menyimpan pos pemeriksaan format TensorFlow v2 , serta komponen tingkat rendah yang menulis dan membaca dari format file ini.

Memuat dan menyimpan model sederhana

Dengan menyesuaikan diri dengan protokol Checkpointable , banyak model sederhana dapat diserialkan ke pos pemeriksaan tanpa kode tambahan apa pun:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

dan kemudian pos pemeriksaan yang sama dapat dibaca dengan menggunakan:

try model.readCheckpoint(from: directory, name: "LeNet")

Implementasi default untuk pemuatan dan penyimpanan model ini akan menggunakan skema penamaan berbasis jalur untuk setiap tensor dalam model yang didasarkan pada nama properti dalam struct model. Misalnya, bobot dan bias dalam konvolusi pertama dalam model LeNet-5 akan disimpan dengan nama conv1/filter dan conv1/bias . Saat memuat, pembaca pos pemeriksaan akan mencari tensor dengan nama ini.

Menyesuaikan pemuatan dan penyimpanan model

Jika Anda ingin memiliki kontrol lebih besar terhadap tensor mana yang disimpan dan dimuat, atau penamaan tensor tersebut, protokol Checkpointable menawarkan beberapa poin penyesuaian.

Untuk mengabaikan properti pada tipe tertentu, Anda dapat memberikan implementasi ignoredTensorPaths pada model Anda yang mengembalikan Kumpulan string dalam bentuk Type.property . Misalnya, untuk mengabaikan properti scale pada setiap lapisan Attention, Anda dapat mengembalikan ["Attention.scale"] .

Secara default, garis miring digunakan untuk memisahkan setiap level yang lebih dalam dalam suatu model. Ini dapat dikustomisasi dengan mengimplementasikan checkpointSeparator pada model Anda dan menyediakan string baru untuk digunakan sebagai pemisah ini.

Terakhir, untuk tingkat penyesuaian terbesar dalam penamaan tensor, Anda dapat mengimplementasikan tensorNameMap dan menyediakan fungsi yang memetakan dari nama string default yang dihasilkan untuk tensor dalam model ke nama string yang diinginkan di pos pemeriksaan. Umumnya, ini akan digunakan untuk berinteroperasi dengan pos pemeriksaan yang dihasilkan dengan kerangka kerja lain, yang masing-masing memiliki konvensi penamaan dan struktur modelnya sendiri. Fungsi pemetaan khusus memberikan tingkat penyesuaian terbesar pada cara pemberian nama tensor ini.

Beberapa fungsi pembantu standar disediakan, seperti CheckpointWriter.identityMap default (yang hanya menggunakan nama jalur tensor yang dibuat secara otomatis untuk pos pemeriksaan), atau fungsi CheckpointWriter.lookupMap(table:) yang dapat membuat pemetaan dari kamus.

Untuk contoh bagaimana pemetaan khusus dapat dilakukan, silakan lihat model GPT-2 , yang menggunakan fungsi pemetaan untuk mencocokkan skema penamaan persis yang digunakan untuk pos pemeriksaan OpenAI.

Komponen CheckpointReader dan CheckpointWriter

Untuk penulisan checkpoint, ekstensi yang disediakan oleh protokol Checkpointable menggunakan refleksi dan jalur kunci untuk melakukan iterasi pada properti model dan menghasilkan kamus yang memetakan jalur tensor string ke nilai Tensor. Kamus ini disediakan untuk CheckpointWriter yang mendasarinya, bersama dengan direktori untuk menulis pos pemeriksaan. CheckpointWriter tersebut menangani tugas menghasilkan pos pemeriksaan pada disk dari kamus tersebut.

Kebalikan dari proses ini adalah membaca, dimana CheckpointReader diberikan lokasi direktori checkpoint pada disk. Ia kemudian membaca dari pos pemeriksaan tersebut dan membentuk kamus yang memetakan nama tensor dalam pos pemeriksaan tersebut dengan nilai yang disimpannya. Kamus ini digunakan untuk mengganti tensor saat ini dalam suatu model dengan yang ada di kamus ini.

Untuk memuat dan menyimpan, protokol Checkpointable memetakan jalur string ke tensor ke nama tensor pada disk yang sesuai menggunakan fungsi pemetaan yang dijelaskan di atas.

Jika protokol Checkpointable tidak memiliki fungsionalitas yang diperlukan, atau diperlukan lebih banyak kontrol atas proses pemuatan dan penyimpanan checkpoint, kelas CheckpointReader dan CheckpointWriter dapat digunakan sendiri.

Format pos pemeriksaan TensorFlow v2

Format pos pemeriksaan TensorFlow v2, seperti yang dijelaskan secara singkat di header ini , adalah format generasi kedua untuk pos pemeriksaan model TensorFlow. Format generasi kedua ini telah digunakan sejak akhir tahun 2016, dan memiliki sejumlah perbaikan dibandingkan format pos pemeriksaan v1. TensorFlow SavedModels menggunakan pos pemeriksaan v2 di dalamnya untuk menyimpan parameter model.

Pos pemeriksaan TensorFlow v2 terdiri dari direktori dengan struktur seperti berikut:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

dengan file pertama menyimpan metadata untuk pos pemeriksaan dan file sisanya adalah pecahan biner yang menyimpan parameter serial untuk model tersebut.

File metadata indeks berisi jenis, ukuran, lokasi, dan nama string semua tensor berseri yang terdapat dalam pecahan. File indeks tersebut adalah bagian pos pemeriksaan yang paling kompleks secara struktural, dan didasarkan pada tensorflow::table , yang juga didasarkan pada SSTable / LevelDB. File indeks ini terdiri dari serangkaian pasangan nilai kunci, di mana kuncinya adalah string dan nilainya adalah buffer protokol. String diurutkan dan dikompresi awalan. Misalnya: jika entri pertama adalah conv1/weight dan berikutnya conv1/bias , entri kedua hanya menggunakan bagian bias .

File indeks keseluruhan ini terkadang dikompresi menggunakan kompresi Snappy . File SnappyDecompression.swift menyediakan implementasi Swift asli dari dekompresi Snappy dari instance Data terkompresi.

Metadata header indeks dan metadata tensor dikodekan sebagai buffer protokol dan dikodekan/didekode langsung melalui Swift Protobuf .

Kelas CheckpointIndexReader dan CheckpointIndexWriter menangani pemuatan dan penyimpanan file indeks ini sebagai bagian dari kelas CheckpointReader dan CheckpointWriter yang menyeluruh. Yang terakhir menggunakan file indeks sebagai dasar untuk menentukan apa yang harus dibaca dan ditulis ke pecahan biner yang secara struktural lebih sederhana yang berisi data tensor.

,

Kemampuan untuk menyimpan dan memulihkan status model sangat penting untuk sejumlah aplikasi, seperti dalam pembelajaran transfer atau untuk melakukan inferensi menggunakan model yang telah dilatih sebelumnya. Menyimpan parameter model (bobot, bias, dll.) dalam file atau direktori pos pemeriksaan adalah salah satu cara untuk mencapai hal ini.

Modul ini menyediakan antarmuka tingkat tinggi untuk memuat dan menyimpan pos pemeriksaan format TensorFlow v2 , serta komponen tingkat rendah yang menulis dan membaca dari format file ini.

Memuat dan menyimpan model sederhana

Dengan menyesuaikan diri dengan protokol Checkpointable , banyak model sederhana dapat diserialkan ke pos pemeriksaan tanpa kode tambahan apa pun:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

dan kemudian pos pemeriksaan yang sama dapat dibaca dengan menggunakan:

try model.readCheckpoint(from: directory, name: "LeNet")

Implementasi default untuk pemuatan dan penyimpanan model ini akan menggunakan skema penamaan berbasis jalur untuk setiap tensor dalam model yang didasarkan pada nama properti dalam struct model. Misalnya, bobot dan bias dalam konvolusi pertama dalam model LeNet-5 akan disimpan dengan nama conv1/filter dan conv1/bias . Saat memuat, pembaca pos pemeriksaan akan mencari tensor dengan nama ini.

Menyesuaikan pemuatan dan penyimpanan model

Jika Anda ingin memiliki kontrol lebih besar terhadap tensor mana yang disimpan dan dimuat, atau penamaan tensor tersebut, protokol Checkpointable menawarkan beberapa poin penyesuaian.

Untuk mengabaikan properti pada tipe tertentu, Anda dapat memberikan implementasi ignoredTensorPaths pada model Anda yang mengembalikan Kumpulan string dalam bentuk Type.property . Misalnya, untuk mengabaikan properti scale pada setiap lapisan Attention, Anda dapat mengembalikan ["Attention.scale"] .

Secara default, garis miring digunakan untuk memisahkan setiap level yang lebih dalam dalam suatu model. Ini dapat dikustomisasi dengan mengimplementasikan checkpointSeparator pada model Anda dan menyediakan string baru untuk digunakan sebagai pemisah ini.

Terakhir, untuk tingkat penyesuaian terbesar dalam penamaan tensor, Anda dapat mengimplementasikan tensorNameMap dan menyediakan fungsi yang memetakan dari nama string default yang dihasilkan untuk tensor dalam model ke nama string yang diinginkan di pos pemeriksaan. Umumnya, ini akan digunakan untuk berinteroperasi dengan pos pemeriksaan yang dihasilkan dengan kerangka kerja lain, yang masing-masing memiliki konvensi penamaan dan struktur modelnya sendiri. Fungsi pemetaan khusus memberikan tingkat penyesuaian terbesar pada cara pemberian nama tensor ini.

Beberapa fungsi pembantu standar disediakan, seperti CheckpointWriter.identityMap default (yang hanya menggunakan nama jalur tensor yang dibuat secara otomatis untuk pos pemeriksaan), atau fungsi CheckpointWriter.lookupMap(table:) yang dapat membuat pemetaan dari kamus.

Untuk contoh bagaimana pemetaan khusus dapat dilakukan, silakan lihat model GPT-2 , yang menggunakan fungsi pemetaan untuk mencocokkan skema penamaan persis yang digunakan untuk pos pemeriksaan OpenAI.

Komponen CheckpointReader dan CheckpointWriter

Untuk penulisan checkpoint, ekstensi yang disediakan oleh protokol Checkpointable menggunakan refleksi dan jalur kunci untuk melakukan iterasi pada properti model dan menghasilkan kamus yang memetakan jalur tensor string ke nilai Tensor. Kamus ini disediakan untuk CheckpointWriter yang mendasarinya, bersama dengan direktori untuk menulis pos pemeriksaan. CheckpointWriter tersebut menangani tugas menghasilkan pos pemeriksaan pada disk dari kamus tersebut.

Kebalikan dari proses ini adalah membaca, dimana CheckpointReader diberikan lokasi direktori checkpoint pada disk. Ia kemudian membaca dari pos pemeriksaan tersebut dan membentuk kamus yang memetakan nama tensor dalam pos pemeriksaan tersebut dengan nilai yang disimpannya. Kamus ini digunakan untuk mengganti tensor saat ini dalam suatu model dengan yang ada di kamus ini.

Untuk memuat dan menyimpan, protokol Checkpointable memetakan jalur string ke tensor ke nama tensor pada disk yang sesuai menggunakan fungsi pemetaan yang dijelaskan di atas.

Jika protokol Checkpointable tidak memiliki fungsionalitas yang diperlukan, atau diperlukan lebih banyak kontrol atas proses pemuatan dan penyimpanan checkpoint, kelas CheckpointReader dan CheckpointWriter dapat digunakan sendiri.

Format pos pemeriksaan TensorFlow v2

Format pos pemeriksaan TensorFlow v2, seperti yang dijelaskan secara singkat di header ini , adalah format generasi kedua untuk pos pemeriksaan model TensorFlow. Format generasi kedua ini telah digunakan sejak akhir tahun 2016, dan memiliki sejumlah perbaikan dibandingkan format pos pemeriksaan v1. TensorFlow SavedModels menggunakan pos pemeriksaan v2 di dalamnya untuk menyimpan parameter model.

Pos pemeriksaan TensorFlow v2 terdiri dari direktori dengan struktur seperti berikut:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

dengan file pertama menyimpan metadata untuk pos pemeriksaan dan file sisanya adalah pecahan biner yang menyimpan parameter serial untuk model tersebut.

File metadata indeks berisi jenis, ukuran, lokasi, dan nama string semua tensor berseri yang terdapat dalam pecahan. File indeks tersebut adalah bagian pos pemeriksaan yang paling kompleks secara struktural, dan didasarkan pada tensorflow::table , yang juga didasarkan pada SSTable / LevelDB. File indeks ini terdiri dari serangkaian pasangan nilai kunci, di mana kuncinya adalah string dan nilainya adalah buffer protokol. String diurutkan dan dikompresi awalan. Misalnya: jika entri pertama adalah conv1/weight dan berikutnya conv1/bias , entri kedua hanya menggunakan bagian bias .

File indeks keseluruhan ini terkadang dikompresi menggunakan kompresi Snappy . File SnappyDecompression.swift menyediakan implementasi Swift asli dari dekompresi Snappy dari instance Data terkompresi.

Metadata header indeks dan metadata tensor dikodekan sebagai buffer protokol dan dikodekan/didekode langsung melalui Swift Protobuf .

Kelas CheckpointIndexReader dan CheckpointIndexWriter menangani pemuatan dan penyimpanan file indeks ini sebagai bagian dari kelas CheckpointReader dan CheckpointWriter yang menyeluruh. Yang terakhir menggunakan file indeks sebagai dasar untuk menentukan apa yang harus dibaca dan ditulis ke pecahan biner yang secara struktural lebih sederhana yang berisi data tensor.