ด่านจำลอง

ความสามารถในการบันทึกและกู้คืนสถานะของแบบจำลองมีความสำคัญต่อแอปพลิเคชันจำนวนหนึ่ง เช่น ในการถ่ายโอนการเรียนรู้ หรือสำหรับการอนุมานโดยใช้แบบจำลองที่ได้รับการฝึกอบรมมาแล้ว การบันทึกพารามิเตอร์ของโมเดล (น้ำหนัก อคติ ฯลฯ) ในไฟล์จุดตรวจสอบหรือไดเร็กทอรีเป็นวิธีหนึ่งในการบรรลุเป้าหมายนี้

โมดูลนี้มีอินเทอร์เฟซระดับสูงสำหรับการโหลดและบันทึกจุดตรวจสอบ รูปแบบ TensorFlow v2 รวมถึงส่วนประกอบระดับล่างที่เขียนและอ่านจากรูปแบบไฟล์นี้

กำลังโหลดและบันทึกโมเดลอย่างง่าย

ด้วยการปฏิบัติตามโปรโตคอล Checkpointable โมเดลง่ายๆ จำนวนมากจึงสามารถซีเรียลไลซ์ไปยังจุดตรวจสอบได้โดยไม่ต้องใช้รหัสเพิ่มเติม:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

จากนั้นจุดตรวจสอบเดียวกันนั้นสามารถอ่านได้โดยใช้:

try model.readCheckpoint(from: directory, name: "LeNet")

การใช้งานเริ่มต้นสำหรับการโหลดและบันทึกโมเดลนี้จะใช้โครงร่างการตั้งชื่อตามเส้นทางสำหรับแต่ละเทนเซอร์ในโมเดลที่ขึ้นอยู่กับชื่อของคุณสมบัติภายในโครงสร้างของโมเดล ตัวอย่างเช่น น้ำหนักและอคติภายในการหมุนครั้งแรกใน โมเดล LeNet-5 จะถูกบันทึกด้วยชื่อ conv1/filter และ conv1/bias ตามลำดับ เมื่อโหลด เครื่องอ่านจุดตรวจจะค้นหาเทนเซอร์ด้วยชื่อเหล่านี้

การปรับแต่งการโหลดและการบันทึกโมเดล

หากคุณต้องการควบคุมเทนเซอร์ที่จะบันทึกและโหลดได้มากขึ้น หรือการตั้งชื่อเทนเซอร์เหล่านั้น โปรโตคอล Checkpointable เสนอการปรับแต่งสองสามจุด

หากต้องการละเว้นคุณสมบัติของบางประเภท คุณสามารถจัดเตรียมการใช้งานของ ignoredTensorPaths บนโมเดลของคุณที่ส่งคืนชุดสตริงในรูปแบบของ Type.property ตัวอย่างเช่น หากต้องการละเว้นคุณสมบัติ scale ในทุกเลเยอร์ Attention คุณสามารถส่งคืน ["Attention.scale"] ได้

ตามค่าเริ่มต้น เครื่องหมายทับจะใช้เพื่อแยกแต่ละระดับที่ลึกกว่าในแบบจำลอง ซึ่งสามารถปรับแต่งได้โดยใช้ checkpointSeparator กับโมเดลของคุณและระบุสตริงใหม่เพื่อใช้สำหรับตัวคั่นนี้

สุดท้ายนี้ เพื่อการปรับแต่งการตั้งชื่อเทนเซอร์ในระดับสูงสุด คุณสามารถใช้ tensorNameMap และจัดเตรียมฟังก์ชันที่จับคู่จากชื่อสตริงเริ่มต้นที่สร้างขึ้นสำหรับเทนเซอร์ในโมเดลกับชื่อสตริงที่ต้องการในจุดตรวจสอบ โดยทั่วไป สิ่งนี้จะถูกใช้เพื่อทำงานร่วมกับจุดตรวจสอบที่สร้างด้วยเฟรมเวิร์กอื่นๆ ซึ่งแต่ละจุดมีแบบแผนการตั้งชื่อและโครงสร้างแบบจำลองของตัวเอง ฟังก์ชันการแมปแบบกำหนดเองช่วยให้ปรับแต่งวิธีการตั้งชื่อเทนเซอร์เหล่านี้ได้ในระดับสูงสุด

