Introducing X10

%install '.package(url: "https://github.com/tensorflow/swift-models", .branch("tensorflow-0.12"))' Datasets ImageClassificationModels
print("\u{001B}[2J")


View on TensorFlow.org View source on GitHub

By default, Swift For TensorFlow performs tensor operations using eager dispatch. This allows for rapid iteration, but isn't the most performant option for training machine learning models.

The X10 tensor library adds a high-performance backend to Swift for TensorFlow, leveraging tensor tracing and the XLA compiler. This tutorial will introduce X10 and guide you through the process of updating a training loop to run on GPUs or TPUs.

Eager vs. X10 tensors

Accelerated calculations in Swift for TensorFlow are performed through the Tensor type. Tensors can participate in a wide variety of operations, and are the fundamental building blocks of machine learning models.

By default, a Tensor uses eager execution to perform calculations on an operation-by-operation basis. Each Tensor has an associated Device that describes what hardware it is attached to and what backend is used for it.

import TensorFlow
import Foundation
let eagerTensor1 = Tensor([0.0, 1.0, 2.0])
let eagerTensor2 = Tensor([1.5, 2.5, 3.5])
let eagerTensorSum = eagerTensor1 + eagerTensor2
print(eagerTensorSum)
[1.5, 3.5, 5.5]

print(eagerTensor1.device)
Device(kind: .CPU, ordinal: 0, backend: .TF_EAGER)

If you are running this notebook on a GPU-enabled instance, you should see that hardware reflected in the device description above. The eager runtime does not have support for TPUs, so if you are using one of them as an accelerator you will see the CPU being used as a hardware target.

When creating a Tensor, the default eager mode device can be overridden by specifying an alternative. This is how you opt-in to performing calculations using the X10 backend.

let x10Tensor1 = Tensor([0.0, 1.0, 2.0], on: Device.defaultXLA)
let x10Tensor2 = Tensor([1.5, 2.5, 3.5], on: Device.defaultXLA)
let x10TensorSum = x10Tensor1 + x10Tensor2
print(x10TensorSum)
[1.5, 3.5, 5.5]

print(x10Tensor1.device)
Device(kind: .CPU, ordinal: 0, backend: .XLA)

If you're running this in a GPU-enabled instance, you should see that accelerator listed in the X10 tensor's device. Unlike for eager execution, if you are running this in a TPU-enabled instance, you should now see that calculations are using that device. X10 is how you take advantage of TPUs within Swift for TensorFlow.

The default eager and X10 devices will attempt to use the first accelerator on the system. If you have GPUs attached, the will use the first available GPU. If TPUs are present, X10 will use the first TPU core by default. If no accelerator is found or supported, the default device will fall back to the CPU.

Beyond the default eager and XLA devices, you can provide specific hardware and backend targets in a Device:

// let tpu1 = Device(kind: .TPU, ordinal: 1, backend: .XLA)
// let tpuTensor1 = Tensor([0.0, 1.0, 2.0], on: tpu1)

Training an eager-mode model

Let's take a look at how you'd set up and train a model using the default eager execution mode. In this example, we'll be using the simple LeNet-5 model from the swift-models repository and the MNIST handwritten digit classification dataset.

First, we'll set up and download the MNIST dataset.

import Datasets

let epochCount = 5
let batchSize = 128
let dataset = MNIST(batchSize: batchSize)
Loading resource: train-images-idx3-ubyte
File does not exist locally at expected path: /home/kbuilder/.cache/swift-models/datasets/MNIST/train-images-idx3-ubyte and must be fetched
Fetching URL: https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz...
Archive saved to: /home/kbuilder/.cache/swift-models/datasets/MNIST
Loading resource: train-labels-idx1-ubyte
File does not exist locally at expected path: /home/kbuilder/.cache/swift-models/datasets/MNIST/train-labels-idx1-ubyte and must be fetched
Fetching URL: https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz...
Archive saved to: /home/kbuilder/.cache/swift-models/datasets/MNIST
Loading resource: t10k-images-idx3-ubyte
File does not exist locally at expected path: /home/kbuilder/.cache/swift-models/datasets/MNIST/t10k-images-idx3-ubyte and must be fetched
Fetching URL: https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz...
Archive saved to: /home/kbuilder/.cache/swift-models/datasets/MNIST
Loading resource: t10k-labels-idx1-ubyte
File does not exist locally at expected path: /home/kbuilder/.cache/swift-models/datasets/MNIST/t10k-labels-idx1-ubyte and must be fetched
Fetching URL: https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz...
Archive saved to: /home/kbuilder/.cache/swift-models/datasets/MNIST

Next, we will configure the model and optimizer.

