TF-NumPy Type Promotion

View on TensorFlow.org Run in Google Colab View source on GitHub Download notebook

Overview

There are 4 options for type promotion in TensorFlow.

  • By default, TensorFlow raises errors instead of promoting types for mixed type operations.
  • Running tf.numpy.experimental_enable_numpy_behavior() switches TensorFlow to use NumPy type promotion rules.
  • This doc describes two new options that will be available in TensorFlow 2.15 (or currently in tf-nightly):
pip install -q tf_nightly

Setup

import numpy as np
import tensorflow as tf
import tensorflow.experimental.numpy as tnp

print("Using TensorFlow version %s" % tf.__version__)
2024-07-19 02:43:21.802179: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1721357001.823383  127084 cuda_dnn.cc:8458] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1721357001.829905  127084 cuda_blas.cc:1420] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Using TensorFlow version 2.18.0-dev20240718

Enabling the new type promotion

In order to use the JAX-like type promotion in TF-Numpy, specify either 'all' or 'safe' as the dtype conversion mode when enabling NumPy behavior for TensorFlow.

This new system (with dtype_conversion_mode="all") is associative, commutative, and makes it easy to control what width of float you end up with (it doesn't automatically convert to wider floats). It does introduce some risks of overflows and precision loss, but dtype_conversion_mode="safe" forces you to handle those cases explicitly. The two modes are explained more in detail in the next section.

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.

Two Modes : ALL mode vs SAFE mode

In the new type promotion system, we introduce two modes: ALL mode and SAFE mode. SAFE mode is used to mitigate the concerns of "risky" promotions that can result in precision loss or bit-widening.

Dtypes

We will be using the following abbreviations for brevity.

The asterisk (*) denotes that the corresponding type is “weak” - such a dtype is temporarily inferred by the system, and could defer to other dtypes. This concept is explained more in detail here.

Example of precision losing operations

In the following example, i32 + f32 is allowed in ALL mode but not in SAFE mode due to the risk of precision loss.

# i32 + f32 returns a f32 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
a + b  # <tf.Tensor: shape=(), dtype=float32, numpy=15.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
I0000 00:00:1721357004.235529  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.238825  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.242583  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.246235  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.257984  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.260895  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.264333  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.267683  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.271166  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.274075  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.277481  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357004.280953  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.524751  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.526797  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.528848  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.530817  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.532808  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.534653  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.536589  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.538446  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.540338  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.542208  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.544147  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.546020  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.584381  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.586389  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.588398  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.590315  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.592220  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.594102  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.596065  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.597951  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.599869  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.601729  127084 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13642 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:05.0, compute capability: 7.5
I0000 00:00:1721357005.602264  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.604222  127084 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:1 with 13760 MB memory:  -> device: 1, name: Tesla T4, pci bus id: 0000:00:06.0, compute capability: 7.5
I0000 00:00:1721357005.604788  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.606653  127084 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:2 with 13760 MB memory:  -> device: 2, name: Tesla T4, pci bus id: 0000:00:07.0, compute capability: 7.5
I0000 00:00:1721357005.607137  127084 cuda_executor.cc:821] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
I0000 00:00:1721357005.609027  127084 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:3 with 13760 MB memory:  -> device: 3, name: Tesla T4, pci bus id: 0000:00:08.0, compute capability: 7.5
<tf.Tensor: shape=(), dtype=float32, numpy=15.0>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int32)
b = tf.constant(5.0, dtype = tf.float32)
try:
  a + b
except TypeError as e:
   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int32'>, weak=False) and (<dtype: 'float32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).

Example of bit-widening operations

In the following example, i8 + u32 is allowed in ALL mode but not in SAFE mode due to bit-widening, which means using more bits than the number of bits in the inputs. Note that the new type promotion semantics only allows necessary bit-widening.

# i8 + u32 returns an i64 result in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
a + b
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=int64, numpy=15>
# This promotion is not allowed in SAFE mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="safe")
a = tf.constant(10, dtype = tf.int8)
b = tf.constant(5, dtype = tf.uint32)
try:
  a + b
except TypeError as e:
   print(f'{type(e)}: {e}')  # TypeError: explicitly specify the dtype or switch to ALL mode.
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<class 'TypeError'>: In promotion mode PromoMode.SAFE, implicit dtype promotion between (<dtype: 'int8'>, weak=False) and (<dtype: 'uint32'>, weak=False) is disallowed. You need to explicitly specify the dtype in your op, or relax your dtype promotion rules (such as from SAFE mode to ALL mode).

