Working with sparse tensors

View on TensorFlow.org Run in Google Colab View on GitHub Download notebook

When working with tensors that contain a lot of zero values, it is important to store them in a space- and time-efficient manner. Sparse tensors enable efficient storage and processing of tensors that contain a lot of zero values. Sparse tensors are used extensively in encoding schemes like TF-IDF as part of data pre-processing in NLP applications and for pre-processing images with a lot of dark pixels in computer vision applications.

Sparse tensors in TensorFlow

TensorFlow represents sparse tensors through the tf.sparse.SparseTensor object. Currently, sparse tensors in TensorFlow are encoded using the coordinate list (COO) format. This encoding format is optimized for hyper-sparse matrices such as embeddings.

The COO encoding for sparse tensors is comprised of:

  • values: A 1D tensor with shape [N] containing all nonzero values.
  • indices: A 2D tensor with shape [N, rank], containing the indices of the nonzero values.
  • dense_shape: A 1D tensor with shape [rank], specifying the shape of the tensor.

A nonzero value in the context of a tf.sparse.SparseTensor is a value that's not explicitly encoded. It is possible to explicitly include zero values in the values of a COO sparse matrix, but these "explicit zeros" are generally not included when referring to nonzero values in a sparse tensor.

Creating a tf.sparse.SparseTensor

Construct sparse tensors by directly specifying their values, indices, and dense_shape.

import tensorflow as tf
2024-02-14 02:25:26.445043: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:9261] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-02-14 02:25:26.445086: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:607] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-02-14 02:25:26.446594: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1515] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
st1 = tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]],
                      values=[10, 20],
                      dense_shape=[3, 10])

When you use the print() function to print a sparse tensor, it shows the contents of the three component tensors:

print(st1)
SparseTensor(indices=tf.Tensor(
[[0 3]
 [2 4]], shape=(2, 2), dtype=int64), values=tf.Tensor([10 20], shape=(2,), dtype=int32), dense_shape=tf.Tensor([ 3 10], shape=(2,), dtype=int64))

It is easier to understand the contents of a sparse tensor if the nonzero values are aligned with their corresponding indices. Define a helper function to pretty-print sparse tensors such that each nonzero value is shown on its own line.

def pprint_sparse_tensor(st):
  s = "<SparseTensor shape=%s \n values={" % (st.dense_shape.numpy().tolist(),)
  for (index, value) in zip(st.indices, st.values):
    s += f"\n  %s: %s" % (index.numpy().tolist(), value.numpy().tolist())
  return s + "}>"
print(pprint_sparse_tensor(st1))
<SparseTensor shape=[3, 10] 
 values={
  [0, 3]: 10
  [2, 4]: 20}>

You can also construct sparse tensors from dense tensors by using tf.sparse.from_dense, and convert them back to dense tensors by using tf.sparse.to_dense.

st2 = tf.sparse.from_dense([[1, 0, 0, 8], [0, 0, 0, 0], [0, 0, 3, 0]])
print(pprint_sparse_tensor(st2))
<SparseTensor shape=[3, 4] 
 values={
  [0, 0]: 1
  [0, 3]: 8
  [2, 2]: 3}>
st3 = tf.sparse.to_dense(st2)
print(st3)
tf.Tensor(
[[1 0 0 8]
 [0 0 0 0]
 [0 0 3 0]], shape=(3, 4), dtype=int32)

Manipulating sparse tensors

Use the utilities in the tf.sparse package to manipulate sparse tensors. Ops like tf.math.add that you can use for arithmetic manipulation of dense tensors do not work with sparse tensors.

Add sparse tensors of the same shape by using tf.sparse.add.

st_a = tf.sparse.SparseTensor(indices=[[0, 2], [3, 4]],
                       values=[31, 2], 
                       dense_shape=[4, 10])

st_b = tf.sparse.SparseTensor(indices=[[0, 2], [7, 0]],
                       values=[56, 38],
                       dense_shape=[4, 10])

st_sum = tf.sparse.add(st_a, st_b)

print(pprint_sparse_tensor(st_sum))
<SparseTensor shape=[4, 10] 
 values={
  [0, 2]: 87
  [3, 4]: 2
  [7, 0]: 38}>

Use tf.sparse.sparse_dense_matmul to multiply sparse tensors with dense matrices.

st_c = tf.sparse.SparseTensor(indices=([0, 1], [1, 0], [1, 1]),
                       values=[13, 15, 17],
                       dense_shape=(2,2))

mb = tf.constant([[4], [6]])
product = tf.sparse.sparse_dense_matmul(st_c, mb)

print(product)
tf.Tensor(
[[ 78]
 [162]], shape=(2, 1), dtype=int32)

Put sparse tensors together by using tf.sparse.concat and take them apart by using tf.sparse.slice.

