Have a question? Connect with the community at the TensorFlow Forum Visit Forum

SGD

public class SGD<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

A stochastic gradient descent (SGD) optimizer.

Implements the stochastic gradient descent algorithm with support for momentum, learning rate decay, and Nesterov momentum. Momentum and Nesterov momentum (a.k.a. the Nesterov accelerated gradient method) are first-order optimization methods that can improve the training speed and convergence rate of gradient descent.

References:

  • Declaration

    public typealias Model = Model
  • The learning rate.

    Declaration

    public var learningRate: Float
  • The momentum factor. It accelerates stochastic gradient descent in the relevant direction and dampens oscillations.

    Declaration

    public var momentum: Float
  • The learning rate decay.

    Declaration

    public var decay: Float
  • Use Nesterov momentum if true.

    Declaration

    public var nesterov: Bool
  • The velocity state of the model.

    Declaration

    public var velocity: Model.TangentVector
  • The set of steps taken.

    Declaration

    public var step: Int
  • Creates an instance for model.

    Declaration

    public init(
      for model: __shared Model,
      learningRate: Float = 0.01,
      momentum: Float = 0,
      decay: Float = 0,
      nesterov: Bool = false
    )

    Parameters

    learningRate

    The learning rate. The default value is 0.01.

    momentum

    The momentum factor that accelerates stochastic gradient descent in the relevant direction and dampens oscillations. The default value is 0.

    decay

    The learning rate decay. The default value is 0.

    nesterov

    Use Nesterov momentum iff true. The default value is true.

  • Declaration

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • Declaration

    public required init(copying other: SGD, to device: Device)