TensorFlow GraphDef模型(一般是通过Python API创建的)可以保存成以下几种格式:
以上所有格式都可以被TensorFlow.js converter转换成TensorFlow.js可读取的模型格式,并用于推算(inference)。
(注意:TensorFlow已经淘汰了session bundle格式,请将您的模型转换成SavedModel格式。)
必要条件
模型转换的工作需要用到Python环境;你可以用pipenv 或 virtualenv创建一个隔离的环境。用这条命令安装模型转换器:
pip install tensorflowjs
将TensorFlow模型引入到TensorFlow.js需要两个步骤。首先,将您的模型转换为TensorFlow.js可用的web格式,然后载入到TensorFlow.js中。
第一步:将TensorFlow模型转换至TensorFlow.js可用的 web 格式模型
运行转换器提供的转换脚本:
用法:以SavedModel为例:
tensorflowjs_converter \
--input_format=tf_saved_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
--saved_model_tags=serve \
/mobilenet/saved_model \
/mobilenet/web_model
Frozen model 为例:
tensorflowjs_converter \
--input_format=tf_frozen_model \
--output_node_names='MobilenetV1/Predictions/Reshape_1' \
/mobilenet/frozen_model.pb \
/mobilenet/web_model
Tensorflow Hub module 为例:
tensorflowjs_converter \
--input_format=tf_hub \
'https://hub.tensorflow.google.cn/google/imagenet/mobilenet_v1_100_224/classification/1' \
/mobilenet/web_model
脚本参数 | 描述 |
---|---|
input_path |
saved model, session bundle 或 frozen model的完整的路径,或TensorFlow Hub模块的路径。 |
output_path |
输出文件的保存路径。 |
选项 | 描述 |
---|---|
--input_format |
要转换的模型的格式。SavedModel 为 tf_saved_model, frozen model 为 tf_frozen_model, session bundle 为 tf_session_bundle, TensorFlow Hub module 为 tf_hub,Keras HDF5 为 keras。 |
--output_node_names |
输出节点的名字,每个名字用逗号分离。 |
--saved_model_tags |
只对SavedModel转换用的选项:输入需要加载的MetaGraphDef相对应的tag,多个tag请用逗号分隔。默认为 serve 。 |
--signature_name |
只对TensorFlow Hub module转换用的选项:对应要加载的签名,默认为default 。请参考 https://tensorflow.google.cn/hub/common_signatures/ |
用以下命令查看帮助信息:
tensorflowjs_converter --help
转换器产生的文件
转换脚本会产生两种文件:
model.json
(数据流图和权重清单)group1-shard\*of\*
(二进制权重文件)
这里举例Mobilenet v2模型转换后输出的文件:
output_directory/model.json
output_directory/group1-shard1of5
...
output_directory/group1-shard5of5
第二步:在浏览器加载和运行模型
- 安装tfjs-convert npm包:
yarn add @tensorflow/tfjs
或 npm install @tensorflow/tfjs
- 创建 FrozenModel class 并开始推算:
import * as tf from '@tensorflow/tfjs';
import {loadGraphModel} from '@tensorflow/tfjs-converter';
const MODEL_URL = 'model_directory/model.json';
const model = await loadGraphModel(MODEL_URL);
const cat = document.getElementById('cat');
model.execute(tf.fromPixels(cat));
具体代码请参考 MobileNet 演示.
loadGraphModel
API中的LoadOptions
参数可以用来发送密钥或者自定义请求中的头文件。更多信息请参考 loadGraphModel() 文档。
支持的操作
目前,TensorFlow.js只支持部分TensorFlow算子。若您的模型包含了不被支持的算子,tensorflowjs_converter
脚本会报错并列出您的模型中不被支持的算子。请在github上发起 issue让我们知道您需要支持的算子。
加载模型权重
若您只需要加载模型的权重,请参考以下代码:
import * as tf from '@tensorflow/tfjs';
const weightManifestUrl = "https://example.org/model/weights_manifest.json";
const manifest = await fetch(weightManifestUrl);
this.weightManifest = await manifest.json();
const weightMap = await tf.io.loadWeights(
this.weightManifest, "https://example.org/model");