توانایی ذخیره و بازیابی وضعیت یک مدل برای تعدادی از کاربردها، مانند در یادگیری انتقال یا برای انجام استنتاج با استفاده از مدل های از پیش آموزش دیده، حیاتی است. ذخیره پارامترهای یک مدل (وزنها، بایاسها و غیره) در یک فایل یا دایرکتوری نقطه چک یکی از راههای انجام این کار است.
این ماژول یک رابط سطح بالا برای بارگیری و ذخیره نقاط چک با فرمت TensorFlow v2 ، و همچنین اجزای سطح پایین تر که در این فرمت فایل می نویسند و می خوانند، فراهم می کند.
بارگیری و ذخیره مدل های ساده
با انطباق با پروتکل Checkpointable
، بسیاری از مدلهای ساده را میتوان بدون هیچ کد اضافی در نقاط بازرسی سریال کرد:
import Checkpoints
import ImageClassificationModels
extension LeNet: Checkpointable {}
var model = LeNet()
...
try model.writeCheckpoint(to: directory, name: "LeNet")
و سپس همان نقطه بازرسی را می توان با استفاده از:
try model.readCheckpoint(from: directory, name: "LeNet")
این پیادهسازی پیشفرض برای بارگذاری و ذخیره مدل، از یک طرح نامگذاری مبتنی بر مسیر برای هر تانسور در مدل استفاده میکند که بر اساس نام ویژگیهای ساختارهای مدل است. برای مثال، وزنها و بایاسها در اولین پیچیدگی در مدل LeNet-5 به ترتیب با نامهای conv1/filter
و conv1/bias
ذخیره میشوند. هنگام بارگیری، بازرسی خوان تانسورهایی با این نام ها را جستجو می کند.
سفارشی کردن بارگیری و ذخیره مدل
اگر می خواهید کنترل بیشتری روی ذخیره و بارگذاری تانسورها یا نامگذاری آن تانسورها داشته باشید، پروتکل Checkpointable
چند نکته سفارشی سازی را ارائه می دهد.
برای نادیده گرفتن ویژگیها در انواع خاصی، میتوانید پیادهسازی ignoredTensorPaths
را در مدل خود ارائه دهید که مجموعهای از رشتهها را به شکل Type.property
برمیگرداند. به عنوان مثال، برای نادیده گرفتن ویژگی scale
در هر لایه توجه، می توانید ["Attention.scale"]
را برگردانید.
به طور پیش فرض، یک اسلش رو به جلو برای جدا کردن هر سطح عمیق تر در یک مدل استفاده می شود. این را می توان با اجرای checkpointSeparator
در مدل خود و ارائه یک رشته جدید برای استفاده برای این جداکننده سفارشی کرد.
در نهایت، برای بیشترین درجه سفارشیسازی در نامگذاری تانسور، میتوانید tensorNameMap
پیادهسازی کنید و تابعی را ارائه کنید که از نام رشته پیشفرض تولید شده برای یک تانسور در مدل به نام رشته دلخواه در نقطه بازرسی نگاشت میشود. معمولاً، از این برای تعامل با نقاط بازرسی تولید شده با چارچوب های دیگر استفاده می شود، که هر یک دارای قراردادهای نامگذاری و ساختار مدل خود هستند. یک تابع نگاشت سفارشی بیشترین درجه سفارشی سازی را برای نحوه نامگذاری این تانسورها ارائه می دهد.
برخی از توابع کمکی استاندارد ارائه شدهاند، مانند پیشفرض CheckpointWriter.identityMap
(که به سادگی از نام مسیر تانسور تولید شده بهطور خودکار برای نقاط بازرسی استفاده میکند)، یا تابع CheckpointWriter.lookupMap(table:)
که میتواند نقشهبرداری را از فرهنگ لغت بسازد.
برای مثالی از نحوه انجام نقشهبرداری سفارشی، لطفاً مدل GPT-2 را ببینید که از یک تابع نقشهبرداری برای مطابقت با طرح نامگذاری دقیق مورد استفاده برای نقاط بازرسی OpenAI استفاده میکند.
اجزای CheckpointReader و CheckpointWriter
برای نوشتن نقطه بازرسی، برنامه افزودنی ارائه شده توسط پروتکل Checkpointable
از بازتاب و مسیرهای کلیدی برای تکرار بر روی ویژگی های یک مدل استفاده می کند و فرهنگ لغتی ایجاد می کند که مسیرهای تانسور رشته را به مقادیر Tensor ترسیم می کند. این فرهنگ لغت در اختیار یک CheckpointWriter
زیرزمینی قرار گرفته است، همراه با دایرکتوری که در آن چکپوینت را بنویسید. آن CheckpointWriter
وظیفه تولید ایست بازرسی روی دیسک را از آن فرهنگ لغت بر عهده دارد.
معکوس این فرآیند خواندن است، جایی که به یک CheckpointReader
مکان دایرکتوری ایست بازرسی روی دیسک داده می شود. سپس از آن ایست بازرسی می خواند و دیکشنری تشکیل می دهد که نام تانسورهای داخل ایست بازرسی را با مقادیر ذخیره شده آنها ترسیم می کند. این دیکشنری برای جایگزینی تانسورهای فعلی در یک مدل با موارد موجود در این فرهنگ لغت استفاده می شود.
هم برای بارگذاری و هم برای ذخیره، پروتکل Checkpointable
مسیرهای رشته به تانسورها را به نام های تانسور مربوطه روی دیسک با استفاده از تابع نگاشت فوق توصیف می کند.
اگر پروتکل Checkpointable
فاقد عملکرد مورد نیاز باشد، یا کنترل بیشتری روی فرآیند بارگیری و ذخیره نقطه بازرسی مورد نظر باشد، کلاس های CheckpointReader
و CheckpointWriter
می توانند به تنهایی مورد استفاده قرار گیرند.
فرمت ایست بازرسی TensorFlow v2
فرمت ایست بازرسی TensorFlow v2، همانطور که به طور خلاصه در این هدر توضیح داده شده است، فرمت نسل دوم برای نقاط بازرسی مدل TensorFlow است. این فرمت نسل دوم از اواخر سال 2016 مورد استفاده قرار گرفته است و نسبت به فرمت چک پوینت v1 بهبودهایی دارد. TensorFlow SavedModels از چک پوینتهای v2 در داخل خود برای ذخیره پارامترهای مدل استفاده میکنند.
یک چک پوینت TensorFlow v2 از یک دایرکتوری با ساختاری مانند زیر تشکیل شده است:
checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002
جایی که فایل اول متادیتا را برای نقطه بازرسی ذخیره میکند و فایلهای باقیمانده خردههای باینری هستند که پارامترهای سریالی مدل را در خود نگه میدارند.
فایل فراداده شاخص شامل انواع، اندازهها، مکانها و نام رشتههای تمام تانسورهای سریالی موجود در خردهها است. این فایل فهرست پیچیده ترین بخش از نظر ساختاری است و بر اساس tensorflow::table
است که خود بر اساس SSTable / LevelDB است. این فایل فهرست از یک سری جفت کلید-مقدار تشکیل شده است که کلیدها رشتهها و مقادیر بافرهای پروتکل هستند. رشته ها مرتب شده و با پیشوند فشرده شده اند. به عنوان مثال: اگر ورودی اول conv1/weight
و بعدی conv1/bias
باشد، ورودی دوم فقط از قسمت bias
استفاده می کند.
این فایل فهرست کلی گاهی اوقات با استفاده از فشرده سازی Snappy فشرده می شود. فایل SnappyDecompression.swift
یک پیاده سازی بومی Swift از فشرده سازی Snappy را از یک نمونه داده فشرده ارائه می دهد.
فراداده سرصفحه شاخص و فراداده تانسور به عنوان بافرهای پروتکل کدگذاری می شوند و مستقیماً از طریق Swift Protobuf کدگذاری / رمزگشایی می شوند.
کلاس های CheckpointIndexReader
و CheckpointIndexWriter
بارگیری و ذخیره این فایل های فهرست را به عنوان بخشی از کلاس های فراگیر CheckpointReader
و CheckpointWriter
انجام می دهند. دومی از فایلهای فهرست بهعنوان مبنایی برای تعیین اینکه از چه چیزی باید بخواند و در تکههای باینری سادهتر ساختاری که حاوی دادههای تانسور هستند، بخواند و بنویسد، استفاده میکند.