A System Based on a Lattice

Type Promotion Lattice

The new type promotion behavior is determined via the following type promotion lattice:

Type Promotion Lattice

More specifically, promotion between any two types is determined by finding the first common child of the two nodes (including the nodes themselves).

For example, in the diagram above, the first common child of i8 and i32 is i32 because the two nodes intersect for the first time at i32 when following the direction of the arrows.

Similarly as another example, the result promotion type between u64 and f16 would be f16.

Type Promotion Table

Following the lattice generates the binary promotion table below:

Type Promotion Table

Advantages of The New Type Promotion

We adopt a JAX-like lattice-based system for our new type promotion, which offers the following advantages:

Advantages of Lattice-Based System

First, using a lattice-based system ensures three very important properties:

  • Existence: There is a unique result promotion type for any combinations of types.
  • Commutativity: a + b = b + a
  • Associativity: a + (b + c) = (a + b) = c

These three properties are critical for constructing a type promotion semantics that is consistent and predictable.

Advantages of JAX-like Lattice System

Another crucial advantage of the JAX-like lattice system is that outside unsigned ints, it avoids all wider-than-necessary promotions. This means you cannot get 64-bit results without 64-bit inputs. This is especially beneficial for working on accelerators as it avoids unnecessary 64-bit values, which was frequent in the old type promotion.

However, this comes with a trade-off: mixed float/integer promotion is very prone to precision loss. For instance, in the example below, i64 + f16 results in promoting i64 to f16.