import ImageClassificationModels

var eagerModel = LeNet()
var eagerOptimizer = SGD(for: eagerModel, learningRate: 0.1)

Now, we will implement basic progress tracking and reporting. All intermediate statistics are kept as tensors on the same device where training is run and scalarized() is called only during reporting. This will be especially important later when using X10, because it avoids unnecessary materialization of lazy tensors.

struct Statistics {
    var correctGuessCount = Tensor<Int32>(0, on: Device.default)
    var totalGuessCount = Tensor<Int32>(0, on: Device.default)
    var totalLoss = Tensor<Float>(0, on: Device.default)
    var batches: Int = 0
    var accuracy: Float { 
        Float(correctGuessCount.scalarized()) / Float(totalGuessCount.scalarized()) * 100 
    } 
    var averageLoss: Float { totalLoss.scalarized() / Float(batches) }

    init(on device: Device = Device.default) {
        correctGuessCount = Tensor<Int32>(0, on: device)
        totalGuessCount = Tensor<Int32>(0, on: device)
        totalLoss = Tensor<Float>(0, on: device)
    }

    mutating func update(logits: Tensor<Float>, labels: Tensor<Int32>, loss: Tensor<Float>) {
        let correct = logits.argmax(squeezingAxis: 1) .== labels
        correctGuessCount += Tensor<Int32>(correct).sum()
        totalGuessCount += Int32(labels.shape[0])
        totalLoss += loss
        batches += 1
    }
}

Finally, we'll run the model through a training loop for five epochs.

print("Beginning training...")

for (epoch, batches) in dataset.training.prefix(epochCount).enumerated() {
    let start = Date()
    var trainStats = Statistics()
    var testStats = Statistics()

    Context.local.learningPhase = .training
    for batch in batches {
        let (images, labels) = (batch.data, batch.label)
        let 𝛁model = TensorFlow.gradient(at: eagerModel) { eagerModel -> Tensor<Float> in
            let ŷ = eagerModel(images)
            let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
            trainStats.update(logits: ŷ, labels: labels, loss: loss)
            return loss
        }
        eagerOptimizer.update(&eagerModel, along: 𝛁model)
    }

    Context.local.learningPhase = .inference
    for batch in dataset.validation {
        let (images, labels) = (batch.data, batch.label)
        let ŷ = eagerModel(images)
        let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
        testStats.update(logits: ŷ, labels: labels, loss: loss)
    }

    print(
        """
        [Epoch \(epoch)] \
        Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \
        Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
        (\(String(format: "%.1f", trainStats.accuracy))%), \
        Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \
        Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
        (\(String(format: "%.1f", testStats.accuracy))%) \
        seconds per epoch: \(String(format: "%.1f", Date().timeIntervalSince(start)))
        """)
}
Beginning training...
[Epoch 0] Training Loss: 0.528, Training Accuracy: 50154/59904 (83.7%), Test Loss: 0.168, Test Accuracy: 9468/10000 (94.7%) seconds per epoch: 11.9
[Epoch 1] Training Loss: 0.133, Training Accuracy: 57488/59904 (96.0%), Test Loss: 0.107, Test Accuracy: 9659/10000 (96.6%) seconds per epoch: 11.7
[Epoch 2] Training Loss: 0.092, Training Accuracy: 58193/59904 (97.1%), Test Loss: 0.069, Test Accuracy: 9782/10000 (97.8%) seconds per epoch: 11.8
[Epoch 3] Training Loss: 0.071, Training Accuracy: 58577/59904 (97.8%), Test Loss: 0.066, Test Accuracy: 9794/10000 (97.9%) seconds per epoch: 11.8
[Epoch 4] Training Loss: 0.059, Training Accuracy: 58800/59904 (98.2%), Test Loss: 0.064, Test Accuracy: 9800/10000 (98.0%) seconds per epoch: 11.8

As you can see, the model trained as we would expect, and its accuracy against the validation set increased each epoch. This is how Swift for TensorFlow models are defined and run using eager execution, now let's see what modifications need to be made to take advantage of X10.

Training an X10 model

Datasets, models, and optimizers contain tensors that are initialized on the default eager execution device. To work with X10, we'll need to move these tensors to an X10 device.

let device = Device.defaultXLA
print(device)
Device(kind: .CPU, ordinal: 0, backend: .XLA)

For the datasets, we'll do that at the point in which batches are processed in the training loop, so we can re-use the dataset from the eager execution model.

In the case of the model and optimizer, we'll initialize them with their internal tensors on the eager execution device, then move them over to the X10 device.

var x10Model = LeNet()
x10Model.move(to: device)

var x10Optimizer = SGD(for: x10Model, learningRate: 0.1)
x10Optimizer = SGD(copying: x10Optimizer, to: device)