มีฟังก์ชันตัวช่วยมาตรฐานบางอย่างมาให้ เช่น CheckpointWriter.identityMap เริ่มต้น (ซึ่งใช้ชื่อพาธเทนเซอร์ที่สร้างขึ้นโดยอัตโนมัติสำหรับจุดตรวจสอบ) หรือฟังก์ชัน CheckpointWriter.lookupMap(table:) ซึ่งสามารถสร้างการแมปจากพจนานุกรมได้

สำหรับตัวอย่างวิธีการแมปแบบกำหนดเองให้สำเร็จ โปรดดู โมเดล GPT-2 ซึ่งใช้ฟังก์ชันการแมปเพื่อให้ตรงกับรูปแบบการตั้งชื่อที่ใช้สำหรับจุดตรวจสอบของ OpenAI

ส่วนประกอบ CheckpointReader และ CheckpointWriter

สำหรับการเขียนจุดตรวจ ส่วนขยายที่จัดทำโดยโปรโตคอล Checkpointable จะใช้การสะท้อนและเส้นทางคีย์เพื่อวนซ้ำคุณสมบัติของแบบจำลอง และสร้างพจนานุกรมที่แมปเส้นทางสตริงเทนเซอร์กับค่าเทนเซอร์ พจนานุกรมนี้จัดทำขึ้นสำหรับ CheckpointWriter พื้นฐาน พร้อมด้วยไดเร็กทอรีสำหรับเขียนจุดตรวจสอบ CheckpointWriter นั้นจัดการงานสร้างจุดตรวจสอบบนดิสก์จากพจนานุกรมนั้น

สิ่งที่ตรงกันข้ามของกระบวนการนี้คือการอ่าน โดยที่ CheckpointReader จะได้รับตำแหน่งของไดเร็กทอรีจุดตรวจสอบบนดิสก์ จากนั้นจะอ่านจากจุดตรวจนั้นและสร้างพจนานุกรมที่จับคู่ชื่อของเทนเซอร์ภายในจุดตรวจด้วยค่าที่บันทึกไว้ พจนานุกรมนี้ใช้เพื่อแทนที่เทนเซอร์ปัจจุบันในแบบจำลองด้วยพจนานุกรมในพจนานุกรมนี้

สำหรับทั้งการโหลดและการบันทึก โปรโตคอล Checkpointable จะแมปพาธสตริงกับเทนเซอร์กับชื่อเทนเซอร์บนดิสก์ที่สอดคล้องกัน โดยใช้ฟังก์ชันการแมปที่อธิบายไว้ข้างต้น

หากโปรโตคอล Checkpointable ขาดฟังก์ชันการทำงานที่จำเป็น หรือต้องการการควบคุมเพิ่มเติมเกี่ยวกับกระบวนการโหลดและบันทึกจุดตรวจ คลาส CheckpointReader และ CheckpointWriter สามารถใช้งานได้ด้วยตัวเอง

รูปแบบจุดตรวจสอบ TensorFlow v2

รูปแบบจุดตรวจสอบ TensorFlow v2 ดังที่อธิบายไว้สั้นๆ ใน ส่วนหัวนี้ คือรูปแบบรุ่นที่สองสำหรับจุดตรวจสอบโมเดล TensorFlow รูปแบบรุ่นที่สองนี้มีการใช้งานมาตั้งแต่ปลายปี 2016 และมีการปรับปรุงหลายประการจากรูปแบบจุดตรวจสอบ v1 TensorFlow SavedModels ใช้จุดตรวจสอบ v2 ภายในเพื่อบันทึกพารามิเตอร์โมเดล

จุดตรวจสอบ TensorFlow v2 ประกอบด้วยไดเร็กทอรีที่มีโครงสร้างดังนี้:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

โดยที่ไฟล์แรกเก็บข้อมูลเมตาสำหรับจุดตรวจสอบและไฟล์ที่เหลือเป็นไบนารีชาร์ดที่เก็บพารามิเตอร์ซีเรียลไลซ์สำหรับโมเดล

