Mô hình đã lưu có thể tái sử dụng

Giới thiệu

TensorFlow Hub lưu trữ SavingModels cho TensorFlow 2, cùng với các nội dung khác. Chúng có thể được tải lại vào chương trình Python với obj = hub.load(url) [ tìm hiểu thêm ]. obj được trả về là kết quả của tf.saved_model.load() (xem hướng dẫn SavingModel của TensorFlow). Đối tượng này có thể có các thuộc tính tùy ý là tf.functions, tf.Variables (được khởi tạo từ các giá trị được đào tạo trước của chúng), các tài nguyên khác và, theo cách đệ quy, nhiều đối tượng như vậy hơn.

Trang này mô tả một giao diện sẽ được triển khai bởi obj đã tải để tái sử dụng trong chương trình TensorFlow Python. Các SavingModel phù hợp với giao diện này được gọi là SavingModels có thể tái sử dụng .

Tái sử dụng có nghĩa là xây dựng một mô hình lớn hơn xung quanh obj , bao gồm cả khả năng tinh chỉnh nó. Tinh chỉnh có nghĩa là huấn luyện thêm các trọng số trong obj được tải như một phần của mô hình xung quanh. Hàm mất mát và trình tối ưu hóa được xác định bởi mô hình xung quanh; obj chỉ xác định ánh xạ kích hoạt đầu vào đến đầu ra ("chuyển tiếp"), có thể bao gồm các kỹ thuật như bỏ học hoặc chuẩn hóa hàng loạt.

Nhóm TensorFlow Hub khuyên bạn nên triển khai giao diện SavingModel có thể tái sử dụng trong tất cả các SavingModel được thiết kế để tái sử dụng theo nghĩa trên. Nhiều tiện ích từ thư viện tensorflow_hub , đặc biệt là hub.KerasLayer , yêu cầu SavingModels triển khai nó.

Mối quan hệ với SignatureDefs

Giao diện này xét về mặt tf.functions và các tính năng TF2 khác tách biệt với các chữ ký của SavingModel, đã có từ TF1 và tiếp tục được sử dụng trong TF2 để suy luận (chẳng hạn như triển khai SavingModels cho TF Serve hoặc TF Lite). Chữ ký để suy luận không đủ biểu cảm để hỗ trợ tinh chỉnh và tf.function cung cấp API Python tự nhiên và biểu cảm hơn cho mô hình được sử dụng lại.

Liên quan đến thư viện xây dựng mô hình

SavingModel có thể tái sử dụng chỉ sử dụng các nguyên hàm TensorFlow 2, độc lập với bất kỳ thư viện xây dựng mô hình cụ thể nào như Keras hoặc Sonnet. Điều này tạo điều kiện thuận lợi cho việc tái sử dụng trên các thư viện xây dựng mô hình, không bị phụ thuộc vào mã xây dựng mô hình ban đầu.

Sẽ cần một số mức độ thích ứng để tải các Mô hình đã lưu có thể tái sử dụng vào hoặc lưu chúng từ bất kỳ thư viện xây dựng mô hình cụ thể nào. Đối với Keras, hub.KerasLayer cung cấp khả năng tải và tính năng lưu tích hợp của Keras ở định dạng SavingModel đã được thiết kế lại cho TF2 với mục tiêu cung cấp siêu bộ giao diện này (xem RFC từ tháng 5 năm 2019).

Liên quan đến "API mô hình lưu chung" dành riêng cho nhiệm vụ

Định nghĩa giao diện trên trang này cho phép bất kỳ số lượng và loại đầu vào và đầu ra nào. API Common SavingModel dành cho TF Hub tinh chỉnh giao diện chung này với các quy ước sử dụng cho các tác vụ cụ thể để làm cho các mô hình có thể dễ dàng thay thế cho nhau.

Định nghĩa giao diện

Thuộc tính

Một SavingModel có thể tái sử dụng là một SavingModel của TensorFlow 2 sao cho obj = tf.saved_model.load(...) trả về một đối tượng có các thuộc tính sau

  • __call__ . Yêu cầu. Một tf.function triển khai tính toán của mô hình ("chuyển tiếp") tuân theo thông số kỹ thuật bên dưới.

  • variables : Danh sách các đối tượng tf.Variable, liệt kê tất cả các biến được sử dụng bởi bất kỳ lệnh gọi __call__ nào có thể có, bao gồm cả những biến có thể huấn luyện và không thể huấn luyện.

    Danh sách này có thể được bỏ qua nếu trống.

  • trainable_variables : Danh sách các đối tượng tf.Variable sao cho v.trainable đúng với tất cả các phần tử. Các biến này phải là tập hợp con của variables . Đây là các biến cần được huấn luyện khi tinh chỉnh đối tượng. Người tạo SavingModel có thể chọn bỏ qua một số biến ở đây mà ban đầu có thể huấn luyện được để chỉ ra rằng những biến này không nên sửa đổi trong quá trình tinh chỉnh.

    Danh sách này có thể được bỏ qua nếu trống, đặc biệt nếu SavingModel không hỗ trợ tinh chỉnh.

  • regularization_losses : Một danh sách các tf.functions, mỗi hàm nhận đầu vào bằng 0 và trả về một tensor float vô hướng duy nhất. Để tinh chỉnh, người dùng SavingModel nên đưa những điều này làm thuật ngữ chính quy bổ sung vào phần mất (trong trường hợp đơn giản nhất mà không cần mở rộng thêm). Thông thường, chúng được sử dụng để đại diện cho việc điều chỉnh cân nặng. (Vì thiếu đầu vào, các tf.functions này không thể biểu thị các bộ điều chỉnh hoạt động.)

    Danh sách này có thể được bỏ qua nếu trống, đặc biệt, nếu SavingModel không hỗ trợ tinh chỉnh hoặc không muốn quy định việc điều chỉnh trọng số.

