TF Hub 的通用 SavedModel API

简介

TensorFlow Hub 托管了适用于各种任务的模型。建议为同一任务的模型实现一个通用 API,以便模型使用者能够轻松地换用不同模型,而无需修改使用这些模型的代码,即使模型来自不同的发布者也是如此。

目的是让使用者能够针对同一任务轻松换用不同的模型,就像切换采用字符串值的超参数一样简单。这样,模型使用者可以轻松找到最能解决其问题的模型。

此目录下收录了采用 TF2 SavedModel 格式的模型的通用 API 规范。(它取代现已弃用的 TF1 Hub 格式通用签名。)

可重复使用的 SavedModel:通用基础

Reusable SavedModel API 定义了如何将 SavedModel 重新加载到 Python 程序并将其作为整个 TensorFlow 模型的一部分重复使用的一般规则。

基本用法:

obj = hub.load("path/to/model")  # That's tf.saved_model.load() after download.
outputs = obj(inputs, training=False)  # Invokes the tf.function obj.__call__.

对于 Keras 用户,hub.KerasLayer 类会借助此 API 将可重复使用的 SavedModel 封装为 Keras 层(让 Keras 用户无需处理细节),其中输入和输出视下文列出的任务特有 API 而定。

任务特有的 API

这些 API 可依据特定机器学习任务和数据类型的惯例优化 Reusable SavedModel SavedModel API