TensorFlow Java をインストールする

TensorFlow Java は、すべての JVM で動作させて、機械学習モデルの構築、トレーニング、デプロイに使用することができます。グラフモードまたは eager モードで CPU 実行と GPU 実行をサポートし、JVM 環境で TensorFlow を使用するためのリッチな API を提供します。Java や Scala、Kotlin のような JVM 言語は、世界中の大小さまざまな企業でよく使用されていることから、TensorFlow Java は機械学習を大規模に導入するための戦略的選択となります。

要件

TensorFlow Java は Java 8 以降で動作し、以下のプラットフォームの初期構成に対応しています。

  • Ubuntu 16.04 以上、64 ビット、x86
  • macOS 10.12.6(Sierra)以上、64 ビット、x86
  • Windows 7 以上、64 ビット、x86

バージョン

TensorFlow Java には、TensorFlow ランタイムから独立した、独自のリリース サイクルがあります。したがって、そのバージョンは、その下で実行される TensorFlow ランタイムのバージョンと一致しません。TensorFlow Java のバージョン対応表で、すべての利用可能なバージョンとそれに対応する TensorFlow ランタイムを確認してください。

アーティファクト

プロジェクトに TensorFlow Java を追加するには、いくつかの方法があります。最も簡単なのは、tensorflow-core-platform アーティファクトの依存関係を追加することです。これには、TensorFlow Java Core API と、すべての対応プラットフォームでの実行に必要なネイティブの依存関係の両方が含まれています。

純粋な CPU バージョンではなく、以下の拡張のいずれかを選択することもできます。

  • tensorflow-core-platform-mkl: すべてのプラットフォームでの Intel® MKL-DNN のサポート
  • tensorflow-core-platform-gpu: Linux と Windows のプラットフォームでの CUDA® のサポート
  • tensorflow-core-platform-mkl-gpu: Linux と Windows のプラットフォームでの Intel® MKL-DNN と CUDA® のサポート

上記に加えて、tensorflow-framework ライブラリの依存関係を別途追加すると、JVM 上で TensorFlow ベースの機械学習を行うのに役立つ豊富なユーティリティを使用できます。

Maven を使ったインストール

Maven アプリケーションに TensorFlow を追加するには、そのアーティファクトの依存関係をプロジェクトの pom.xml ファイルに追加します。次に例を示します。

<dependency>
  <groupId>org.tensorflow</groupId>
  <artifactId>tensorflow-core-platform</artifactId>
  <version>0.2.0</version>
</dependency>

依存関係の削減

tensorflow-core-platform アーティファクトの依存関係を追加すると、サポートしている全プラットフォームのネイティブ ライブラリが読み込まれ、プロジェクトが著しく肥大化する可能性があることに留意が必要です。

Maven の依存関係除外機能を使用すると、他のプラットフォームの不要なアーティファクトを除外して、対応プラットフォームの一部をターゲットとすることができます。

アプリケーションで対応するプラットフォームを選択する別の方法として、JavaCPP システム プロパティを Maven のコマンドラインまたは pom.xml で設定することができます。詳しくは、JavaCPP のドキュメントをご覧ください。

スナップショットの使用

TensorFlow Java のソース リポジトリをもとにした最新の開発スナップショットは、OSS Sonatype Nexus のリポジトリから入手できます。このアーティファクトに依存関係を設定するには、pom.xml で OSS スナップショットのリポジトリを設定してください。

<repositories>
    <repository>
        <id>tensorflow-snapshots</id>
        <url>https://oss.sonatype.org/content/repositories/snapshots/</url>
        <snapshots>
            <enabled>true</enabled>
        </snapshots>
    </repository>
</repositories>

<dependencies>
    <dependency>
        <groupId>org.tensorflow</groupId>
        <artifactId>tensorflow-core-platform</artifactId>
        <version>0.3.0-SNAPSHOT</version>
    </dependency>
</dependencies>

Gradle を使ったインストール

Gradle アプリケーションに TensorFlow を追加するには、そのアーティファクトの依存関係をプロジェクトの build.gradle ファイルに追加します。次に例を示します。

repositories {
    mavenCentral()
}

dependencies {
    compile group: 'org.tensorflow', name: 'tensorflow-core-platform', version: '0.2.0'
}

依存関係の削減

Gradle で TensorFlow Java のネイティブ アーティファクトを除外するのは、Maven の場合ほど簡単ではありません。この場合の依存関係を減らすには、Gradle JavaCPP プラグインの使用をおすすめします。

詳しくは、Gradle JavaCPP のドキュメントをご覧ください。

ソースからのインストール

ソースから TensorFlow Java をビルドし、あるいはさらにカスタマイズも行う方法については、以下の手順をご覧ください。

サンプル プログラム

この例では、TensorFlow を使用する Apache Maven プロジェクトをビルドする方法を示します。まず、TensorFlow の依存関係をプロジェクトの pom.xml ファイルに追加します。

<project>
    <modelVersion>4.0.0</modelVersion>
    <groupId>org.myorg</groupId>
    <artifactId>hellotensorflow</artifactId>
    <version>1.0-SNAPSHOT</version>

    <properties>
        <exec.mainClass>HelloTensorFlow</exec.mainClass>
    <!-- Minimal version for compiling TensorFlow Java is JDK 8 -->
        <maven.compiler.source>1.8</maven.compiler.source>
        <maven.compiler.target>1.8</maven.compiler.target>
    </properties>

    <dependencies>
    <!-- Include TensorFlow (pure CPU only) for all supported platforms -->
        <dependency>
            <groupId>org.tensorflow</groupId>
            <artifactId>tensorflow-core-platform</artifactId>
            <version>0.2.0</version>
        </dependency>
    </dependencies>
</project>

ソースファイル(src/main/java/HelloTensorFlow.java)を作成します。

import org.tensorflow.ConcreteFunction;
import org.tensorflow.Signature;
import org.tensorflow.Tensor;
import org.tensorflow.TensorFlow;
import org.tensorflow.op.Ops;
import org.tensorflow.op.core.Placeholder;
import org.tensorflow.op.math.Add;
import org.tensorflow.types.TInt32;

public class HelloTensorFlow {

  public static void main(String[] args) throws Exception {
    System.out.println("Hello TensorFlow " + TensorFlow.version());

    try (ConcreteFunction dbl = ConcreteFunction.create(HelloTensorFlow::dbl);
        Tensor<TInt32> x = TInt32.scalarOf(10);
        Tensor<TInt32> dblX = dbl.call(x).expect(TInt32.DTYPE)) {
      System.out.println(x.data().getInt() + " doubled is " + dblX.data().getInt());
    }
  }

  private static Signature dbl(Ops tf) {
    Placeholder<TInt32> x = tf.placeholder(TInt32.DTYPE);
    Add<TInt32> dblX = tf.math.add(x, x);
    return Signature.builder().input("x", x).output("dbl", dblX).build();
  }
}

コンパイルして実行します。

mvn -q compile exec:java

このコマンドから「TensorFlow version and a simple calculation.」と出力されます。

完了です。TensorFlow Java が構成されました。