ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

TensorLabel

public class TensorLabel

TensorLabel is an util wrapper for TensorBuffers with meaningful labels on an axis.

For example, an image classification model may have an output tensor with shape as {1, 10}, where 1 is the batch size and 10 is the number of categories. In fact, on the 2nd axis, we could label each sub-tensor with the name or description of each corresponding category. TensorLabel could help converting the plain Tensor in TensorBuffer into a map from predefined labels to sub-tensors. In this case, if provided 10 labels for the 2nd axis, TensorLabel could convert the original {1, 10} Tensor to a 10 element map, each value of which is Tensor in shape {} (scalar). Usage example:

   TensorBuffer outputTensor = ...;
   List<String> labels = FileUtil.loadLabels(context, labelFilePath);
   // labels the first axis with size greater than one
   TensorLabel labeled = new TensorLabel(labels, outputTensor);
   // If each sub-tensor has effectively size 1, we can directly get a float value
   Map<String, Float> probabilities = labeled.getMapWithFloatValue();
   // Or get sub-tensors, when each sub-tensor has elements more than 1
   Map<String, TensorBuffer> subTensors = labeled.getMapWithTensorBuffer();
 

Note: currently we only support tensor-to-map conversion for the first label with size greater than 1.

Public Constructors

TensorLabel(Map<IntegerList<String>> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.
TensorLabel(List<String> axisLabels, TensorBuffer tensorBuffer)
Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.

Public Methods

List<Category>
getCategoryList()
Gets a list of Category from the TensorLabel object.
Map<StringFloat>
getMapWithFloatValue()
Gets a map that maps label to float.
Map<StringTensorBuffer>
getMapWithTensorBuffer()
Gets the map with a pair of the label and the corresponding TensorBuffer.

Inherited Methods

Public Constructors

public TensorLabel (Map<IntegerList<String>> axisLabels, TensorBuffer tensorBuffer)

Creates a TensorLabel object which is able to label on the axes of multi-dimensional tensors.

Parameters
axisLabels A map, whose key is axis id (starting from 0) and value is corresponding labels. Note: The size of labels should be same with the size of the tensor on that axis.
tensorBuffer The TensorBuffer to be labeled.
Throws
NullPointerException if axisLabels or tensorBuffer is null, or any value in axisLabels is null.
IllegalArgumentException if any key in axisLabels is out of range (compared to the shape of tensorBuffer, or any value (labels) has different size with the tensorBuffer on the given dimension.

public TensorLabel (List<String> axisLabels, TensorBuffer tensorBuffer)

Creates a TensorLabel object which is able to label on one axis of multi-dimensional tensors.

Note: The labels are applied on the first axis whose size is larger than 1. For example, if the shape of the tensor is [1, 10, 3], the labels will be applied on axis 1 (id starting from 0), and size of axisLabels should be 10 as well.

Parameters
axisLabels A list of labels, whose size should be same with the size of the tensor on the to-be-labeled axis.
tensorBuffer The TensorBuffer to be labeled.

Public Methods

public List<Category> getCategoryList ()

Gets a list of Category from the TensorLabel object.

The axis of label should be effectively the last axis (which means every sub tensor specified by this axis should have a flat size of 1), so that each labelled sub tensor could be converted into a float value score. Example: A TensorLabel with shape {2, 5, 3} and axis 2 is valid. If axis is 1 or 0, it cannot be converted into a Category.

getMapWithFloatValue() is an alternative but returns a Map as the result.

Throws
IllegalStateException if size of a sub tensor on each label is not 1.

public Map<StringFloat> getMapWithFloatValue ()

Gets a map that maps label to float. Only allow the mapping on the first axis with size greater than 1, and the axis should be effectively the last axis (which means every sub tensor specified by this axis should have a flat size of 1).

getCategoryList() is an alternative API to get the result.

Throws
IllegalStateException if size of a sub tensor on each label is not 1.

public Map<StringTensorBuffer> getMapWithTensorBuffer ()

Gets the map with a pair of the label and the corresponding TensorBuffer. Only allow the mapping on the first axis with size greater than 1 currently.