AdaDelta

public class AdaDelta<Model: Differentiable>: Optimizer
where
  Model.TangentVector: VectorProtocol & PointwiseMultiplicative
    & ElementaryFunctions & KeyPathIterable,
  Model.TangentVector.VectorSpaceScalar == Float

بهینه ساز AdaDelta.

الگوریتم بهینه سازی AdaDelta را پیاده سازی می کند. AdaDelta یک روش نزولی گرادیان تصادفی بر اساس اطلاعات مرتبه اول است. نرخ های یادگیری را بر اساس یک پنجره متحرک از به روز رسانی های گرادیان، به جای انباشته کردن همه گرادیان های گذشته، تطبیق می دهد. بنابراین، AdaDelta حتی زمانی که به روز رسانی های زیادی انجام شده است به یادگیری ادامه می دهد. سریعتر با دینامیک متغیر فضای مسئله بهینه سازی سازگار می شود.

مرجع: "ADADELTA: یک روش نرخ یادگیری تطبیقی" (Zeiler، 2012)

  • public typealias Model = Model
  • میزان یادگیری

    اعلامیه

    public var learningRate: Float
  • rho

    ضریب واپاشی، مربوط به کسری از گرادیان که باید در هر مرحله زمانی حفظ شود.

    اعلامیه

    public var rho: Float
  • یک اسکالر کوچک برای بهبود ثبات عددی به مخرج اضافه شده است.

    اعلامیه

    public var epsilon: Float
  • کاهش نرخ یادگیری

    اعلامیه

    public var decay: Float
  • مرحله فعلی.

    اعلامیه

    public var step: Int
  • میانگین انباشته شده و در حال فروپاشی نمایی گرادیان های مربع.

    اعلامیه

    public var averageSquared: Model.TangentVector
  • پارامتر انباشته به روز رسانی می شود.

    اعلامیه

    public var accumulatedDelta: Model.TangentVector
  • یک نمونه برای model ایجاد می کند.

    اعلامیه

    public init(
      for model: __shared Model,
      learningRate: Float = 1,
      rho: Float = 0.95,
      epsilon: Float = 1e-6,
      decay: Float = 0
    )

    پارامترها

    learningRate

    میزان یادگیری مقدار پیش فرض 1 است.

    rho

    عامل پوسیدگی مقدار پیش فرض 0.95 است.

    epsilon

    یک اسکالر کوچک برای بهبود ثبات عددی به مخرج اضافه شده است. مقدار پیش فرض 1e-6 است.

    decay

    کاهش نرخ یادگیری مقدار پیش فرض 0 است.

  • اعلامیه

    public func update(_ model: inout Model, along direction: Model.TangentVector)
  • اعلامیه

    public required init(copying other: AdaDelta, to device: Device)