Hàm __call__

Một mô hình đã lưu được khôi phục obj có thuộc tính obj.__call__ là một tf.function được khôi phục và cho phép gọi obj như sau.

Tóm tắt (mã giả):

outputs = obj(inputs, trainable=..., **kwargs)

Tranh luận

Các lập luận như sau.

  • Có một đối số bắt buộc, vị trí với một loạt kích hoạt đầu vào của SavingModel. Loại của nó là một trong

    • một Tensor duy nhất cho một đầu vào,
    • một danh sách các Tensors cho một chuỗi các đầu vào chưa được đặt tên theo thứ tự,
    • một lệnh Tensors được khóa bởi một tập hợp tên đầu vào cụ thể.

    (Các bản sửa đổi trong tương lai của giao diện này có thể cho phép các tổ tổng quát hơn.) Người tạo SavingModel chọn một trong những thứ đó và các hình dạng tensor và dtype. Khi hữu ích, một số kích thước của hình dạng sẽ không được xác định (đặc biệt là kích thước lô).

  • Có thể có một training đối số từ khóa tùy chọn chấp nhận boolean Python, True hoặc False . Mặc định này False . Nếu mô hình hỗ trợ tinh chỉnh và nếu tính toán của nó khác nhau giữa hai mô hình (ví dụ như trong trường hợp bỏ học và chuẩn hóa hàng loạt), thì sự khác biệt đó sẽ được thực hiện với đối số này. Nếu không, lập luận này có thể vắng mặt.

    Không bắt buộc __call__ phải chấp nhận đối số training có giá trị Tensor. Người gọi sẽ phải sử dụng tf.cond() nếu cần để gửi giữa chúng.

  • Người tạo SavingModel có thể chọn chấp nhận nhiều kwargs tùy chọn hơn với các tên cụ thể.

    • Đối với các đối số có giá trị Tensor, trình tạo SavingModel xác định các kiểu và hình dạng được phép của chúng. tf.function chấp nhận giá trị mặc định của Python trên đối số được theo dõi bằng đầu vào tf.TensorSpec. Các đối số như vậy có thể được sử dụng để cho phép tùy chỉnh các siêu tham số số liên quan đến __call__ (ví dụ: tỷ lệ bỏ học).

    • Đối với các đối số có giá trị Python, trình tạo SavingModel xác định các giá trị được phép của chúng. Các đối số như vậy có thể được sử dụng làm cờ để đưa ra các lựa chọn riêng biệt trong hàm được theo dõi (nhưng hãy lưu ý đến sự bùng nổ tổ hợp của các dấu vết).

Hàm __call__ được khôi phục phải cung cấp dấu vết cho tất cả các tổ hợp đối số được phép. training lật giữa TrueFalse không được làm thay đổi khả năng cho phép của các đối số.

Kết quả

outputs từ việc gọi obj có thể là

  • một Tensor duy nhất cho một đầu ra duy nhất,
  • một danh sách các Tensors cho một chuỗi các đầu ra chưa được đặt tên theo thứ tự,
  • một lệnh của Tensors được khóa bởi một tập hợp tên đầu ra cụ thể.

(Các bản sửa đổi trong tương lai của giao diện này có thể cho phép các tổ tổng quát hơn.) Kiểu trả về có thể khác nhau tùy thuộc vào kwargs có giá trị Python. Điều này cho phép cờ tạo ra đầu ra bổ sung. Trình tạo SavingModel xác định các kiểu và hình dạng đầu ra cũng như sự phụ thuộc của chúng vào đầu vào.

Có thể gọi được đặt tên

Một SavingModel có thể tái sử dụng có thể cung cấp nhiều phần mô hình theo cách được mô tả ở trên bằng cách đặt chúng vào các đối tượng con được đặt tên, ví dụ: obj.foo , obj.bar , v.v. Mỗi đối tượng con cung cấp một phương thức __call__ và các thuộc tính hỗ trợ về các biến, v.v. dành riêng cho phần mô hình đó. Đối với ví dụ trên, sẽ có obj.foo.__call__ , obj.foo.variables , v.v.

Lưu ý rằng giao diện này không bao gồm cách thêm trực tiếp tf.function dưới dạng tf.foo .

Người dùng SavingModels có thể tái sử dụng chỉ được yêu cầu xử lý một cấp độ lồng nhau ( obj.bar chứ không phải obj.bar.baz ). (Các bản sửa đổi trong tương lai của giao diện này có thể cho phép lồng sâu hơn và có thể loại bỏ yêu cầu rằng đối tượng cấp cao nhất có thể gọi được chính nó.)

Đóng nhận xét

Liên quan đến các API trong quá trình

Tài liệu này mô tả giao diện của một lớp Python bao gồm các nguyên hàm như tf.function và tf.Variable tồn tại trong một chuyến đi khứ hồi thông qua tuần tự hóa thông qua tf.saved_model.save()tf.saved_model.load() . Tuy nhiên, giao diện đã có sẵn trên đối tượng ban đầu được chuyển đến tf.saved_model.save() . Việc thích ứng với giao diện đó cho phép trao đổi các phần mô hình trên các API xây dựng mô hình trong một chương trình TensorFlow duy nhất.