安装 TensorFlow Java

TensorFlow Java 可在任何 JVM 上运行,用于构建、训练和部署机器学习模型。它既支持 CPU 执行也支持 GPU 执行(无论是图表模式下还是 eager 模式下),并为在 JVM 环境中使用 TensorFlow 提供了丰富的 API。世界各地的大小企业都在频繁使用 Java 和其他 JVM 语言(比如 Scala 和 Kotlin),有鉴于此,若要大规模采用机器学习,TensorFlow Java 无疑是一项战略选择。

要求

TensorFlow Java 可在 Java 8 及更高版本上运行,并可在以下平台上即开即用:

  • Ubuntu 16.04 或更高版本;64 位,x86
  • macOS 10.12.6 (Sierra) 或更高版本;64 位,x86
  • Windows 7 或更高版本;64 位,x86

版本

TensorFlow Java 有自己独立的发布周期,与 TensorFlow 运行时无关。因此,它的版本与作为其运行环境的 TensorFlow 运行时的版本不匹配。请查阅 TensorFlow Java 版本控制表,了解所有可用的版本以及它们与 TensorFlow 运行时的映射关系。

工件

您可通过多种方式将 TensorFlow Java 添加到您的项目中。最简便的方式是,添加 tensorflow-core-platform 工件的某个依赖项,该工件既包含 TensorFlow Java Core API,也包含在所有受支持的平台上运行时需要用到的原生依赖项。

您也可选择下列扩展程序之一(而非单纯的 CPU 版本):

  • tensorflow-core-platform-mkl:支持所有平台上的 Intel® MKL-DNN
  • tensorflow-core-platform-gpu:支持 Linux 和 Windows 平台上的 CUDA®
  • tensorflow-core-platform-mkl-gpu:支持 Linux 和 Windows 平台上的 Intel® MKL-DNN 和 CUDA®。

您还可再添加一个 tensorflow-framework 库依赖项,以便能够利用丰富多样的实用程序在 JVM 上进行基于 TensorFlow 的机器学习。

使用 Maven 进行安装

若要将 TensorFlow 纳入到您的 Maven 应用内,请将其 工件的某个依赖项添加到您项目的 pom.xml 文件中。 例如:

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow-core-platform</artifactId>
  <version>0.2.0</version>
</dependency>

减少依赖项数量

请务必注意,添加 tensorflow-core-platform 工件的依赖项会导入所有受支持平台的原生库,这可能会导致您项目的规模显著增大。

若想仅定位到一部分可用平台,您可以使用 Maven 依赖项排除功能排除其他平台的非必要工件。

还有一种方法可供您用来选择要将哪些平台纳入到您的应用内,那就是:在您的 Maven 命令行或您的 pom.xml 中设置 JavaCPP 系统属性。如需了解更多详情,请参阅 JavaCPP 文档

使用快照

您可从 OSS Sonatype Nexus 库获取 TensorFlow Java 源代码库的最新 TensorFlow Java 开发快照。若要依赖这些工件,请务必在您的 pom.xml 中配置 OSS 快照库。

<repositories>
    <repository>
        <id>tensorflow-snapshots</id>
        <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        <snapshots>
            <enabled>true</enabled>
        </snapshots>
    </repository>
</repositories>

<dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow-core-platform</artifactId>
        <version>0.3.0-SNAPSHOT</version>
    </dependency>
</dependencies>

使用 Gradle 进行安装

若要将 TensorFlow 纳入到您的 Gradle 应用内,请将其 工件的某个依赖项添加到您项目的 build.gradle 文件中。 例如:

repositories {
    mavenCentral()
}

dependencies {
    compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.2.0'
}

减少依赖项数量

在使用 Gradle 时排除 TensorFlow Java 原生工件并不像使用 Maven 时那般容易。我们建议您使用 Gradle JavaCPP 插件来减少这些依赖项的数量。

如需了解更多详情,请参阅 Gradle JavaCPP 文档

从源代码进行安装

若要从源代码构建 TensorFlow Java,并在可能的情况下自定义它,请阅读以下说明

示例程序

以下示例演示了如何使用 TensorFlow 构建 Apache Maven 项目。首先,将 TensorFlow 依赖项添加到项目的 pom.xml 文件中:

<project>
    <modelVersion>4.0.0</modelVersion>
    <groupId>org.myorg</groupId>
    <artifactId>hellotensorflow</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <exec.mainClass>HelloTensorFlow</exec.mainClass>
    <!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>

    <dependencies>
    <!-- Include TensorFlow (pure CPU only) for all supported platforms -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.2.0</version>
        </dependency>
    </dependencies>
</project>

创建源代码文件 src/main/java/HelloTensorFlow.java

import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;

public class HelloTensorFlow {

  public static void main(String[] args) throws Exception {
    System.out.println("Hello TensorFlow " + TensorFlow.version());

    try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
        Tensor<TInt32> x = TInt32.scalarOf(10);
        Tensor<TInt32> dblX = dbl.call(x).expect(TInt32.DTYPE)) {
      System.out.println(x.data().getInt() + " doubled is " + dblX.data().getInt());
    }
  }

  private static Signature dbl(Ops tf) {
    Placeholder<TInt32> x = tf.placeholder(TInt32.DTYPE);
    Add<TInt32> dblX = tf.math.add(x, x);
    return Signature.builder().input("x", x).output("dbl", dblX).build();
  }
}

编译并执行:

mvn -q compile exec:java

该命令会输出:TensorFlow version and a simple calculation.

大功告成!TensorFlow Java 已配置完毕。