TensorVisitorPlan

public struct TensorVisitorPlan<Base>

TensorVisitorPlan approximates [WritableKeyPath<Base, Tensor<Float>] but is more efficient. This is useful for writing generic optimizers which want to map over the gradients, the existing weights, and an index which can be used to find auxiliarily stored weights. This is slightly more efficient (~2x) but it could be better because it trades off slightly higher overheads (extra pointer dereference) for not having to do O(depth_of_tree) work that is required with a plain list to track down each individual KeyPath.

  • Flatten out the plan as a single [WritableKeyPath<Base, Tensor<Float>].

    Declaration

    public var allTensorKeyPaths: [WritableKeyPath<Base, Tensor<Float>>] { get }
  • Efficiently collect all the tensors.

    Declaration

    public func allTensors(_ v: Base) -> [Tensor<Float>]
  • Efficiently map over two values of type Base and apply a mapping function. Returns the number of tensors. The extra Int argument is provided to allow indexing into an auxiliary list of Tensors with the same Tensor count as the plan.

    Declaration

    @discardableResult
    public func mapTensors(
      _ v1: inout Base, _ v2: Base, _ fn: (inout Tensor<Float>, Tensor<Float>, Int) -> Void
    ) -> Int
  • Declaration

    func populateMask<Base>(_ mask: inout [Bool], _ kp: WritableKeyPath<Base, Tensor<Float>>)
  • Find all keys ending with a particular key-path.

    Declaration

    public func keysEnding<Base>(with kp: WritableKeyPath<Base, Tensor<Float>>) -> [Bool]
  • Declaration

    func findFirstIndex<TrueBase, T>(
      _ rootKeyPath: WritableKeyPath<TrueBase, Base>,
      _ prefix: WritableKeyPath<TrueBase, T>, _ i: inout Int
    ) -> Bool
  • Find the index of the first keypath starting with a particular prefix. Note: All array layers support 1-past-the-end indexing.

    Declaration

    func firstIndex<T>(withPrefix prefix: WritableKeyPath<Base, T>) -> Int
  • Find all keys indices in a range defined by two KeyPath prefixes: [lower, upper)

    Declaration

    public func allKeysBetween<T, U>(lower: WritableKeyPath<Base, T>, upper: WritableKeyPath<Base, U>)
      -> [Bool]
  • Creates a plan to visit all the tensors in a particular instance of Base. This plan is transferable to structurally equivalent versions of Base.

    Declaration

    public init(_ obj: Base)