Help protect the Great Barrier Reef with TensorFlow on Kaggle Join Challenge

保存和加载模型

TensorFlow.js 提供了保存和加载模型的功能,这些模型可以使用 Layers API 创建或从现有 TensorFlow 模型转换而来。可能是您自己训练的模型,也可能是其他人训练的模型。使用 Layers API 的一个主要好处是,使用它创建的模型是可序列化模型,这就是我们将在本教程中探讨的内容。

本教程将重点介绍如何保存和加载 TensorFlow.js 模型(可通过 JSON 文件识别)。我们也可以导入 TensorFlow Python 模型。以下两个教程介绍了如何加载这些模型:

保存 tf.Model

tf.Modeltf.Sequential 都提供了 model.save 函数,您可以借助该函数保存模型的拓扑权重

  • 拓扑:这是一个描述模型架构的文件(例如模型使用了哪些运算)。它包含对外部存储的模型权重的引用。

  • 权重:这些是以有效格式存储给定模型权重的二进制文件。它们通常存储在与拓扑相同的文件夹中。

我们来看看用于保存模型的代码:

const saveResult = await model.save('localstorage://my-model-1');

一些需要注意的地方:

  • save 方法采用以协议名称开头的类网址字符串参数。它描述了我们想保存模型的地址的类型。在上例中,协议名称为 localstorage://
  • 协议名称之后是路径。在上例中,路径是 my-model-1
  • save 方法是异步的。
  • model.save 的返回值是一个 JSON 对象,包含模型的拓扑和权重的字节大小等信息。
  • 用于保存模型的环境不会影响可以加载模型的环境。在 node.js 中保存模型不会阻碍在浏览器中加载模型。

我们将在下面查看不同协议名称。

本地存储空间(仅限浏览器)

协议名称: localstorage://

await model.save('localstorage://my-model');

这会在浏览器的本地存储空间中以名称 my-model 保存模型。这样能够在浏览器刷新后保持不变,而当存储空间成为问题时,用户或浏览器本身可以清除本地存储。每个浏览器还可为给定域设置本地存储空间中可以存储的数据量。

IndexedDB(仅限浏览器)

协议名称: indexeddb://

await model.save('indexeddb://my-model');

这会将模型保存到浏览器的 IndexedDB 存储空间中。与本地存储一样,它在刷新后仍然存在,同时所存储对象大小的上限更高。

文件下载(仅限浏览器)

协议名称: downloads://

await model.save('downloads://my-model');

这会让浏览器将模型文件下载至用户的机器上。将生成两个文件:

  1. 一个名为 [my-model].json 的 JSON 文本文件,其中包含模型拓扑和对下文所述权重文件的引用。
  2. 一个二进制文件,其中包含名为 [my-model].weights.bin 的权重值。

您可以更改 [my-model] 名称以获得一个名称不同的文件。

由于 .json 文件使用相对路径指向 .bin,因此两个文件应位于同一个文件夹中。

注:某些浏览器要求用户先授予权限,然后才能同时下载多个文件。

HTTP(S) 请求

协议名称http://https://

await model.save('http://model-server.domain/upload')

这将创建一个 Web 请求,以将模型保存到远程服务器。您应该控制该远程服务器,确保它能够处理该请求。

模型将通过 POST 请求发送至指定的 HTTP 服务器。POST 主体采用 multipart/form-data 格式并包含两个文件:

  1. 一个名为 model.json 的 JSON 文本文件,其中包含模型拓扑和对下文所述权重文件的引用。
  2. 一个二进制文件,其中包含名为 model.weights.bin 的权重值。

请注意,这两个文件的名称需要始终与上面所指定的完全相同(因为名称内置于函数中)。此 API 文档包含一个 Python 代码段,演示了如何使用 Flask Web 框架处理源自 save 的请求。

