public class AdaDelta<Model: Differentiable>: Optimizer
where
Model.TangentVector: VectorProtocol & PointwiseMultiplicative
& ElementaryFunctions & KeyPathIterable,
Model.TangentVector.VectorSpaceScalar == Float
Un optimiseur AdaDelta.
Implémente l'algorithme d'optimisation AdaDelta. AdaDelta est une méthode de descente de gradient stochastique basée sur les informations du premier ordre. Il adapte les taux d'apprentissage en fonction d'une fenêtre mobile de mises à jour des dégradés, au lieu d'accumuler tous les dégradés passés. Ainsi, AdaDelta continue d'apprendre même lorsque de nombreuses mises à jour ont été effectuées. Il s'adapte plus rapidement à la dynamique changeante de l'espace des problèmes d'optimisation.
Référence : « ADADELTA : Une méthode de taux d'apprentissage adaptatif » (Zeiler, 2012)
Déclaration
public typealias Model = Model
Le taux d'apprentissage.
Déclaration
public var learningRate: Float
Le facteur de décroissance, correspondant à la fraction de gradient à conserver à chaque pas de temps.
Déclaration
public var rho: Float
Un petit scalaire ajouté au dénominateur pour améliorer la stabilité numérique.
Déclaration
public var epsilon: Float
Le taux d’apprentissage diminue.
Déclaration
public var decay: Float
L'étape actuelle.
Déclaration
public var step: Int
La moyenne accumulée et en décroissance exponentielle des gradients carrés.
Déclaration
public var averageSquared: Model.TangentVector
Les paramètres accumulés sont mis à jour.
Déclaration
public var accumulatedDelta: Model.TangentVector
Crée une instance pour
model
.Déclaration
public init( for model: __shared Model, learningRate: Float = 1, rho: Float = 0.95, epsilon: Float = 1e-6, decay: Float = 0 )
Paramètres
learningRate
Le taux d'apprentissage. La valeur par défaut est
1
.rho
Le facteur de désintégration. La valeur par défaut est
0.95
.epsilon
Un petit scalaire ajouté au dénominateur pour améliorer la stabilité numérique. La valeur par défaut est
1e-6
.decay
Le taux d’apprentissage diminue. La valeur par défaut est
0
.Déclaration
public required init(copying other: AdaDelta, to device: Device)