sparse_pattern_A = tf.sparse.SparseTensor(indices = [[2,4], [3,3], [3,4], [4,3], [4,4], [5,4]],
                         values = [1,1,1,1,1,1],
                         dense_shape = [8,5])
sparse_pattern_B = tf.sparse.SparseTensor(indices = [[0,2], [1,1], [1,3], [2,0], [2,4], [2,5], [3,5], 
                                              [4,5], [5,0], [5,4], [5,5], [6,1], [6,3], [7,2]],
                         values = [1,1,1,1,1,1,1,1,1,1,1,1,1,1],
                         dense_shape = [8,6])
sparse_pattern_C = tf.sparse.SparseTensor(indices = [[3,0], [4,0]],
                         values = [1,1],
                         dense_shape = [8,6])

sparse_patterns_list = [sparse_pattern_A, sparse_pattern_B, sparse_pattern_C]
sparse_pattern = tf.sparse.concat(axis=1, sp_inputs=sparse_patterns_list)
print(tf.sparse.to_dense(sparse_pattern))
tf.Tensor(
[[0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0]
 [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0]
 [0 0 0 1 1 0 0 0 0 0 1 1 0 0 0 0 0]
 [0 0 0 0 1 1 0 0 0 1 1 0 0 0 0 0 0]
 [0 0 0 0 0 0 1 0 1 0 0 0 0 0 0 0 0]
 [0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0]], shape=(8, 17), dtype=int32)
sparse_slice_A = tf.sparse.slice(sparse_pattern_A, start = [0,0], size = [8,5])
sparse_slice_B = tf.sparse.slice(sparse_pattern_B, start = [0,5], size = [8,6])
sparse_slice_C = tf.sparse.slice(sparse_pattern_C, start = [0,10], size = [8,6])
print(tf.sparse.to_dense(sparse_slice_A))
print(tf.sparse.to_dense(sparse_slice_B))
print(tf.sparse.to_dense(sparse_slice_C))
tf.Tensor(
[[0 0 0 0 0]
 [0 0 0 0 0]
 [0 0 0 0 1]
 [0 0 0 1 1]
 [0 0 0 1 1]
 [0 0 0 0 1]
 [0 0 0 0 0]
 [0 0 0 0 0]], shape=(8, 5), dtype=int32)
tf.Tensor(
[[0]
 [0]
 [1]
 [1]
 [1]
 [1]
 [0]
 [0]], shape=(8, 1), dtype=int32)
tf.Tensor([], shape=(8, 0), dtype=int32)

If you're using TensorFlow 2.4 or above, use tf.sparse.map_values for elementwise operations on nonzero values in sparse tensors.

st2_plus_5 = tf.sparse.map_values(tf.add, st2, 5)
print(tf.sparse.to_dense(st2_plus_5))
tf.Tensor(
[[ 6  0  0 13]
 [ 0  0  0  0]
 [ 0  0  8  0]], shape=(3, 4), dtype=int32)

Note that only the nonzero values were modified – the zero values stay zero.

Equivalently, you can follow the design pattern below for earlier versions of TensorFlow:

st2_plus_5 = tf.sparse.SparseTensor(
    st2.indices,
    st2.values + 5,
    st2.dense_shape)
print(tf.sparse.to_dense(st2_plus_5))
tf.Tensor(
[[ 6  0  0 13]
 [ 0  0  0  0]
 [ 0  0  8  0]], shape=(3, 4), dtype=int32)

Using tf.sparse.SparseTensor with other TensorFlow APIs

Sparse tensors work transparently with these TensorFlow APIs:

Examples are shown below for a few of the above APIs.

tf.keras

A subset of the tf.keras API supports sparse tensors without expensive casting or conversion ops. The Keras API lets you pass sparse tensors as inputs to a Keras model. Set sparse=True when calling tf.keras.Input or tf.keras.layers.InputLayer. You can pass sparse tensors between Keras layers, and also have Keras models return them as outputs. If you use sparse tensors in tf.keras.layers.Dense layers in your model, they will output dense tensors.

The example below shows you how to pass a sparse tensor as an input to a Keras model if you use only layers that support sparse inputs.

x = tf.keras.Input(shape=(4,), sparse=True)
y = tf.keras.layers.Dense(4)(x)
model = tf.keras.Model(x, y)

sparse_data = tf.sparse.SparseTensor(
    indices = [(0,0),(0,1),(0,2),
               (4,3),(5,0),(5,1)],
    values = [1,1,1,1,1,1],
    dense_shape = (6,4)
)

model(sparse_data)

