テンソルフロー::作戦::ギャザV2

#include <array_ops.h>

indicesに従ってparams axis axisからスライスを収集します

まとめ

indices 、任意の次元 (通常は 0 次元または 1 次元) の整数テンソルでなければなりません。形状params.shape[:axis] + indices.shape + params.shape[axis + 1:]の出力テンソルを生成します。ここで、

    # Scalar indices (output is rank(params) - 1).
    output[a_0, ..., a_n, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices, b_0, ..., b_n]

    # Vector indices (output is rank(params)).
    output[a_0, ..., a_n, i, b_0, ..., b_n] =
      params[a_0, ..., a_n, indices[i], b_0, ..., b_n]

    # Higher rank indices (output is rank(params) + rank(indices) - 1).
    output[a_0, ..., a_n, i, ..., j, b_0, ... b_n] =
      params[a_0, ..., a_n, indices[i, ..., j], b_0, ..., b_n]

CPU では、範囲外のインデックスが見つかった場合、エラーが返されることに注意してください。 GPU では、範囲外のインデックスが見つかった場合、対応する出力値に 0 が格納されます。

tf.batch_gatherおよびtf.gather_ndも参照してください。

引数:

  • スコープ:スコープオブジェクト
  • params: 値を収集するテンソル。少なくともランクaxis + 1である必要があります。
  • インデックス: インデックス テンソル。 [0, params.shape[axis])の範囲内にある必要があります。
  • axis: indicesを収集するparams内の軸。デフォルトは最初の次元です。負のインデックスをサポートします。

戻り値:

  • Output : indicesで指定されたインデックスから収集されたparamsの値。形状はparams.shape[:axis] + indices.shape + params.shape[axis + 1:]です。

コンストラクターとデストラクター

GatherV2 (const :: tensorflow::Scope & scope, :: tensorflow::Input params, :: tensorflow::Input indices, :: tensorflow::Input axis)
GatherV2 (const :: tensorflow::Scope & scope, :: tensorflow::Input params, :: tensorflow::Input indices, :: tensorflow::Input axis, const GatherV2::Attrs & attrs)

パブリック属性

operation
output

公共機能

node () const
::tensorflow::Node *
operator::tensorflow::Input () const
operator::tensorflow::Output () const

パブリック静的関数

BatchDims (int64 x)

構造体

tensorflow:: ops:: GatherV2:: Attrs

GatherV2のオプションの属性セッター。

パブリック属性

手術

Operation operation

出力

::tensorflow::Output output

公共機能

ギャザV2

 GatherV2(
  const ::tensorflow::Scope & scope,
  ::tensorflow::Input params,
  ::tensorflow::Input indices,
  ::tensorflow::Input axis
)

ギャザV2

 GatherV2(
  const ::tensorflow::Scope & scope,
  ::tensorflow::Input params,
  ::tensorflow::Input indices,
  ::tensorflow::Input axis,
  const GatherV2::Attrs & attrs
)

ノード

::tensorflow::Node * node() const 

演算子::tensorflow::入力

 operator::tensorflow::Input() const 

演算子::tensorflow::出力

 operator::tensorflow::Output() const 

パブリック静的関数

バッチディム

Attrs BatchDims(
  int64 x
)