通常,您必须向 HTTP 服务器传递更多参数或请求头(例如,用于身份验证,或者如果要指定应保存模型的文件夹)。您可以通过替换 tf.io.browserHTTPRequest 中的网址字符串参数来获得对来自 save 的请求在这些方面的细粒度控制。此 API 在控制 HTTP 请求方面提供了更大的灵活性。

例如:

await model.save(tf.io.browserHTTPRequest(
    'http://model-server.domain/upload',
    {method: 'PUT', headers: {'header_key_1': 'header_value_1'} }));

原生文件系统(仅限 Node.js)

协议名称: file://

await model.save('file:///path/to/my-model');

在 Node.js 上运行时,我们还可以直接访问文件系统并保存模型。上面的命令会将两个文件保存到在 scheme 后指定的 path 中。

  1. 一个名为 [model].json 的 JSON 文本文件,其中包含模型拓扑和对下文所述权重文件的引用。
  2. 一个二进制文件,其中包含名为 [model].weights.bin 的权重值。

请注意,这两个文件的名称需要始终与上面所指定的完全相同(因为名称内置于函数中)。

加载 tf.Model

给定一个使用上述方法之一保存的模型,我们可以使用 tf.loadLayersModel API 加载它。

我们来看看加载模型的代码:

const model = await tf.loadLayersModel('localstorage://my-model-1');

一些需要注意的地方:

  • model.save() 类似,loadLayersModel 函数也采用以协议名称开头的类网址字符串参数。它描述了我们想要从中加载模型的目标类型。
  • 协议名称之后是路径。在上例中,路径是 my-model-1
  • 类网址字符串可以替换为与 IOHandler 接口匹配的对象。
  • tf.loadLayersModel() 函数是异步的。
  • tf.loadLayersModel 的返回值为 tf.Model

我们将在下面查看不同协议名称。

本地存储空间(仅限浏览器)

协议名称: localstorage://

const model = await tf.loadLayersModel('localstorage://my-model');

这将从浏览器的本地存储空间加载一个名为 my-model 的模型。

IndexedDB(仅限浏览器)

协议名称: indexeddb://

const model = await tf.loadLayersModel('indexeddb://my-model');

这将从浏览器的 IndexedDB 存储空间加载一个模型。

HTTP(S)

协议名称http://https://

const model = await tf.loadLayersModel('http://model-server.domain/download/model.json');

这将从 HTTP 端点加载模型。加载 json 文件后,函数将请求 json 文件引用的对应 .bin 文件。

注:此实现依赖于 fetch 方法,如果您的环境没有提供原生 fetch 方法,您可以提供满足接口要求的全局方法名称 fetch,或者使用类似于 (node-fetch)[https://www.npmjs.com/package/node-fetch] 的库。

原生文件系统(仅限 Node.js)

协议名称: file://

const model = await tf.loadLayersModel('file://path/to/my-model/model.json');

在 Node.js 上运行时,我们还可以直接访问文件系统并加载模型。请注意,在上面的函数调用中,我们引用 model.json 文件本身(在保存时,我们指定一个文件夹)。对应的 .bin 文件应与 json 文件位于同一个文件夹中。

使用 IOHandler 加载模型

如果上述协议名称没有满足您的需求,您可以使用 IOHandler 实现自定义加载行为。Tensorflow.js 提供的一个 IOHandlertf.io.browserFiles,它允许浏览器用户在浏览器中上传模型文件。请参阅文档了解更多信息。

使用自定义 IOHandler 保存或加载模型

如果上述协议名称没有满足您的保存或加载需求,您可以通过实现 IOHandler 来实现自定义序列化行为。

IOHandler 是一个包含 saveload 方法的对象。

save 函数采用一个与 ModelArtifacts 接口匹配的参数,应返回一个解析为 SaveResult 对象的 promise。

load 函数不采用参数,应返回一个解析为 ModelArtifacts 对象的 promise。这是传递给 save 的同一对象。

请参阅 BrowserHTTPRequest 获取如何实现 IOHandler 的示例。