model.predict(sparse_data)
1/1 [==============================] - 0s 96ms/step
array([[-1.9626052 ,  1.9385288 , -0.3864261 , -1.4728891 ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.        ,  0.        ,  0.        ,  0.        ],
       [ 0.2567069 ,  0.1808129 ,  0.01684809,  0.69576544],
       [-1.2277668 ,  1.4192976 ,  0.0884105 , -0.865898  ]],
      dtype=float32)

tf.data

The tf.data API enables you to build complex input pipelines from simple, reusable pieces. Its core data structure is tf.data.Dataset, which represents a sequence of elements in which each element consists of one or more components.

Building datasets with sparse tensors

Build datasets from sparse tensors using the same methods that are used to build them from tf.Tensors or NumPy arrays, such as tf.data.Dataset.from_tensor_slices. This op preserves the sparsity (or sparse nature) of the data.

dataset = tf.data.Dataset.from_tensor_slices(sparse_data)
for element in dataset: 
  print(pprint_sparse_tensor(element))
<SparseTensor shape=[4] 
 values={
  [0]: 1
  [1]: 1
  [2]: 1}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={
  [3]: 1}>
<SparseTensor shape=[4] 
 values={
  [0]: 1
  [1]: 1}>

Batching and unbatching datasets with sparse tensors

You can batch (combine consecutive elements into a single element) and unbatch datasets with sparse tensors using the Dataset.batch and Dataset.unbatch methods respectively.

batched_dataset = dataset.batch(2)
for element in batched_dataset:
  print (pprint_sparse_tensor(element))
<SparseTensor shape=[2, 4] 
 values={
  [0, 0]: 1
  [0, 1]: 1
  [0, 2]: 1}>
<SparseTensor shape=[2, 4] 
 values={}>
<SparseTensor shape=[2, 4] 
 values={
  [0, 3]: 1
  [1, 0]: 1
  [1, 1]: 1}>
unbatched_dataset = batched_dataset.unbatch()
for element in unbatched_dataset:
  print (pprint_sparse_tensor(element))
<SparseTensor shape=[4] 
 values={
  [0]: 1
  [1]: 1
  [2]: 1}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={
  [3]: 1}>
<SparseTensor shape=[4] 
 values={
  [0]: 1
  [1]: 1}>

You can also use tf.data.experimental.dense_to_sparse_batch to batch dataset elements of varying shapes into sparse tensors.

Transforming Datasets with sparse tensors

Transform and create sparse tensors in Datasets using Dataset.map.

transform_dataset = dataset.map(lambda x: x*2)
for i in transform_dataset:
  print(pprint_sparse_tensor(i))
<SparseTensor shape=[4] 
 values={
  [0]: 2
  [1]: 2
  [2]: 2}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={}>
<SparseTensor shape=[4] 
 values={
  [3]: 2}>
<SparseTensor shape=[4] 
 values={
  [0]: 2
  [1]: 2}>

tf.train.Example

tf.train.Example is a standard protobuf encoding for TensorFlow data. When using sparse tensors with tf.train.Example, you can:

tf.function

The tf.function decorator precomputes TensorFlow graphs for Python functions, which can substantially improve the performance of your TensorFlow code. Sparse tensors work transparently with both tf.function and concrete functions.

@tf.function
def f(x,y):
  return tf.sparse.sparse_dense_matmul(x,y)

a = tf.sparse.SparseTensor(indices=[[0, 3], [2, 4]],
                    values=[15, 25],
                    dense_shape=[3, 10])

b = tf.sparse.to_dense(tf.sparse.transpose(a))

c = f(a,b)

print(c)
tf.Tensor(
[[225   0   0]
 [  0   0   0]
 [  0   0 625]], shape=(3, 3), dtype=int32)

Distinguishing missing values from zero values

Most ops on tf.sparse.SparseTensors treat missing values and explicit zero values identically. This is by design — a tf.sparse.SparseTensor is supposed to act just like a dense tensor.

However, there are a few cases where it can be useful to distinguish zero values from missing values. In particular, this allows for one way to encode missing/unknown data in your training data. For example, consider a use case where you have a tensor of scores (that can have any floating point value from -Inf to +Inf), with some missing scores. You can encode this tensor using a sparse tensor where the explicit zeros are known zero scores but the implicit zero values actually represent missing data and not zero.

Note that some ops like tf.sparse.reduce_max do not treat missing values as if they were zero. For example, when you run the code block below, the expected output is 0. However, because of this exception, the output is -3.

print(tf.sparse.reduce_max(tf.sparse.from_dense([-5, 0, -3])))
tf.Tensor(-3, shape=(), dtype=int32)

In contrast, when you apply tf.math.reduce_max to a dense tensor, the output is 0 as expected.

print(tf.math.reduce_max([-5, 0, -3]))
tf.Tensor(0, shape=(), dtype=int32)

Further reading and resources