Surowe operatory TensorFlow

Zobacz na TensorFlow.org Zobacz źródło w GitHub

Opierając się na TensorFlow, Swift dla TensorFlow przyjmuje świeże podejście do projektowania API. Interfejsy API są starannie wybierane z uznanych bibliotek i łączone z nowymi idiomami językowymi. Oznacza to, że nie wszystkie interfejsy API TensorFlow będą bezpośrednio dostępne jako interfejsy API Swift, a rozwój naszych interfejsów API wymaga czasu i wysiłku. Nie martw się jednak, jeśli Twój ulubiony operator TensorFlow nie jest dostępny w Swift — biblioteka TensorFlow Swift zapewnia przezroczysty dostęp do większości operatorów TensorFlow w przestrzeni nazw _Raw .

Aby rozpocząć, zaimportuj TensorFlow .

import TensorFlow

Wywoływanie surowych operatorów

Po prostu znajdź potrzebną funkcję w przestrzeni nazw _Raw poprzez uzupełnienie kodu.

print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]

Definiowanie nowego operatora mnożenia

Multiply jest już dostępne jako operator * na Tensor , ale załóżmy, że chcieliśmy udostępnić go pod nową nazwą jako .* . Swift umożliwia retroaktywne dodawanie metod lub obliczonych właściwości do istniejących typów przy użyciu deklaracji extension .

Teraz dodajmy .* do Tensor , deklarując rozszerzenie i udostępnijmy je, gdy typ Scalar tensora jest zgodny z Numeric .

infix operator .* : MultiplicationPrecedence

extension Tensor where Scalar: Numeric {
    static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
        return _Raw.mul(lhs, rhs)
    }
}

let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0],
 [18.0, 20.0]]

Definiowanie pochodnej opakowanej funkcji

Nie tylko możesz łatwo zdefiniować interfejs API Swift dla surowego operatora TensorFlow, możesz także sprawić, że będzie on różnicowany do pracy z najwyższej klasy automatycznym różnicowaniem Swift.

Aby .* było różniczkowalne, użyj atrybutu @derivative w funkcji pochodnej i określ oryginalną funkcję jako argument atrybutu pod etykietą of: Ponieważ operator .* jest zdefiniowany, gdy typ ogólny Scalar jest zgodny z Numeric , nie wystarczy, aby Tensor<Scalar> był zgodny z protokołem Differentiable . Swift, urodzony z bezpieczeństwem typów, przypomni nam o dodaniu ogólnego ograniczenia atrybutu @differentiable , aby wymagać Scalar zgodności z protokołem TensorFlowFloatingPoint , co sprawi, że Tensor<Scalar> będzie zgodny z Differentiable .

@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence

extension Tensor where Scalar: Numeric {
    @differentiable(where Scalar: TensorFlowFloatingPoint)
    static func .* (_ lhs: Tensor,  _ rhs: Tensor) -> Tensor {
        return _Raw.mul(lhs, rhs)
    }
}

extension Tensor where Scalar : TensorFlowFloatingPoint { 
    @derivative(of: .*)
    static func multiplyDerivative(
        _ lhs: Tensor, _ rhs: Tensor
    ) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
        return (lhs * rhs, { v in
            ((rhs * v).unbroadcasted(to: lhs.shape),
            (lhs * v).unbroadcasted(to: rhs.shape))
        })
    }
}

// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
    (x .* y).sum()
})
(0.0, 0.0)

Więcej przykładów

let matrix = Tensor<Float>([[1, 2], [3, 4]])

print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0],
 [10.0, 22.0]]
[[10.0, 14.0],
 [14.0, 20.0]]
[[ 5.0, 11.0],
 [11.0, 25.0]]
[[ 7.0, 10.0],
 [15.0, 22.0]]