TensorLabel é um wrapper utilitário para TensorBuffers com rótulos significativos em um eixo.
Por exemplo, um modelo de classificação de imagem pode ter um tensor de saída com formato {1, 10}, onde 1 é o tamanho do lote e 10 é o número de categorias. Na verdade, no 2º eixo poderíamos rotular cada subtensor com o nome ou descrição de cada categoria correspondente. TensorLabel
pode ajudar a converter o Tensor simples no TensorBuffer
em um mapa de rótulos predefinidos para subtensores. Nesse caso, se forem fornecidos 10 rótulos para o segundo eixo, TensorLabel
poderá converter o Tensor {1, 10} original em um mapa de 10 elementos, cada valor do qual é o Tensor na forma {} (escalar). Exemplo de uso:
TensorBuffer outputTensor = ...; List<String> labels = FileUtil.loadLabels(context, labelFilePath); // labels the first axis with size greater than one TensorLabel labeled = new TensorLabel(labels, outputTensor); // If each sub-tensor has effectively size 1, we can directly get a float value Map<String, Float> probabilities = labeled.getMapWithFloatValue(); // Or get sub-tensors, when each sub-tensor has elements more than 1 Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
Observação: atualmente, oferecemos suporte apenas à conversão de tensor em mapa para o primeiro rótulo com tamanho maior que 1.
Construtores Públicos
TensorLabel ( Mapa < Inteiro , Lista < String >> axisLabels, TensorBuffer tensorBuffer) Cria um objeto TensorLabel que é capaz de rotular os eixos de tensores multidimensionais. | |
TensorLabel ( Lista <String> axisLabels, TensorBuffer tensorBuffer) Cria um objeto TensorLabel que é capaz de rotular em um eixo de tensores multidimensionais. |
Métodos Públicos
Lista <Categoria> | getCategoryList () Obtém uma lista de Category do objeto TensorLabel . |
Mapa < String , Float > | getMapWithFloatValue () Obtém um mapa que mapeia o rótulo para flutuar. |
Mapa < String , TensorBuffer > | getMapWithTensorBuffer () Obtém o mapa com um par de rótulo e o TensorBuffer correspondente. |
Métodos herdados
Construtores Públicos
public TensorLabel ( Mapa < Inteiro , Lista < String >> axisLabels, TensorBuffer tensorBuffer)
Cria um objeto TensorLabel que é capaz de rotular os eixos de tensores multidimensionais.
Parâmetros
eixoLabels | Um mapa, cuja chave é o id do eixo (começando em 0) e o valor são os rótulos correspondentes. Nota: O tamanho dos rótulos deve ser igual ao tamanho do tensor nesse eixo. |
---|---|
tensorBuffer | O TensorBuffer a ser rotulado. |
Lança
Null Pointer Exception | se axisLabels ou tensorBuffer for nulo ou qualquer valor em axisLabels for nulo. |
---|---|
Exceção de argumento ilegal | se alguma chave em axisLabels estiver fora do intervalo (em comparação com a forma de tensorBuffer , ou qualquer valor (rótulos) tiver tamanho diferente do tensorBuffer na dimensão fornecida. |
public TensorLabel ( Lista <String> axisLabels, TensorBuffer tensorBuffer)
Cria um objeto TensorLabel que é capaz de rotular em um eixo de tensores multidimensionais.
Nota: Os rótulos são aplicados no primeiro eixo cujo tamanho é maior que 1. Por exemplo, se a forma do tensor for [1, 10, 3], os rótulos serão aplicados no eixo 1 (id começando em 0), e o tamanho de axisLabels
também deve ser 10.
Parâmetros
eixoLabels | Uma lista de rótulos, cujo tamanho deve ser igual ao tamanho do tensor no eixo a ser rotulado. |
---|---|
tensorBuffer | O TensorBuffer a ser rotulado. |
Métodos Públicos
lista pública <categoria> getCategoryList ()
Obtém uma lista de Category
do objeto TensorLabel
.
O eixo do rótulo deve ser efetivamente o último eixo (o que significa que cada subtensor especificado por este eixo deve ter um tamanho plano de 1), para que cada subtensor rotulado possa ser convertido em uma pontuação de valor flutuante. Exemplo: Um TensorLabel
com formato {2, 5, 3}
e eixo 2 é válido. Se eixo for 1 ou 0, ele não poderá ser convertido em Category
.
getMapWithFloatValue()
é uma alternativa, mas retorna um Map
como resultado.
Lança
IllegalStateException | se o tamanho de um subtensor em cada rótulo não for 1. |
---|
mapa público < String , Float > getMapWithFloatValue ()
Obtém um mapa que mapeia o rótulo para flutuar. Permita apenas o mapeamento no primeiro eixo com tamanho maior que 1, e o eixo deve ser efetivamente o último eixo (o que significa que cada subtensor especificado por este eixo deve ter um tamanho plano de 1).
getCategoryList()
é uma API alternativa para obter o resultado.
Lança
IllegalStateException | se o tamanho de um subtensor em cada rótulo não for 1. |
---|
mapa público < String , TensorBuffer > getMapWithTensorBuffer ()
Obtém o mapa com um par de rótulo e o TensorBuffer correspondente. Permitir apenas o mapeamento no primeiro eixo com tamanho maior que 1 atualmente.