Zobacz na TensorFlow.org | Zobacz źródło w GitHub |
Opierając się na TensorFlow, Swift dla TensorFlow przyjmuje świeże podejście do projektowania API. Interfejsy API są starannie wybierane z uznanych bibliotek i łączone z nowymi idiomami językowymi. Oznacza to, że nie wszystkie interfejsy API TensorFlow będą bezpośrednio dostępne jako interfejsy API Swift, a rozwój naszych interfejsów API wymaga czasu i wysiłku. Nie martw się jednak, jeśli Twój ulubiony operator TensorFlow nie jest dostępny w Swift — biblioteka TensorFlow Swift zapewnia przezroczysty dostęp do większości operatorów TensorFlow w przestrzeni nazw _Raw
.
Aby rozpocząć, zaimportuj TensorFlow
.
import TensorFlow
Wywoływanie surowych operatorów
Po prostu znajdź potrzebną funkcję w przestrzeni nazw _Raw
poprzez uzupełnienie kodu.
print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]
Definiowanie nowego operatora mnożenia
Multiply jest już dostępne jako operator *
na Tensor
, ale załóżmy, że chcieliśmy udostępnić go pod nową nazwą jako .*
. Swift umożliwia retroaktywne dodawanie metod lub obliczonych właściwości do istniejących typów przy użyciu deklaracji extension
.
Teraz dodajmy .*
do Tensor
, deklarując rozszerzenie i udostępnijmy je, gdy typ Scalar
tensora jest zgodny z Numeric
.
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0], [18.0, 20.0]]
Definiowanie pochodnej opakowanej funkcji
Nie tylko możesz łatwo zdefiniować interfejs API Swift dla surowego operatora TensorFlow, możesz także sprawić, że będzie on różnicowany do pracy z najwyższej klasy automatycznym różnicowaniem Swift.
Aby .*
było różniczkowalne, użyj atrybutu @derivative
w funkcji pochodnej i określ oryginalną funkcję jako argument atrybutu pod etykietą of:
Ponieważ operator .*
jest zdefiniowany, gdy typ ogólny Scalar
jest zgodny z Numeric
, nie wystarczy, aby Tensor<Scalar>
był zgodny z protokołem Differentiable
. Swift, urodzony z bezpieczeństwem typów, przypomni nam o dodaniu ogólnego ograniczenia atrybutu @differentiable
, aby wymagać Scalar
zgodności z protokołem TensorFlowFloatingPoint
, co sprawi, że Tensor<Scalar>
będzie zgodny z Differentiable
.
@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
@differentiable(where Scalar: TensorFlowFloatingPoint)
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
extension Tensor where Scalar : TensorFlowFloatingPoint {
@derivative(of: .*)
static func multiplyDerivative(
_ lhs: Tensor, _ rhs: Tensor
) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, { v in
((rhs * v).unbroadcasted(to: lhs.shape),
(lhs * v).unbroadcasted(to: rhs.shape))
})
}
}
// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
(x .* y).sum()
})
(0.0, 0.0)
Więcej przykładów
let matrix = Tensor<Float>([[1, 2], [3, 4]])
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0], [10.0, 22.0]] [[10.0, 14.0], [14.0, 20.0]] [[ 5.0, 11.0], [11.0, 25.0]] [[ 7.0, 10.0], [15.0, 22.0]]