Ver en TensorFlow.org | Ver código fuente en GitHub |
Basándose en TensorFlow, Swift para TensorFlow adopta un nuevo enfoque para el diseño de API. Las API se seleccionan cuidadosamente a partir de bibliotecas establecidas y se combinan con nuevos lenguajes. Esto significa que no todas las API de TensorFlow estarán disponibles directamente como API Swift, y nuestra selección de API necesita tiempo y esfuerzo dedicado para evolucionar. Sin embargo, no se preocupe si su operador TensorFlow favorito no está disponible en Swift: la biblioteca TensorFlow Swift le brinda acceso transparente a la mayoría de los operadores TensorFlow, bajo el espacio de nombres _Raw
.
Importe TensorFlow
para comenzar.
import TensorFlow
Llamar a operadores sin formato
Simplemente busque la función que necesita en el espacio de nombres _Raw
completando código.
print(_Raw.mul(Tensor([2.0, 3.0]), Tensor([5.0, 6.0])))
[10.0, 18.0]
Definiendo un nuevo operador multiplicador
Multiply ya está disponible como operador *
en Tensor
, pero supongamos que queremos que esté disponible con un nuevo nombre como .*
. Swift le permite agregar métodos o propiedades calculadas de forma retroactiva a tipos existentes mediante declaraciones extension
.
Ahora, agreguemos .*
a Tensor
declarando una extensión y haciéndola disponible cuando el tipo Scalar
del tensor se ajuste a Numeric
.
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
let x: Tensor<Double> = [[1.0, 2.0], [3.0, 4.0]]
let y: Tensor<Double> = [[8.0, 7.0], [6.0, 5.0]]
print(x .* y)
[[ 8.0, 14.0], [18.0, 20.0]]
Definir una derivada de una función envuelta
No solo puede definir fácilmente una API Swift para un operador TensorFlow sin formato, sino que también puede hacerla diferenciable para que funcione con la diferenciación automática de primera clase de Swift.
Para hacer .*
diferenciable, use el atributo @derivative
en la función derivada y especifique la función original como un argumento de atributo bajo la etiqueta of:
Dado que el operador .*
se define cuando el tipo genérico Scalar
se ajusta a Numeric
, no es suficiente para hacer que Tensor<Scalar>
se ajuste al protocolo Differentiable
. Nacido con seguridad de tipos, Swift nos recordará que agreguemos una restricción genérica en el atributo @differentiable
para requerir que Scalar
se ajuste al protocolo TensorFlowFloatingPoint
, lo que haría que Tensor<Scalar>
se ajuste a Differentiable
.
@differentiable(where Scalar: TensorFlowFloatingPoint)
infix operator .* : MultiplicationPrecedence
extension Tensor where Scalar: Numeric {
@differentiable(where Scalar: TensorFlowFloatingPoint)
static func .* (_ lhs: Tensor, _ rhs: Tensor) -> Tensor {
return _Raw.mul(lhs, rhs)
}
}
extension Tensor where Scalar : TensorFlowFloatingPoint {
@derivative(of: .*)
static func multiplyDerivative(
_ lhs: Tensor, _ rhs: Tensor
) -> (value: Tensor, pullback: (Tensor) -> (Tensor, Tensor)) {
return (lhs * rhs, { v in
((rhs * v).unbroadcasted(to: lhs.shape),
(lhs * v).unbroadcasted(to: rhs.shape))
})
}
}
// Now, we can take the derivative of a function that calls `.*` that we just defined.
print(gradient(at: x, y) { x, y in
(x .* y).sum()
})
(0.0, 0.0)
Más ejemplos
let matrix = Tensor<Float>([[1, 2], [3, 4]])
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: true, transposeB: false))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: true))
print(_Raw.matMul(matrix, matrix, transposeA: false, transposeB: false))
[[ 7.0, 15.0], [10.0, 22.0]] [[10.0, 14.0], [14.0, 20.0]] [[ 5.0, 11.0], [11.0, 25.0]] [[ 7.0, 10.0], [15.0, 22.0]]