テンソルフロー::作戦::すべての候補者サンプラー
#include <candidate_sampling_ops.h>
学習されたユニグラム分布を使用して候補サンプリングのラベルを生成します。
まとめ
go/candidate-sampling で候補サンプリングとデータ形式の説明を参照してください。
この操作は、バッチごとに、サンプリングされた候補ラベルの単一セットを選択します。
バッチごとに候補をサンプリングする利点は、単純さと効率的な密行列乗算の可能性です。欠点は、サンプリングされた候補がコンテキストや真のラベルとは独立して選択されなければならないことです。
引数:
- スコープ:スコープオブジェクト
- true_classes:batch_size * num_true 行列。各行には、対応する元のラベルの num_true target_classes の ID が含まれます。
- num_true: コンテキストごとの真のラベルの数。
- num_sampled: 生成する候補の数。
- unique: unique が true の場合、バッチ内のすべてのサンプリングされた候補が一意になるように、拒否を伴うサンプリングが行われます。これには、拒否後のサンプリング確率を推定するために何らかの近似が必要です。
オプションの属性 ( Attrs
を参照):
- シード: シードまたはシード 2 のいずれかが 0 以外に設定されている場合、乱数ジェネレーターには指定されたシードがシードされます。それ以外の場合は、ランダム シードによってシードされます。
- シード2: シードの衝突を避けるための 2 番目のシード。
戻り値:
-
Output
sampled_candidates: 長さnum_sampledのベクトル。各要素はサンプリングされた候補のIDです。 -
Output
true_expected_count: サンプリングされた候補のバッチ内で各候補が出現すると予想される回数を表す、batch_size * num_true 行列。 unique=true の場合、これは確率です。 -
Output
sampled_expected_count: サンプルされた候補ごとに、サンプルされた候補のバッチ内で候補が出現すると予想される回数を表す、長さ num_sampled のベクトル。 unique=true の場合、これは確率です。
コンストラクターとデストラクター | |
---|---|
AllCandidateSampler (const :: tensorflow::Scope & scope, :: tensorflow::Input true_classes, int64 num_true, int64 num_sampled, bool unique) | |
AllCandidateSampler (const :: tensorflow::Scope & scope, :: tensorflow::Input true_classes, int64 num_true, int64 num_sampled, bool unique, const AllCandidateSampler::Attrs & attrs) |
パブリック属性 | |
---|---|
operation | |
sampled_candidates | |
sampled_expected_count | |
true_expected_count |
パブリック静的関数 | |
---|---|
Seed (int64 x) | |
Seed2 (int64 x) |
構造体 | |
---|---|
tensorflow:: ops:: AllCandidateSampler:: Attrs | AllCandidateSamplerのオプションの属性セッター。 |
パブリック属性
手術
Operation operation
サンプルされた候補者
::tensorflow::Output sampled_candidates
サンプル予想数
::tensorflow::Output sampled_expected_count
true_expected_count
::tensorflow::Output true_expected_count
公共機能
すべての候補者サンプラー
AllCandidateSampler( const ::tensorflow::Scope & scope, ::tensorflow::Input true_classes, int64 num_true, int64 num_sampled, bool unique )
すべての候補者サンプラー
AllCandidateSampler( const ::tensorflow::Scope & scope, ::tensorflow::Input true_classes, int64 num_true, int64 num_sampled, bool unique, const AllCandidateSampler::Attrs & attrs )
パブリック静的関数
シード
Attrs Seed( int64 x )
シード2
Attrs Seed2( int64 x )