TensorFlow CPU
TensorFlow CPU 软件包可以按如下方式导入:
import * as tf from '@tensorflow/tfjs-node'
从此软件包导入 TensorFlow.js 时,您导入的模块将由 TensorFlow C 二进制文件加速并在 CPU 上运行。CPU 上的 TensorFlow 使用硬件加速来加速后台的线性代数运算。
此软件包可以在支持 TensorFlow 的 Linux、Windows 和 Mac 平台上运行。
注:您不必导入 '@tensorflow/tfjs' 或者将其添加到您的 package.json 中。它由 Node 库间接导入。
TensorFlow GPU
TensorFlow GPU 软件包可以按如下方式导入:
import * as tf from '@tensorflow/tfjs-node-gpu'
与 CPU 软件包一样,您导入的模块将由 TensorFlow C 二进制文件加速,但是它将在支持 CUDA 的 GPU 上运行张量运算,因此只能在 Linux 平台上运行。此绑定比其他绑定选项至少快一个数量级。
注:此软件包目前仅适用于 CUDA。在选择本方案之前,您需要在带有 NVIDIA 显卡的的计算机上安装 CUDA。
注:您不必导入 '@tensorflow/tfjs' 或者将其添加到您的 package.json 中。它由 Node 库间接导入。
普通 CPU
使用普通 CPU 运算运行的 TensorFlow.js 版本可以按如下方式导入:
import * as tf from '@tensorflow/tfjs'
此软件包与您在浏览器中使用的软件包相同。在此软件包中,运算在 CPU 上以原生 JavaScript 运行。此软件包比其他软件包小得多,因为它不需要 TensorFlow 二进制文件,但是速度要慢得多。
由于此软件包不依赖于 TensorFlow,因此它可用于支持 Node.js 的更多设备,而不仅仅是 Linux、Windows 和 Mac 平台。
生产考量因素
Node.js 绑定为 TensorFlow.js 提供了一个同步执行运算的后端。这意味着当您调用一个运算(例如 tf.matMul(a, b)
)时,它将阻塞主线程,直到运算完成。
因此,绑定当前非常适合脚本和离线任务。如果您要在正式应用(例如网络服务器)中使用 Node.js 绑定,应设置一个作业队列或设置一些工作进程线程,以便您的 TensorFlow.js 代码不会阻塞主线程。
API
一旦您在上面的任何选项中将软件包作为 tf 导入,所有普通的 TensorFlow.js 符号都将出现在导入的模块上。
tf.browser
在普通的 TensorFlow.js 软件包中,tf.browser.*
命名空间中的符号将在 Node.js 中不可用,因为它们使用浏览器特定的 API。
目前,存在以下 API:
- tf.browser.fromPixels
- tf.browser.toPixels
tf.node
两个 Node.js 软件包还提供了一个名为 tf.node
的命名空间,其中包含 Node 特定的 API。
TensorBoard 是一个值得注意的 Node.js 特定的 API 示例。
在 Node.js 中将摘要导出到 TensorBoard 的示例:
const model = tf.sequential();
model.add(tf.layers.dense({ units: 1, inputShape: [200] }));
model.compile({
loss: 'meanSquaredError',
optimizer: 'sgd',
metrics: ['MAE']
});
// Generate some random fake data for demo purpose.
const xs = tf.randomUniform([10000, 200]);
const ys = tf.randomUniform([10000, 1]);
const valXs = tf.randomUniform([1000, 200]);
const valYs = tf.randomUniform([1000, 1]);
// Start model training process.
async function train() {
await model.fit(xs, ys, {
epochs: 100,
validationData: [valXs, valYs],
// Add the tensorBoard callback here.
callbacks: tf.node.tensorBoard('/tmp/fit_logs_1')
});
}
train();