Esegue più predittori di insieme di regressione additiva su istanze di input e
calcola l'aggiornamento ai logit memorizzati nella cache. È progettato per essere utilizzato durante l'allenamento. Attraversa gli alberi a partire dall'ID dell'albero memorizzato nella cache e dall'ID del nodo memorizzato nella cache e calcola gli aggiornamenti da inviare alla cache.
Metodi pubblici
static BoostedTreesTrainingPredict | create ( Scope scope, Operand <?> treeEnsembleHandle, Operand <Integer> cachedTreeIds, Operand <Integer> cachedNodeIds, Iterable < Operand <Integer>> bucketizedFeatures, Long logitsDimension) Metodo Factory per creare una classe che esegue il wrapping di una nuova operazione BoostedTreesTrainingPredict. |
Uscita <Integer> | nodeIds () Tensore di rango 1 contenente nuovi ID nodo nei nuovi tree_ids. |
Uscita <Float> | partialLogits () Rango 2 Tensore contenente l'aggiornamento dei logit (rispetto ai valori memorizzati nella cache) per ogni esempio. |
Uscita <Integer> | treeIds () Tensore di rango 1 contenente nuovi ID albero per ogni esempio. |
Metodi ereditati
Metodi pubblici
public static BoostedTreesTrainingPredict create ( ambito ambito, Operando <?> treeEnsembleHandle, Operando <Integer> cachedTreeIds, Operando <Integer> cachedNodeIds, Iterable < Operand <Integer>> bucketizedFeatures, Long logitsDimension)
Metodo Factory per creare una classe che esegue il wrapping di una nuova operazione BoostedTreesTrainingPredict.
Parametri
scopo | ambito attuale |
---|---|
cachedTreeIds | Tensore di rango 1 contenente gli ID dell'albero memorizzati nella cache che è l'albero iniziale della previsione. |
cachedNodeIds | Rango 1 Tensore contenente l'ID del nodo memorizzato nella cache che è il nodo iniziale della previsione. |
bucketizedFeatures | Un elenco di tensori di livello 1 contenente l'ID del bucket per ciascuna funzionalità. |
logitsDimension | scalare, dimensione dei logiti, da utilizzare per la forma dei logiti parziali. |
ritorna
- una nuova istanza di BoostedTreesTrainingPredict
output pubblico <Integer> nodeIds ()
Tensore di rango 1 contenente nuovi ID nodo nei nuovi tree_ids.
output pubblico <Float> partialLogits ()
Rango 2 Tensore contenente l'aggiornamento dei logit (rispetto ai valori memorizzati nella cache) per ogni esempio.
output pubblico <Integer> treeIds ()
Tensore di rango 1 contenente nuovi ID albero per ogni esempio.