ไฟล์ข้อมูลเมตาดัชนีประกอบด้วยประเภท ขนาด ตำแหน่ง และชื่อสตริงของเทนเซอร์ที่ต่อเนื่องกันทั้งหมดที่มีอยู่ในชาร์ด ไฟล์ดัชนีนั้นเป็นส่วนที่มีโครงสร้างซับซ้อนที่สุดของจุดตรวจสอบ และอิงตาม tensorflow::table ซึ่งตัวมันเองอิงตาม SSTable / LevelDB ไฟล์ดัชนีนี้ประกอบด้วยชุดของคู่คีย์-ค่า โดยที่คีย์เป็นสตริง และค่าเป็นบัฟเฟอร์โปรโตคอล สตริงจะถูกจัดเรียงและบีบอัดคำนำหน้า ตัวอย่างเช่น หากรายการแรกคือ conv1/weight และ conv1/bias ถัดไป รายการที่สองจะใช้เฉพาะส่วน bias เท่านั้น

ไฟล์ดัชนีโดยรวมนี้บางครั้งถูกบีบอัดโดยใช้ Snappy Compression ไฟล์ SnappyDecompression.swift นำเสนอการใช้งาน Swift แบบเนทีฟของการบีบอัด Snappy จากอินสแตนซ์ข้อมูลที่บีบอัด

ข้อมูลเมตาของส่วนหัวดัชนีและข้อมูลเมตาของเทนเซอร์ได้รับการเข้ารหัสเป็นบัฟเฟอร์โปรโตคอลและเข้ารหัส / ถอดรหัสโดยตรงผ่าน Swift Protobuf

คลาส CheckpointIndexReader และ CheckpointIndexWriter จัดการการโหลดและบันทึกไฟล์ดัชนีเหล่านี้โดยเป็นส่วนหนึ่งของคลาส CheckpointReader และ CheckpointWriter ที่ครอบคลุม อย่างหลังใช้ไฟล์ดัชนีเป็นพื้นฐานในการกำหนดว่าจะอ่านและเขียนอะไรลงในไบนารี่ชาร์ดที่มีโครงสร้างง่ายกว่าซึ่งมีข้อมูลเทนเซอร์

,

ความสามารถในการบันทึกและกู้คืนสถานะของแบบจำลองมีความสำคัญต่อแอปพลิเคชันจำนวนหนึ่ง เช่น ในการถ่ายโอนการเรียนรู้ หรือสำหรับการอนุมานโดยใช้แบบจำลองที่ได้รับการฝึกอบรมมาแล้ว การบันทึกพารามิเตอร์ของโมเดล (น้ำหนัก อคติ ฯลฯ) ในไฟล์จุดตรวจสอบหรือไดเร็กทอรีเป็นวิธีหนึ่งในการบรรลุเป้าหมายนี้

โมดูลนี้มีอินเทอร์เฟซระดับสูงสำหรับการโหลดและบันทึกจุดตรวจสอบ รูปแบบ TensorFlow v2 รวมถึงส่วนประกอบระดับล่างที่เขียนและอ่านจากรูปแบบไฟล์นี้

กำลังโหลดและบันทึกโมเดลอย่างง่าย

ด้วยการปฏิบัติตามโปรโตคอล Checkpointable โมเดลง่ายๆ จำนวนมากจึงสามารถซีเรียลไลซ์ไปยังจุดตรวจสอบได้โดยไม่ต้องใช้รหัสเพิ่มเติม:

import Checkpoints
import ImageClassificationModels

extension LeNet: Checkpointable {}

var model = LeNet()

...

try model.writeCheckpoint(to: directory, name: "LeNet")

จากนั้นจุดตรวจสอบเดียวกันนั้นสามารถอ่านได้โดยใช้:

try model.readCheckpoint(from: directory, name: "LeNet")

การใช้งานเริ่มต้นสำหรับการโหลดและบันทึกโมเดลนี้จะใช้โครงร่างการตั้งชื่อตามเส้นทางสำหรับแต่ละเทนเซอร์ในโมเดลที่ขึ้นอยู่กับชื่อของคุณสมบัติภายในโครงสร้างของโมเดล ตัวอย่างเช่น น้ำหนักและอคติภายในการหมุนครั้งแรกใน โมเดล LeNet-5 จะถูกบันทึกด้วยชื่อ conv1/filter และ conv1/bias ตามลำดับ เมื่อโหลด เครื่องอ่านจุดตรวจจะค้นหาเทนเซอร์ด้วยชื่อเหล่านี้

