AdaDelta

public class AdaDelta<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & PointwiseMultiplicative
    & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

بهینه ساز AdaDelta.

الگوریتم بهینه سازی AdaDelta را پیاده سازی می کند. AdaDelta یک روش نزولی گرادیان تصادفی بر اساس اطلاعات مرتبه اول است. نرخ های یادگیری را بر اساس یک پنجره متحرک از به روز رسانی های گرادیان، به جای انباشته کردن همه گرادیان های گذشته، تطبیق می دهد. بنابراین، AdaDelta حتی زمانی که به روز رسانی های زیادی انجام شده است به یادگیری ادامه می دهد. سریعتر با دینامیک متغیر فضای مسئله بهینه سازی سازگار می شود.

مرجع: "ADADELTA: یک روش نرخ یادگیری تطبیقی" (Zeiler، 2012)

  • اعلام

    public typealias Model = Model
  • میزان یادگیری

    اعلام

    public var learningRate: Float
  • rho

    ضریب واپاشی، مربوط به کسری از گرادیان که باید در هر مرحله زمانی حفظ شود.

    اعلام

    public var rho: Float
  • یک اسکالر کوچک برای بهبود ثبات عددی به مخرج اضافه شده است.

    اعلام

    public var epsilon: Float
  • کاهش نرخ یادگیری

    اعلام

    public var decay: Float
  • مرحله فعلی.

    اعلام

    public var step: Int
  • میانگین انباشته شده و در حال فروپاشی نمایی گرادیان های مربع.

    اعلام

    public var averageSquared: Model.TangentVector
  • پارامتر انباشته به روز رسانی می شود.

    اعلام

    public var accumulatedDelta: Model.TangentVector
  • یک نمونه برای model ایجاد می کند.

    اعلام

    public init(
      for model: __shared Model,
      learningRate: Float = 1,
      rho: Float = 0.95,
      epsilon: Float = 1e-6,
      decay: Float = 0
    )

    مولفه های

    learningRate

    میزان یادگیری مقدار پیش فرض 1 است.

    rho

    عامل پوسیدگی مقدار پیش فرض 0.95 است.

    epsilon

    یک اسکالر کوچک برای بهبود ثبات عددی به مخرج اضافه شده است. مقدار پیش فرض 1e-6 است.

    decay

    کاهش نرخ یادگیری مقدار پیش فرض 0 است.

  • اعلام

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • اعلام

    public required init(copying other: AdaDelta, to device: Device)