Nhập mô hình TensorFlow vào TensorFlow.js

Các mô hình dựa trên TensorFlow GraphDef (thường được tạo thông qua API Python) có thể được lưu ở một trong các định dạng sau:

  1. TensorFlow SavingMô hình
  2. Người mẫu đông lạnh
  3. Mô-đun trung tâm Tensorflow

Tất cả các định dạng trên có thể được chuyển đổi bằng trình chuyển đổi TensorFlow.js thành định dạng có thể tải trực tiếp vào TensorFlow.js để suy luận.

(Lưu ý: TensorFlow không còn dùng định dạng gói phiên nữa. Vui lòng di chuyển mô hình của bạn sang định dạng SavingModel.)

Yêu cầu

Quy trình chuyển đổi yêu cầu môi trường Python; bạn có thể muốn giữ một cái riêng biệt bằng cách sử dụng pipenv hoặc virtualenv .

Để cài đặt bộ chuyển đổi, hãy chạy lệnh sau:

 pip install tensorflowjs

Nhập mô hình TensorFlow vào TensorFlow.js là một quá trình gồm hai bước. Đầu tiên, chuyển đổi mô hình hiện có sang định dạng web TensorFlow.js, sau đó tải mô hình đó vào TensorFlow.js.

Bước 1. Chuyển đổi mô hình TensorFlow hiện có sang định dạng web TensorFlow.js

Chạy tập lệnh chuyển đổi được cung cấp bởi gói pip:

Ví dụ về SavingModel:

tensorflowjs_converter \
    --input_format=tf_saved_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    --saved_model_tags=serve \
    /mobilenet/saved_model \
    /mobilenet/web_model

Ví dụ về mô hình đông lạnh:

tensorflowjs_converter \
    --input_format=tf_frozen_model \
    --output_node_names='MobilenetV1/Predictions/Reshape_1' \
    /mobilenet/frozen_model.pb \
    /mobilenet/web_model

Ví dụ về mô-đun Tensorflow Hub:

tensorflowjs_converter \
    --input_format=tf_hub \
    'https://tfhub.dev/google/imagenet/mobilenet_v1_100_224/classification/1' \
    /mobilenet/web_model
Đối số vị trí Sự miêu tả
input_path Đường dẫn đầy đủ của thư mục mô hình đã lưu, thư mục gói phiên, tệp mô hình cố định hoặc đường dẫn hoặc điều khiển mô-đun TensorFlow Hub.
output_path Đường dẫn cho tất cả các tạo phẩm đầu ra.
Tùy chọn Sự miêu tả
--input_format Định dạng của mô hình đầu vào. Sử dụng tf_saved_model cho SavingModel, tf_frozen_model cho mô hình cố định, tf_session_bundle cho gói phiên, tf_hub cho mô-đun TensorFlow Hub và máy ảnh cho Keras HDF5.
--output_node_names Tên của các nút đầu ra, được phân tách bằng dấu phẩy.
--saved_model_tags Chỉ áp dụng cho chuyển đổi SavingModel. Các thẻ của MetaGraphDef cần tải, ở định dạng được phân tách bằng dấu phẩy. Mặc định để serve .
--signature_name Chỉ áp dụng cho chuyển đổi mô-đun TensorFlow Hub, chữ ký để tải. Mặc định là default . Xem https://www.tensorflow.org/hub/common_signatures/

Sử dụng lệnh sau để nhận thông báo trợ giúp chi tiết:

tensorflowjs_converter --help

Chuyển đổi tập tin được tạo

Tập lệnh chuyển đổi ở trên tạo ra hai loại tệp:

  • model.json : Biểu đồ luồng dữ liệu và bảng kê khai trọng số
  • group1-shard\*of\* : Tập hợp các tệp trọng số nhị phân

Ví dụ: đây là kết quả từ việc chuyển đổi MobileNet v2:

  output_directory/model.json
  output_directory/group1-shard1of5
  ...
  output_directory/group1-shard5of5

Bước 2: Load và chạy trên trình duyệt

  1. Cài đặt gói npm tfjs-converter:

yarn add @tensorflow/tfjs hoặc npm install @tensorflow/tfjs

  1. Khởi tạo lớp FrozenModel và chạy suy luận.
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';

const MODEL_URL = 'model_directory/model.json';

const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.browser.fromPixels(cat));

Hãy xem bản demo MobileNet .

API loadGraphModel chấp nhận tham số LoadOptions bổ sung, tham số này có thể được sử dụng để gửi thông tin xác thực hoặc tiêu đề tùy chỉnh cùng với yêu cầu. Để biết chi tiết, hãy xem tài liệu LoadGraphModel() .

Các hoạt động được hỗ trợ

Hiện tại, TensorFlow.js hỗ trợ một số hoạt động TensorFlow có giới hạn. Nếu mô hình của bạn sử dụng op không được hỗ trợ, tập lệnh tensorflowjs_converter sẽ không thành công và in ra danh sách các op không được hỗ trợ trong mô hình của bạn. Vui lòng gửi vấn đề cho từng hoạt động để cho chúng tôi biết bạn cần hỗ trợ cho hoạt động nào.

Chỉ tải trọng lượng

Nếu bạn chỉ muốn tải trọng số, bạn có thể sử dụng đoạn mã sau:

import * as tf from '@tensorflow/tfjs';

const weightManifestUrl = "https://example.org/model/weights_manifest.json";

const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
        this.weightManifest, "https://example.org/model");
// Use `weightMap` ...