การปรับแต่งการโหลดและการบันทึกโมเดล

หากคุณต้องการควบคุมเทนเซอร์ที่จะบันทึกและโหลดได้มากขึ้น หรือการตั้งชื่อเทนเซอร์เหล่านั้น โปรโตคอล Checkpointable เสนอการปรับแต่งสองสามจุด

หากต้องการละเว้นคุณสมบัติของบางประเภท คุณสามารถจัดเตรียมการใช้งานของ ignoredTensorPaths บนโมเดลของคุณที่ส่งคืนชุดสตริงในรูปแบบของ Type.property ตัวอย่างเช่น หากต้องการละเว้นคุณสมบัติ scale ในทุกเลเยอร์ Attention คุณสามารถส่งคืน ["Attention.scale"] ได้

ตามค่าเริ่มต้น เครื่องหมายทับจะใช้เพื่อแยกแต่ละระดับที่ลึกกว่าในแบบจำลอง ซึ่งสามารถปรับแต่งได้โดยใช้ checkpointSeparator กับโมเดลของคุณและระบุสตริงใหม่เพื่อใช้สำหรับตัวคั่นนี้

สุดท้ายนี้ เพื่อการปรับแต่งการตั้งชื่อเทนเซอร์ในระดับสูงสุด คุณสามารถใช้ tensorNameMap และจัดเตรียมฟังก์ชันที่จับคู่จากชื่อสตริงเริ่มต้นที่สร้างขึ้นสำหรับเทนเซอร์ในโมเดลกับชื่อสตริงที่ต้องการในจุดตรวจสอบ โดยทั่วไป สิ่งนี้จะถูกใช้เพื่อทำงานร่วมกับจุดตรวจสอบที่สร้างด้วยเฟรมเวิร์กอื่นๆ ซึ่งแต่ละจุดมีแบบแผนการตั้งชื่อและโครงสร้างแบบจำลองของตัวเอง ฟังก์ชันการแมปแบบกำหนดเองช่วยให้ปรับแต่งวิธีการตั้งชื่อเทนเซอร์เหล่านี้ได้ในระดับสูงสุด

มีฟังก์ชันตัวช่วยมาตรฐานบางอย่างมาให้ เช่น CheckpointWriter.identityMap เริ่มต้น (ซึ่งใช้ชื่อพาธเทนเซอร์ที่สร้างขึ้นโดยอัตโนมัติสำหรับจุดตรวจสอบ) หรือฟังก์ชัน CheckpointWriter.lookupMap(table:) ซึ่งสามารถสร้างการแมปจากพจนานุกรมได้

สำหรับตัวอย่างวิธีการแมปแบบกำหนดเองให้สำเร็จ โปรดดู โมเดล GPT-2 ซึ่งใช้ฟังก์ชันการแมปเพื่อให้ตรงกับรูปแบบการตั้งชื่อที่ใช้สำหรับจุดตรวจสอบของ OpenAI

ส่วนประกอบ CheckpointReader และ CheckpointWriter

สำหรับการเขียนจุดตรวจ ส่วนขยายที่จัดทำโดยโปรโตคอล Checkpointable จะใช้การสะท้อนและเส้นทางคีย์เพื่อวนซ้ำคุณสมบัติของแบบจำลอง และสร้างพจนานุกรมที่แมปเส้นทางสตริงเทนเซอร์กับค่าเทนเซอร์ พจนานุกรมนี้จัดทำขึ้นสำหรับ CheckpointWriter พื้นฐาน พร้อมด้วยไดเร็กทอรีสำหรับเขียนจุดตรวจสอบ CheckpointWriter นั้นจัดการงานสร้างจุดตรวจสอบบนดิสก์จากพจนานุกรมนั้น

สิ่งที่ตรงกันข้ามของกระบวนการนี้คือการอ่าน โดยที่ CheckpointReader จะได้รับตำแหน่งของไดเร็กทอรีจุดตรวจสอบบนดิสก์ จากนั้นจะอ่านจากจุดตรวจนั้นและสร้างพจนานุกรมที่จับคู่ชื่อของเทนเซอร์ภายในจุดตรวจด้วยค่าที่บันทึกไว้ พจนานุกรมนี้ใช้เพื่อแทนที่เทนเซอร์ปัจจุบันในแบบจำลองด้วยพจนานุกรมในพจนานุกรมนี้

