TensorLabel è un wrapper di utilità per TensorBuffers con etichette significative su un asse.
Ad esempio, un modello di classificazione delle immagini può avere un tensore di output con forma come {1, 10}, dove 1 è la dimensione del batch e 10 è il numero di categorie. Infatti, sul 2° asse, potremmo etichettare ciascun sottotensore con il nome o la descrizione di ciascuna categoria corrispondente. TensorLabel
potrebbe aiutare a convertire il semplice tensore in TensorBuffer
in una mappa da etichette predefinite a sottotensori. In questo caso, se vengono fornite 10 etichette per il 2° asse, TensorLabel
potrebbe convertire il tensore {1, 10} originale in una mappa di 10 elementi, ciascun valore della quale è un tensore in forma {} (scalare). Esempio di utilizzo:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Nota: attualmente supportiamo solo la conversione da tensore a mappa per la prima etichetta con dimensione maggiore di 1.
Costruttori pubblici
TensorLabel ( Mappa < Intero , Elenco < Stringa >>axisLabels, TensorBuffer tensorBuffer) Crea un oggetto TensorLabel in grado di etichettare sugli assi di tensori multidimensionali. | |
TensorLabel ( Lista < String > axisLabels, TensorBuffer tensorBuffer) Crea un oggetto TensorLabel che è in grado di etichettare su un asse di tensori multidimensionali. |
Metodi pubblici
Elenco <Categoria> | getCategoryList () Ottiene un elenco di Category dall'oggetto TensorLabel . |
Mappa < String , Float > | getMapWithFloatValue () Ottiene una mappa che mappa l'etichetta in modo mobile. |
Mappa < String , TensorBuffer > | getMapWithTensorBuffer () Ottiene la mappa con una coppia di etichette e il corrispondente TensorBuffer. |
Metodi ereditati
Costruttori pubblici
public TensorLabel ( Map < Integer , List < String >> axisLabels, TensorBuffer tensorBuffer)
Crea un oggetto TensorLabel in grado di etichettare sugli assi di tensori multidimensionali.
Parametri
assiEtichette | Una mappa, la cui chiave è l'id dell'asse (a partire da 0) e il valore sono le etichette corrispondenti. Nota: la dimensione delle etichette dovrebbe essere la stessa della dimensione del tensore su quell'asse. |
---|---|
tensorBuffer | Il TensorBuffer da etichettare. |
Lancia
NullPointerException | se axisLabels o tensorBuffer è nullo o qualsiasi valore in axisLabels è nullo. |
---|---|
IllegalArgumentException | se qualsiasi chiave in axisLabels è fuori intervallo (rispetto alla forma di tensorBuffer o qualsiasi valore (etichette) ha dimensioni diverse con tensorBuffer sulla dimensione data. |
public TensorLabel ( List < String > axisLabels, TensorBuffer tensorBuffer)
Crea un oggetto TensorLabel che è in grado di etichettare su un asse di tensori multidimensionali.
Nota: Le etichette vengono applicate sul primo asse la cui dimensione è maggiore di 1. Ad esempio, se la forma del tensore è [1, 10, 3], le etichette verranno applicate sull'asse 1 (id partendo da 0), e anche la dimensione di axisLabels
dovrebbe essere 10.
Parametri
assiEtichette | Un elenco di etichette, la cui dimensione dovrebbe essere uguale alla dimensione del tensore sull'asse da etichettare. |
---|---|
tensorBuffer | Il TensorBuffer da etichettare. |
Metodi pubblici
Elenco pubblico < Categoria > getCategoryList ()
Ottiene un elenco di Category
dall'oggetto TensorLabel
.
L'asse dell'etichetta dovrebbe essere effettivamente l'ultimo asse (il che significa che ogni sottotensore specificato da questo asse dovrebbe avere una dimensione piatta pari a 1), in modo che ogni sottotensore etichettato possa essere convertito in un punteggio con valore float. Esempio: un TensorLabel
con forma {2, 5, 3}
e asse 2 è valido. Se l'asse è 1 o 0, non può essere convertito in una Category
.
getMapWithFloatValue()
è un'alternativa ma restituisce una Map
come risultato.
Lancia
IllegalStateException | se la dimensione di un sottotensore su ciascuna etichetta non è 1. |
---|
mappa pubblica < String , Float > getMapWithFloatValue ()
Ottiene una mappa che mappa l'etichetta in modo mobile. Consenti la mappatura solo sul primo asse con dimensione maggiore di 1 e l'asse dovrebbe essere effettivamente l'ultimo asse (il che significa che ogni sottotensore specificato da questo asse dovrebbe avere una dimensione piatta pari a 1).
getCategoryList()
è un'API alternativa per ottenere il risultato.
Lancia
IllegalStateException | se la dimensione di un sottotensore su ciascuna etichetta non è 1. |
---|
mappa pubblica < String , TensorBuffer > getMapWithTensorBuffer ()
Ottiene la mappa con una coppia di etichette e il corrispondente TensorBuffer. Consenti attualmente solo la mappatura sul primo asse con dimensione maggiore di 1.