# The first input is promoted to f16 in ALL mode.
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
tf.constant(1, tf.int64) + tf.constant(3.2, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.2>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float16, numpy=4.2>

To migitage this concern, we introduced a SAFE mode that will disallow these "risky" promotions.

WeakTensor

Overview

Weak tensors are Tensors that are "weakly typed", similar to a concept in JAX.

WeakTensor's dtype is temporarily inferred by the system, and could defer to other dtypes. This concept is introduced in the new type promotion to prevent unwanted type promotion within binary operations between TF values and values with no explicitly user-specified type, such as Python scalar literals.

For instance, in the example below, tf.constant(1.2) is considered "weak" because it doesn't have a specific dtype. Therefore, tf.constant(1.2) defers to the type of tf.constant(3.1, tf.float16), resulting in a f16 output.

tf.constant(1.2) + tf.constant(3.1, tf.float16)  # <tf.Tensor: shape=(), dtype=float16, numpy=4.3>
<tf.Tensor: shape=(), dtype=float16, numpy=4.3>

WeakTensor Construction

WeakTensors are created if you create a tensor without specifying a dtype the result is a WeakTensor. You can check whether a Tensor is "weak" or not by checking the weak attribute at the end of the Tensor's string representation.

First Case: When tf.constant is called with an input with no user-specified dtype.

tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=5, weak=True>
tf.constant([5.0, 10.0, 3])  # <tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>
<tf.Tensor: shape=(3,), dtype=float32, numpy=array([ 5., 10.,  3.], dtype=float32), weak=True>
# A normal Tensor is created when dtype arg is specified.
tf.constant(5, tf.int32)  # <tf.Tensor: shape=(), dtype=int32, numpy=5>
<tf.Tensor: shape=(), dtype=int32, numpy=5>

Second Case: When an input with no user-specified dtype is passed into a WeakTensor-supporting API.

tf.math.abs([100.0, 4.0])  # <tf.Tensor: shape=(2,), dtype=float32, numpy=array([100., 4.], dtype=float32), weak=True>
<tf.Tensor: shape=(2,), dtype=float32, numpy=array([100.,   4.], dtype=float32), weak=True>

Effects of turning on the new type promotion

Below is a non-exhaustive list of changes that result from turning on the new type promotion.

  • More consistent and predictable promotion results.
  • Reduced risk of bit-widening.
  • tf.Tensor mathematical dunder methods use new type promotion.
  • tf.constant can return WeakTensor.
  • tf.constant allows implicit conversions when a Tensor input with a dtype different from the dtype arg is passed in.
  • tf.Variable in-place ops (assign, assign-add, assign-sub) allow implicit conversions.
  • tnp.array(1) and tnp.array(1.0) returns 32-bit WeakTensor.
  • WeakTensors will be created and used for WeakTensor-supporting unary and binary API's.

More consistent and predictable promotion results

Using a lattice-based system allows the new type promotion to produce consistent and predictable type promotion results.

Old Type Promotion

Changing the order of operations produces inconsistent results using old type promotion.

# Setup
tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
# (a + b) + c throws an InvalidArgumentError.
try:
  tf.add(tf.add(a, b), c)
except tf.errors.InvalidArgumentError as e:
  print(f'{type(e)}: {e}')  # InvalidArgumentError
<class 'tensorflow.python.framework.errors_impl.InvalidArgumentError'>: cannot compute AddV2 as input #1(zero-based) was expected to be a int8 tensor but is a int32 tensor [Op:AddV2] name:
# (b + a) + c returns an i32 result.
tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=int32, numpy=3>
<tf.Tensor: shape=(), dtype=int32, numpy=3>

New Type Promotion

New type promotion produces consistent results regardless of the order.

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = np.array(1, dtype=np.int8)
b = tf.constant(1)
c = np.array(1, dtype=np.float16)
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
# (a + b) + c returns a f16 result.
tf.add(tf.add(a, b), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>
# (b + a) + c also returns a f16 result.
tf.add(tf.add(b, a), c)  # <tf.Tensor: shape=(), dtype=float16, numpy=3.0>
<tf.Tensor: shape=(), dtype=float16, numpy=3.0>

Reduced risk of bit-widening

Old Type Promotion

Old type promotion often resulted in 64-bit results.

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="legacy")
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>
<tf.Tensor: shape=(), dtype=float64, numpy=54.19921875>

New Type Promotion

New type promotion returns results with minimal number of bits necessary.

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
np.array(3.2, np.float16) + tf.constant(1, tf.int8) + tf.constant(50)  # <tf.Tensor: shape=(), dtype=float16, numpy=54.2>
<tf.Tensor: shape=(), dtype=float16, numpy=54.2>

tf.Tensor mathematical dunder methods

All tf.Tensor mathematical dunder methods will follow the new type promotion.

-tf.constant(5)  # <tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=-5, weak=True>
tf.constant(5, tf.int16) - tf.constant(1, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=4.0>
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>

tf.Variable in-place ops

Implicit conversions will be allowed in tf.Variable in-place ops.

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.Variable(10, tf.int32)
a.assign_add(tf.constant(5, tf.int16))  # <tf.Variable shape=() dtype=int32, numpy=15>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Variable 'UnreadVariable' shape=() dtype=int32, numpy=15>

tf.constant implicit conversions

In the old type promotion, tf.constant required an input Tensor to have the same dtype as the dtype argument. However, in the new type promotion, we implicitly convert Tensor to the specified dtype.

tnp.experimental_enable_numpy_behavior(dtype_conversion_mode="all")
a = tf.constant(10, tf.int16)
tf.constant(a, tf.float32)  # <tf.Tensor: shape=(), dtype=float32, numpy=10.0>
WARNING:tensorflow:UserWarning: enabling the new type promotion must happen at the beginning of the program. Please ensure no TF APIs have been used yet.
<tf.Tensor: shape=(), dtype=float32, numpy=10.0>

TF-NumPy Array

tnp.array defaults to i32* and f32* for python inputs using the new type promotion.

tnp.array(1)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
tnp.array(1.0)  # <tf.Tensor: shape=(), dtype=int32, numpy=1, weak=True>
<tf.Tensor: shape=(), dtype=float32, numpy=1.0, weak=True>

Input Type Inference

This is how different inputs' types are inferred in the new type promotion.

  • tf.Tensor: Since tf.Tensor has a dtype property, we don't do further inference.
  • NumPy types: This includes types like np.array(1), np.int16(1), and np.float. Since NumPy inputs also have a dtype property, we take the dtype property as the result inference type. Note that NumPy defaults to i64 and f64.
  • Python scalars/Nested types: This includes types like 1, [1, 2, 3], and (1.0, 2.0).
    • Python int is inferred as i32*.
    • Python float is inferred as f32*.
    • Python complex is inferred as c128*.
  • If the input doesn't fall into any of the above categories but has a dtype property, we take the dtype property as the result inference type.

Further Reading

The new type promotion closely resembles JAX-NumPy's type promotion. If you want to know more details about the new type promotion and the design choices, check out the resources below.

References

WeakTensor-supporting APIs

Below is a list of APIs that supports WeakTensor.

For an unary op, this means that if an input with no user-specified type is passed in, it will return a WeakTensor.

For a binary op, it will follow the promotion table here. It may or may not return a WeakTensor depending on the promotion result of the two inputs.