สำหรับทั้งการโหลดและการบันทึก โปรโตคอล Checkpointable จะแมปพาธสตริงกับเทนเซอร์กับชื่อเทนเซอร์บนดิสก์ที่สอดคล้องกัน โดยใช้ฟังก์ชันการแมปที่อธิบายไว้ข้างต้น

หากโปรโตคอล Checkpointable ขาดฟังก์ชันการทำงานที่จำเป็น หรือต้องการการควบคุมเพิ่มเติมเกี่ยวกับกระบวนการโหลดและบันทึกจุดตรวจ คลาส CheckpointReader และ CheckpointWriter สามารถใช้งานได้ด้วยตัวเอง

รูปแบบจุดตรวจสอบ TensorFlow v2

รูปแบบจุดตรวจสอบ TensorFlow v2 ดังที่อธิบายไว้สั้นๆ ใน ส่วนหัวนี้ คือรูปแบบรุ่นที่สองสำหรับจุดตรวจสอบโมเดล TensorFlow รูปแบบรุ่นที่สองนี้มีการใช้งานมาตั้งแต่ปลายปี 2016 และมีการปรับปรุงหลายประการจากรูปแบบจุดตรวจสอบ v1 TensorFlow SavedModels ใช้จุดตรวจสอบ v2 ภายในเพื่อบันทึกพารามิเตอร์โมเดล

จุดตรวจสอบ TensorFlow v2 ประกอบด้วยไดเร็กทอรีที่มีโครงสร้างดังนี้:

checkpoint/modelname.index
checkpoint/modelname.data-00000-of-00002
checkpoint/modelname.data-00001-of-00002

โดยที่ไฟล์แรกเก็บข้อมูลเมตาสำหรับจุดตรวจสอบและไฟล์ที่เหลือเป็นไบนารีชาร์ดที่เก็บพารามิเตอร์ซีเรียลไลซ์สำหรับโมเดล

ไฟล์ข้อมูลเมตาดัชนีประกอบด้วยประเภท ขนาด ตำแหน่ง และชื่อสตริงของเทนเซอร์ที่ต่อเนื่องกันทั้งหมดที่มีอยู่ในชาร์ด ไฟล์ดัชนีนั้นเป็นส่วนที่มีโครงสร้างซับซ้อนที่สุดของจุดตรวจสอบ และอิงตาม tensorflow::table ซึ่งตัวมันเองอิงตาม SSTable / LevelDB ไฟล์ดัชนีนี้ประกอบด้วยชุดของคู่คีย์-ค่า โดยที่คีย์เป็นสตริง และค่าเป็นบัฟเฟอร์โปรโตคอล สตริงจะถูกจัดเรียงและบีบอัดคำนำหน้า ตัวอย่างเช่น หากรายการแรกคือ conv1/weight และ conv1/bias ถัดไป รายการที่สองจะใช้เฉพาะส่วน bias เท่านั้น

ไฟล์ดัชนีโดยรวมนี้บางครั้งถูกบีบอัดโดยใช้ Snappy Compression ไฟล์ SnappyDecompression.swift นำเสนอการใช้งาน Swift แบบเนทีฟของการบีบอัด Snappy จากอินสแตนซ์ข้อมูลที่บีบอัด

ข้อมูลเมตาของส่วนหัวดัชนีและข้อมูลเมตาของเทนเซอร์ได้รับการเข้ารหัสเป็นบัฟเฟอร์โปรโตคอลและเข้ารหัส / ถอดรหัสโดยตรงผ่าน Swift Protobuf

คลาส CheckpointIndexReader และ CheckpointIndexWriter จัดการการโหลดและบันทึกไฟล์ดัชนีเหล่านี้โดยเป็นส่วนหนึ่งของคลาส CheckpointReader และ CheckpointWriter ที่ครอบคลุม อย่างหลังใช้ไฟล์ดัชนีเป็นพื้นฐานในการกำหนดว่าจะอ่านและเขียนอะไรลงในไบนารี่ชาร์ดที่มีโครงสร้างง่ายกว่าซึ่งมีข้อมูลเทนเซอร์