Kemampuan untuk menyimpan dan memulihkan status model sangat penting untuk sejumlah aplikasi, seperti dalam pembelajaran transfer atau untuk melakukan inferensi menggunakan model yang telah dilatih sebelumnya. Menyimpan parameter model (bobot, bias, dll.) dalam file atau direktori pos pemeriksaan adalah salah satu cara untuk mencapai hal ini.
Modul ini menyediakan antarmuka tingkat tinggi untuk memuat dan menyimpan pos pemeriksaan format TensorFlow v2 , serta komponen tingkat rendah yang menulis dan membaca dari format file ini.
Memuat dan menyimpan model sederhana
Dengan menyesuaikan diri dengan protokol Checkpointable
, banyak model sederhana dapat diserialkan ke pos pemeriksaan tanpa kode tambahan apa pun:
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
dan kemudian pos pemeriksaan yang sama dapat dibaca dengan menggunakan:
try model.readCheckpoint(from: directory, name: "LeNet")
Implementasi default untuk pemuatan dan penyimpanan model ini akan menggunakan skema penamaan berbasis jalur untuk setiap tensor dalam model yang didasarkan pada nama properti dalam struct model. Misalnya, bobot dan bias dalam konvolusi pertama dalam model LeNet-5 akan disimpan dengan nama conv1/filter
dan conv1/bias
. Saat memuat, pembaca pos pemeriksaan akan mencari tensor dengan nama ini.
Menyesuaikan pemuatan dan penyimpanan model
Jika Anda ingin memiliki kontrol lebih besar terhadap tensor mana yang disimpan dan dimuat, atau penamaan tensor tersebut, protokol Checkpointable
menawarkan beberapa poin penyesuaian.
Untuk mengabaikan properti pada tipe tertentu, Anda dapat memberikan implementasi ignoredTensorPaths
pada model Anda yang mengembalikan Kumpulan string dalam bentuk Type.property
. Misalnya, untuk mengabaikan properti scale
pada setiap lapisan Attention, Anda dapat mengembalikan ["Attention.scale"]
.
Secara default, garis miring digunakan untuk memisahkan setiap level yang lebih dalam dalam suatu model. Ini dapat dikustomisasi dengan mengimplementasikan checkpointSeparator
pada model Anda dan menyediakan string baru untuk digunakan sebagai pemisah ini.
Terakhir, untuk tingkat penyesuaian terbesar dalam penamaan tensor, Anda dapat mengimplementasikan tensorNameMap
dan menyediakan fungsi yang memetakan dari nama string default yang dihasilkan untuk tensor dalam model ke nama string yang diinginkan di pos pemeriksaan. Umumnya, ini akan digunakan untuk berinteroperasi dengan pos pemeriksaan yang dihasilkan dengan kerangka kerja lain, yang masing-masing memiliki konvensi penamaan dan struktur modelnya sendiri. Fungsi pemetaan khusus memberikan tingkat penyesuaian terbesar pada cara pemberian nama tensor ini.
Beberapa fungsi pembantu standar disediakan, seperti CheckpointWriter.identityMap
default (yang hanya menggunakan nama jalur tensor yang dibuat secara otomatis untuk pos pemeriksaan), atau fungsi CheckpointWriter.lookupMap(table:)
yang dapat membuat pemetaan dari kamus.
Untuk contoh bagaimana pemetaan khusus dapat dilakukan, silakan lihat model GPT-2 , yang menggunakan fungsi pemetaan untuk mencocokkan skema penamaan persis yang digunakan untuk pos pemeriksaan OpenAI.
Komponen CheckpointReader dan CheckpointWriter
Untuk penulisan checkpoint, ekstensi yang disediakan oleh protokol Checkpointable
menggunakan refleksi dan jalur kunci untuk melakukan iterasi pada properti model dan menghasilkan kamus yang memetakan jalur tensor string ke nilai Tensor. Kamus ini disediakan untuk CheckpointWriter
yang mendasarinya, bersama dengan direktori untuk menulis pos pemeriksaan. CheckpointWriter
tersebut menangani tugas menghasilkan pos pemeriksaan pada disk dari kamus tersebut.
Kebalikan dari proses ini adalah membaca, dimana CheckpointReader
diberikan lokasi direktori checkpoint pada disk. Ia kemudian membaca dari pos pemeriksaan tersebut dan membentuk kamus yang memetakan nama tensor dalam pos pemeriksaan tersebut dengan nilai yang disimpannya. Kamus ini digunakan untuk mengganti tensor saat ini dalam suatu model dengan yang ada di kamus ini.
Untuk memuat dan menyimpan, protokol Checkpointable
memetakan jalur string ke tensor ke nama tensor pada disk yang sesuai menggunakan fungsi pemetaan yang dijelaskan di atas.
Jika protokol Checkpointable
tidak memiliki fungsionalitas yang diperlukan, atau diperlukan lebih banyak kontrol atas proses pemuatan dan penyimpanan checkpoint, kelas CheckpointReader
dan CheckpointWriter
dapat digunakan sendiri.
Format pos pemeriksaan TensorFlow v2
Format pos pemeriksaan TensorFlow v2, seperti yang dijelaskan secara singkat di header ini , adalah format generasi kedua untuk pos pemeriksaan model TensorFlow. Format generasi kedua ini telah digunakan sejak akhir tahun 2016, dan memiliki sejumlah perbaikan dibandingkan format pos pemeriksaan v1. TensorFlow SavedModels menggunakan pos pemeriksaan v2 di dalamnya untuk menyimpan parameter model.
Pos pemeriksaan TensorFlow v2 terdiri dari direktori dengan struktur seperti berikut:
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
dengan file pertama menyimpan metadata untuk pos pemeriksaan dan file sisanya adalah pecahan biner yang menyimpan parameter serial untuk model tersebut.
File metadata indeks berisi jenis, ukuran, lokasi, dan nama string semua tensor berseri yang terdapat dalam pecahan. File indeks tersebut adalah bagian pos pemeriksaan yang paling kompleks secara struktural, dan didasarkan pada tensorflow::table
, yang juga didasarkan pada SSTable / LevelDB. File indeks ini terdiri dari serangkaian pasangan nilai kunci, di mana kuncinya adalah string dan nilainya adalah buffer protokol. String diurutkan dan dikompresi awalan. Misalnya: jika entri pertama adalah conv1/weight
dan berikutnya conv1/bias
, entri kedua hanya menggunakan bagian bias
.
File indeks keseluruhan ini terkadang dikompresi menggunakan kompresi Snappy . File SnappyDecompression.swift
menyediakan implementasi Swift asli dari dekompresi Snappy dari instance Data terkompresi.
Metadata header indeks dan metadata tensor dikodekan sebagai buffer protokol dan dikodekan/didekode langsung melalui Swift Protobuf .
Kelas CheckpointIndexReader
dan CheckpointIndexWriter
menangani pemuatan dan penyimpanan file indeks ini sebagai bagian dari kelas CheckpointReader
dan CheckpointWriter
yang menyeluruh. Yang terakhir menggunakan file indeks sebagai dasar untuk menentukan apa yang harus dibaca dan ditulis ke pecahan biner yang secara struktural lebih sederhana yang berisi data tensor.
Kemampuan untuk menyimpan dan memulihkan status model sangat penting untuk sejumlah aplikasi, seperti dalam pembelajaran transfer atau untuk melakukan inferensi menggunakan model yang telah dilatih sebelumnya. Menyimpan parameter model (bobot, bias, dll.) dalam file atau direktori pos pemeriksaan adalah salah satu cara untuk mencapai hal ini.
Modul ini menyediakan antarmuka tingkat tinggi untuk memuat dan menyimpan pos pemeriksaan format TensorFlow v2 , serta komponen tingkat rendah yang menulis dan membaca dari format file ini.
Memuat dan menyimpan model sederhana
Dengan menyesuaikan diri dengan protokol Checkpointable
, banyak model sederhana dapat diserialkan ke pos pemeriksaan tanpa kode tambahan apa pun:
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
dan kemudian pos pemeriksaan yang sama dapat dibaca dengan menggunakan:
try model.readCheckpoint(from: directory, name: "LeNet")
Implementasi default untuk pemuatan dan penyimpanan model ini akan menggunakan skema penamaan berbasis jalur untuk setiap tensor dalam model yang didasarkan pada nama properti dalam struct model. Misalnya, bobot dan bias dalam konvolusi pertama dalam model LeNet-5 akan disimpan dengan nama conv1/filter
dan conv1/bias
. Saat memuat, pembaca pos pemeriksaan akan mencari tensor dengan nama ini.
Menyesuaikan pemuatan dan penyimpanan model
Jika Anda ingin memiliki kontrol lebih besar terhadap tensor mana yang disimpan dan dimuat, atau penamaan tensor tersebut, protokol Checkpointable
menawarkan beberapa poin penyesuaian.
Untuk mengabaikan properti pada tipe tertentu, Anda dapat memberikan implementasi ignoredTensorPaths
pada model Anda yang mengembalikan Kumpulan string dalam bentuk Type.property
. Misalnya, untuk mengabaikan properti scale
pada setiap lapisan Attention, Anda dapat mengembalikan ["Attention.scale"]
.
Secara default, garis miring digunakan untuk memisahkan setiap level yang lebih dalam dalam suatu model. Ini dapat dikustomisasi dengan mengimplementasikan checkpointSeparator
pada model Anda dan menyediakan string baru untuk digunakan sebagai pemisah ini.
Terakhir, untuk tingkat penyesuaian terbesar dalam penamaan tensor, Anda dapat mengimplementasikan tensorNameMap
dan menyediakan fungsi yang memetakan dari nama string default yang dihasilkan untuk tensor dalam model ke nama string yang diinginkan di pos pemeriksaan. Umumnya, ini akan digunakan untuk berinteroperasi dengan pos pemeriksaan yang dihasilkan dengan kerangka kerja lain, yang masing-masing memiliki konvensi penamaan dan struktur modelnya sendiri. Fungsi pemetaan khusus memberikan tingkat penyesuaian terbesar pada cara pemberian nama tensor ini.
Beberapa fungsi pembantu standar disediakan, seperti CheckpointWriter.identityMap
default (yang hanya menggunakan nama jalur tensor yang dibuat secara otomatis untuk pos pemeriksaan), atau fungsi CheckpointWriter.lookupMap(table:)
yang dapat membuat pemetaan dari kamus.
Untuk contoh bagaimana pemetaan khusus dapat dilakukan, silakan lihat model GPT-2 , yang menggunakan fungsi pemetaan untuk mencocokkan skema penamaan persis yang digunakan untuk pos pemeriksaan OpenAI.
Komponen CheckpointReader dan CheckpointWriter
Untuk penulisan checkpoint, ekstensi yang disediakan oleh protokol Checkpointable
menggunakan refleksi dan jalur kunci untuk melakukan iterasi pada properti model dan menghasilkan kamus yang memetakan jalur tensor string ke nilai Tensor. Kamus ini disediakan untuk CheckpointWriter
yang mendasarinya, bersama dengan direktori untuk menulis pos pemeriksaan. CheckpointWriter
tersebut menangani tugas menghasilkan pos pemeriksaan pada disk dari kamus tersebut.
Kebalikan dari proses ini adalah membaca, dimana CheckpointReader
diberikan lokasi direktori checkpoint pada disk. Ia kemudian membaca dari pos pemeriksaan tersebut dan membentuk kamus yang memetakan nama tensor dalam pos pemeriksaan tersebut dengan nilai yang disimpannya. Kamus ini digunakan untuk mengganti tensor saat ini dalam suatu model dengan yang ada di kamus ini.
Untuk memuat dan menyimpan, protokol Checkpointable
memetakan jalur string ke tensor ke nama tensor pada disk yang sesuai menggunakan fungsi pemetaan yang dijelaskan di atas.
Jika protokol Checkpointable
tidak memiliki fungsionalitas yang diperlukan, atau diperlukan lebih banyak kontrol atas proses pemuatan dan penyimpanan checkpoint, kelas CheckpointReader
dan CheckpointWriter
dapat digunakan sendiri.
Format pos pemeriksaan TensorFlow v2
Format pos pemeriksaan TensorFlow v2, seperti yang dijelaskan secara singkat di header ini , adalah format generasi kedua untuk pos pemeriksaan model TensorFlow. Format generasi kedua ini telah digunakan sejak akhir tahun 2016, dan memiliki sejumlah perbaikan dibandingkan format pos pemeriksaan v1. TensorFlow SavedModels menggunakan pos pemeriksaan v2 di dalamnya untuk menyimpan parameter model.
Pos pemeriksaan TensorFlow v2 terdiri dari direktori dengan struktur seperti berikut:
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
dengan file pertama menyimpan metadata untuk pos pemeriksaan dan file sisanya adalah pecahan biner yang menyimpan parameter serial untuk model tersebut.
File metadata indeks berisi jenis, ukuran, lokasi, dan nama string semua tensor berseri yang terdapat dalam pecahan. File indeks tersebut adalah bagian pos pemeriksaan yang paling kompleks secara struktural, dan didasarkan pada tensorflow::table
, yang juga didasarkan pada SSTable / LevelDB. File indeks ini terdiri dari serangkaian pasangan nilai kunci, di mana kuncinya adalah string dan nilainya adalah buffer protokol. String diurutkan dan dikompresi awalan. Misalnya: jika entri pertama adalah conv1/weight
dan berikutnya conv1/bias
, entri kedua hanya menggunakan bagian bias
.
File indeks keseluruhan ini terkadang dikompresi menggunakan kompresi Snappy . File SnappyDecompression.swift
menyediakan implementasi Swift asli dari dekompresi Snappy dari instance Data terkompresi.
Metadata header indeks dan metadata tensor dikodekan sebagai buffer protokol dan dikodekan/didekode langsung melalui Swift Protobuf .
Kelas CheckpointIndexReader
dan CheckpointIndexWriter
menangani pemuatan dan penyimpanan file indeks ini sebagai bagian dari kelas CheckpointReader
dan CheckpointWriter
yang menyeluruh. Yang terakhir menggunakan file indeks sebagai dasar untuk menentukan apa yang harus dibaca dan ditulis ke pecahan biner yang secara struktural lebih sederhana yang berisi data tensor.