The modifications needed for the training loop come at a few specific points. First, we'll need to move the batches of training data over to the X10 device. This is done via Tensor(copying:to:) when each batch is retrieved.

The next change is to indicate where to cut off the traces during the training loop. X10 works by tracing through the tensor calculations needed in your code and just-in-time compiling an optimized representation of that trace. In the case of a training loop, you’re repeating the same operation over and over again, an ideal section to trace, compile, and re-use.

In the absence of code that explicitly requests a value from a Tensor (these usually stand out as .scalars or .scalarized() calls), X10 will attempt to compile all loop iterations together. To prevent this, and cut the trace at a specific point, we place an explicit LazyTensorBarrier() after the optimizer updates the model weights and after the loss and accuracy are obtained during validation. This creates two reused traces: each step in the training loop and each batch of inference during validation.

These modifications result in the following training loop.

print("Beginning training...")

for (epoch, batches) in dataset.training.prefix(epochCount).enumerated() {
    let start = Date()
    var trainStats = Statistics(on: device)
    var testStats = Statistics(on: device)

    Context.local.learningPhase = .training
    for batch in batches {
        let (eagerImages, eagerLabels) = (batch.data, batch.label)
        let images = Tensor(copying: eagerImages, to: device)
        let labels = Tensor(copying: eagerLabels, to: device)
        let 𝛁model = TensorFlow.gradient(at: x10Model) { x10Model -> Tensor<Float> in
            let ŷ = x10Model(images)
            let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
            trainStats.update(logits: ŷ, labels: labels, loss: loss)
            return loss
        }
        x10Optimizer.update(&x10Model, along: 𝛁model)
        LazyTensorBarrier()
    }

    Context.local.learningPhase = .inference
    for batch in dataset.validation {
        let (eagerImages, eagerLabels) = (batch.data, batch.label)
        let images = Tensor(copying: eagerImages, to: device)
        let labels = Tensor(copying: eagerLabels, to: device)
        let ŷ = x10Model(images)
        let loss = softmaxCrossEntropy(logits: ŷ, labels: labels)
        LazyTensorBarrier()
        testStats.update(logits: ŷ, labels: labels, loss: loss)
    }

    print(
        """
        [Epoch \(epoch)] \
        Training Loss: \(String(format: "%.3f", trainStats.averageLoss)), \
        Training Accuracy: \(trainStats.correctGuessCount)/\(trainStats.totalGuessCount) \
        (\(String(format: "%.1f", trainStats.accuracy))%), \
        Test Loss: \(String(format: "%.3f", testStats.averageLoss)), \
        Test Accuracy: \(testStats.correctGuessCount)/\(testStats.totalGuessCount) \
        (\(String(format: "%.1f", testStats.accuracy))%) \
        seconds per epoch: \(String(format: "%.1f", Date().timeIntervalSince(start)))
        """)
}
Beginning training...
[Epoch 0] Training Loss: 0.421, Training Accuracy: 51888/59904 (86.6%), Test Loss: 0.134, Test Accuracy: 9557/10000 (95.6%) seconds per epoch: 18.6
[Epoch 1] Training Loss: 0.117, Training Accuracy: 57733/59904 (96.4%), Test Loss: 0.085, Test Accuracy: 9735/10000 (97.3%) seconds per epoch: 14.9
[Epoch 2] Training Loss: 0.080, Training Accuracy: 58400/59904 (97.5%), Test Loss: 0.068, Test Accuracy: 9791/10000 (97.9%) seconds per epoch: 13.1
[Epoch 3] Training Loss: 0.064, Training Accuracy: 58684/59904 (98.0%), Test Loss: 0.056, Test Accuracy: 9804/10000 (98.0%) seconds per epoch: 13.5
[Epoch 4] Training Loss: 0.053, Training Accuracy: 58909/59904 (98.3%), Test Loss: 0.063, Test Accuracy: 9779/10000 (97.8%) seconds per epoch: 13.4

Training of the model using the X10 backend should have proceeded in the same manner as the eager execution model did before. You may have noticed a delay before the first batch and at the end of the first epoch, due to the just-in-time compilation of the unique traces at those points. If you're running this with an accelerator attached, you should have seen the training after that point proceeding faster than with eager mode.

There is a tradeoff of initial trace compilation time vs. faster throughput, but in most machine learning models the increase in throughput from repeated operations should more than offset compilation overhead. In practice, we've seen an over 4X improvement in throughput with X10 in some training cases.

As has been stated before, using X10 now makes it not only possible but easy to work with TPUs, unlocking that whole class of accelerators for your Swift for TensorFlow models.