View on TensorFlow.org | Run in Google Colab | View on GitHub | Download notebook |
Under the hood, TensorFlow 2 follows a fundamentally different programming paradigm from TF1.x.
This guide describes the fundamental differences between TF1.x and TF2 in terms of behaviors and the APIs, and how these all relate to your migration journey.
High-level summary of major changes
Fundamentally, TF1.x and TF2 use a different set of runtime behaviors around execution (eager in TF2), variables, control flow, tensor shapes, and tensor equality comparisons. To be TF2 compatible, your code must be compatible with the full set of TF2 behaviors. During migration, you can enable or disable most of these behaviors individually via the tf.compat.v1.enable_*
or tf.compat.v1.disable_*
APIs. The one exception is the removal of collections, which is a side effect of enabling/disabling eager execution.
At a high level, TensorFlow 2:
- Removes redundant APIs.
- Makes APIs more consistent - for example, Unified RNNs and Unified Optimizers.
- Prefers functions over sessions and integrates better with the Python runtime with
Eager execution enabled by default along with
tf.function
that provides automatic control dependencies for graphs and compilation. - Deprecates global graph collections.
- Alters Variable concurrency semantics by using
ResourceVariables
overReferenceVariables
. - Supports function-based and differentiable control flow (Control Flow v2).
- Simplifies the TensorShape API to hold
int
s instead oftf.compat.v1.Dimension
objects. - Updates tensor equality mechanics. In TF1.x the
==
operator on tensors and variables checks for object reference equality. In TF2 it checks for value equality. Additionally, tensors/variables are no longer hashable, but you can get hashable object references to them viavar.ref()
if you need to use them in sets or asdict
keys.
The sections below provide some more context on the differences between TF1.x and TF2. To learn more about the design process behind TF2, read the RFCs and the design docs.
API cleanup
Many APIs are either gone or moved in TF2. Some of the major changes include removing tf.app
, tf.flags
, and tf.logging
in favor of the now open-source absl-py, rehoming projects that lived in tf.contrib
, and cleaning up the main tf.*
namespace by moving lesser used functions into subpackages like tf.math
. Some APIs have been replaced with their TF2 equivalents - tf.summary
, tf.keras.metrics
, and
tf.keras.optimizers
.
tf.compat.v1
: Legacy and Compatibility API Endpoints
Symbols under the tf.compat
and tf.compat.v1
namespaces are not considered TF2 APIs. These namespaces expose a mix of compatibility symbols, as well as legacy API endpoints from TF 1.x. These are intended to aid migration from TF1.x to TF2. However, as none of these compat.v1
APIs are idiomatic TF2 APIs, do not use them for writing brand-new TF2 code.
Individual tf.compat.v1
symbols may be TF2 compatible because they continue to work even with TF2 behaviors enabled (such as tf.compat.v1.losses.mean_squared_error
), while others are incompatible with TF2 (such as tf.compat.v1.metrics.accuracy
). Many compat.v1
symbols (though not all) contain dedicated migration information in their documentation that explains their degree of compatibility with TF2 behaviors, as well as how to migrate them to TF2 APIs.
The TF2 upgrade script can map many compat.v1
API symbols to equivalent TF2 APIs in the case where they are aliases or have the same arguments but with a different ordering. You can also use the upgrade script to automatically rename TF1.x APIs.
False friend APIs
There are a set of "false-friend" symbols found in the TF2 tf
namespace (not under compat.v1
) that actually ignore TF2 behaviors under-the-hood, and/or are not fully compatible with the full set of TF2 behaviors. As such, these APIs are likely to misbehave with TF2 code, potentially in silent ways.
tf.estimator.*
: Estimators create and use graphs and sessions under the hood. As such, these should not be considered TF2-compatible. If your code is running estimators, it is not using TF2 behaviors.keras.Model.model_to_estimator(...)
: This creates an Estimator under the hood, which as mentioned above is not TF2-compatible.tf.Graph().as_default()
: This enters TF1.x graph behaviors and does not follow standard TF2-compatibletf.function
behaviors. Code that enters graphs like this will generally run them via Sessions, and should not be considered TF2-compatible.tf.feature_column.*
The feature column APIs generally rely on TF1-styletf.compat.v1.get_variable
variable creation and assume that the created variables will be accessed via global collections. As TF2 does not support collections, APIs may not work correctly when running them with TF2 behaviors enabled.
Other API changes
TF2 features significant improvements to the device placement algorithms which renders the usage of
tf.colocate_with
unnecessary. If removing it causes a performance degradation, please file a bug.Replace all usage of
tf.v1.ConfigProto
with equivalent functions fromtf.config
.
Eager execution
TF1.x required you to manually stitch together an abstract syntax tree (the graph) by making tf.*
API calls and then manually compile the abstract syntax tree by passing a set of output tensors and input tensors to a session.run
call. TF2 executes eagerly (like Python normally does) and makes graphs and sessions feel like implementation details.
One notable byproduct of eager execution is that tf.control_dependencies
is no
longer required, as all lines of code execute in order (within a tf.function
,
code with side effects executes in the order written).
No more globals
TF1.x relied heavily on implicit global namespaces and collections. When you call tf.Variable
, it would be put into a collection in the default graph, and it would remain there, even if you lost track of the Python variable pointing to it. You could then recover that tf.Variable
, but only if you knew the name that it had been created with. This was difficult to do if you were not in control of the variable's creation. As a result, all sorts of mechanisms proliferated to
attempt to help you find your variables again, and for frameworks to find
user-created variables. Some of these include: variable scopes, global collections, helper methods like tf.get_global_step
and tf.global_variables_initializer
, optimizers implicitly
computing gradients over all trainable variables, and so on. TF2 eliminates all of these mechanisms (Variables 2.0 RFC) in favor of the default mechanism - you keep track of your variables. If you lose track of a tf.Variable
, it gets garbage collected.
The requirement to track variables creates some extra work, but with tools like the modeling shims and behaviors like implicit object-oriented variable collections in tf.Module
s and tf.keras.layers.Layer
s, the burden is minimized.
Functions, not sessions
A session.run
call is almost like a function call: you specify the inputs and
the function to be called, and you get back a set of outputs. In TF2, you can decorate a Python function using tf.function
to mark it for JIT compilation so that TensorFlow runs it as a single graph (Functions 2.0 RFC). This mechanism allows TF2 to gain all of the benefits of graph mode:
- Performance: The function can be optimized (node pruning, kernel fusion, etc.)
- Portability: The function can be exported/reimported (SavedModel 2.0 RFC), allowing you to reuse and share modular TensorFlow functions.
# TF1.x
outputs = session.run(f(placeholder), feed_dict={placeholder: input})
# TF2
outputs = f(input)
With the power to freely intersperse Python and TensorFlow code, you can take
advantage of Python's expressiveness. However, portable TensorFlow executes in
contexts without a Python interpreter, such as mobile, C++, and JavaScript. To
help avoid rewriting your code when adding tf.function
, use AutoGraph to convert a subset of Python constructs
into their TensorFlow equivalents:
for
/while
->tf.while_loop
(break
andcontinue
are supported)if
->tf.cond
for _ in dataset
->dataset.reduce
AutoGraph supports arbitrary nestings of control flow, which makes it possible to performantly and concisely implement many complex ML programs such as sequence models, reinforcement learning, custom training loops, and more.
Adapting to TF 2.x Behavior Changes
Your migration to TF2 is only complete once you have migrated to the full set of TF2 behaviors. The full set of behaviors can be enabled or disabled via tf.compat.v1.enable_v2_behaviors
and tf.compat.v1.disable_v2_behaviors
. The sections below discuss each major behavior change in detail.
Using tf.function
s
The largest changes to your programs during migration are likely to come from the fundamental programming model paradigm shift from graphs and sessions to eager execution and tf.function
. Refer to the TF2 migration guides to learn more about moving from APIs that are incompatible with eager execution and tf.function
to APIs that are compatible with them.
Below are some common program patterns not tied to any one API that may cause problems when switching from tf.Graph
s and tf.compat.v1.Session
s to eager execution with tf.function
s.
Pattern 1: Python object manipulation and variable creation intended to be done only once get run multiple times
In TF1.x programs that rely on graphs and sessions, the expectation is usually that all Python logic in your program will only run once. However, with eager execution and tf.function
it is fair to expect that your Python logic will be run at least once, but possibly more times (either multiple times eagerly, or multiple times across different tf.function
traces). Sometimes, tf.function
will even trace twice on the same input, causing unexpected behaviors (see Example 1 and 2). Refer to the tf.function
guide for more details.
Example 1: Variable creation
Consider the example below, where the function creates a variable when called:
def f():
v = tf.Variable(1.0)
return v
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
res = f()
sess.run(tf.compat.v1.global_variables_initializer())
sess.run(res)
However, naively wrapping the above function that contains variable creation with tf.function
is not allowed. tf.function
only supports singleton variable creations on the first call. To enforce this, when tf.function detects variable creation in the first call, it will attempt to trace again and raise an error if there is variable creation in the second trace.
@tf.function
def f():
print("trace") # This will print twice because the python body is run twice
v = tf.Variable(1.0)
return v
try:
f()
except ValueError as e:
print(e)
A workaround is caching and reusing the variable after it is created in the first call.
class Model(tf.Module):
def __init__(self):
self.v = None
@tf.function
def __call__(self):
print("trace") # This will print twice because the python body is run twice
if self.v is None:
self.v = tf.Variable(0)
return self.v
m = Model()
m()
Example 2: Out-of-scope Tensors due to tf.function
retracing
As demonstrated in Example 1, tf.function
will retrace when it detects Variable creation in the first call. This can cause extra confusion, because the two tracings will create two graphs. When the second graph from retracing attempts to access a Tensor from the graph generated during the first tracing, Tensorflow will raise an error complaining that the Tensor is out of scope. To demonstrate the scenario, the code below creates a dataset on the first tf.function
call. This would run as expected.
class Model(tf.Module):
def __init__(self):
self.dataset = None
@tf.function
def __call__(self):
print("trace") # This will print once: only traced once
if self.dataset is None:
self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
it = iter(self.dataset)
return next(it)
m = Model()
m()
However, if we also attempt to create a variable on the first tf.function
call, the code will raise an error complaining that the dataset is out of scope. This is because the dataset is in the first graph, while the second graph is also attempting to access it.
class Model(tf.Module):
def __init__(self):
self.v = None
self.dataset = None
@tf.function
def __call__(self):
print("trace") # This will print twice because the python body is run twice
if self.v is None:
self.v = tf.Variable(0)
if self.dataset is None:
self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
it = iter(self.dataset)
return [self.v, next(it)]
m = Model()
try:
m()
except TypeError as e:
print(e) # <tf.Tensor ...> is out of scope and cannot be used here.
The most straightforward solution is ensuring that the variable creation and dataset creation are both outside of the tf.function
call. For example:
class Model(tf.Module):
def __init__(self):
self.v = None
self.dataset = None
def initialize(self):
if self.dataset is None:
self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
if self.v is None:
self.v = tf.Variable(0)
@tf.function
def __call__(self):
it = iter(self.dataset)
return [self.v, next(it)]
m = Model()
m.initialize()
m()
However, sometimes it's not avoidable to create variables in tf.function
(such as slot variables in some TF keras optimizers). Still, we can simply move the dataset creation outside of the tf.function
call. The reason that we can rely on this is because tf.function
will receive the dataset as an implicit input and both graphs can access it properly.
class Model(tf.Module):
def __init__(self):
self.v = None
self.dataset = None
def initialize(self):
if self.dataset is None:
self.dataset = tf.data.Dataset.from_tensors([1, 2, 3])
@tf.function
def __call__(self):
if self.v is None:
self.v = tf.Variable(0)
it = iter(self.dataset)
return [self.v, next(it)]
m = Model()
m.initialize()
m()
Example 3: Unexpected Tensorflow object re-creations due to dict usage
tf.function
has very poor support for python side effects such as appending to a list, or checking/adding to a dictionary. More details are in "Better performance with tf.function". In the example below, the code uses dictionaries to cache datasets and iterators. For the same key, each call to the model will return the same iterator of the dataset.
class Model(tf.Module):
def __init__(self):
self.datasets = {}
self.iterators = {}
def __call__(self, key):
if key not in self.datasets:
self.datasets[key] = tf.compat.v1.data.Dataset.from_tensor_slices([1, 2, 3])
self.iterators[key] = self.datasets[key].make_initializable_iterator()
return self.iterators[key]
with tf.Graph().as_default():
with tf.compat.v1.Session() as sess:
m = Model()
it = m('a')
sess.run(it.initializer)
for _ in range(3):
print(sess.run(it.get_next())) # prints 1, 2, 3
However, the pattern above will not work as expected in tf.function
. During tracing, tf.function
will ignore the python side effect of addition to the dictionaries. Instead, it only remembers the creation of a new dataset and iterator. As a result, each call to the model will always return a new iterator. This issue is hard to notice unless the numerical results or performance are significant enough. Hence, we recommend users to think about the code carefully before wrapping tf.function
naively onto the python code.
class Model(tf.Module):
def __init__(self):
self.datasets = {}
self.iterators = {}
@tf.function
def __call__(self, key):
if key not in self.datasets:
self.datasets[key] = tf.data.Dataset.from_tensor_slices([1, 2, 3])
self.iterators[key] = iter(self.datasets[key])
return self.iterators[key]
m = Model()
for _ in range(3):
print(next(m('a'))) # prints 1, 1, 1
We can use tf.init_scope
to lift the dataset and iterator creation outside of the graph, to achieve the expected behavior:
class Model(tf.Module):
def __init__(self):
self.datasets = {}
self.iterators = {}
@tf.function
def __call__(self, key):
if key not in self.datasets:
# Lifts ops out of function-building graphs
with tf.init_scope():
self.datasets[key] = tf.data.Dataset.from_tensor_slices([1, 2, 3])
self.iterators[key] = iter(self.datasets[key])
return self.iterators[key]
m = Model()
for _ in range(3):
print(next(m('a'))) # prints 1, 2, 3
The general rule of thumb is to avoid relying on Python side effects in your logic and only use them to debug your traces.
Example 4: Manipulating a global Python list
The following TF1.x code uses a global list of losses that it uses to only maintain the list of losses generated by the current training step. Note that the Python logic that appends losses to the list will only be called once regardless of how many training steps the session is run for.
all_losses = []
class Model():
def __call__(...):
...
all_losses.append(regularization_loss)
all_losses.append(label_loss_a)
all_losses.append(label_loss_b)
...
g = tf.Graph()
with g.as_default():
...
# initialize all objects
model = Model()
optimizer = ...
...
# train step
model(...)
total_loss = tf.reduce_sum(all_losses)
optimizer.minimize(total_loss)
...
...
sess = tf.compat.v1.Session(graph=g)
sess.run(...)
However, if this Python logic is naively mapped to TF2 with eager execution, the global list of losses will have new values appended to it in each training step. This means the training step code which previously expected the list to only contain losses from the current training step now actually sees the list of losses from all training steps run so far. This is an unintended behavior change, and the list will either need to be cleared at the start of each step or made local to the training step.
all_losses = []
class Model():
def __call__(...):
...
all_losses.append(regularization_loss)
all_losses.append(label_loss_a)
all_losses.append(label_loss_b)
...
# initialize all objects
model = Model()
optimizer = ...
def train_step(...)
...
model(...)
total_loss = tf.reduce_sum(all_losses) # global list is never cleared,
# Accidentally accumulates sum loss across all training steps
optimizer.minimize(total_loss)
...
Pattern 2: A symbolic tensor meant to be recomputed every step in TF1.x is accidentally cached with the initial value when switching to eager.
This pattern usually causes your code to silently misbehave when executing eagerly outside of tf.functions, but raises an InaccessibleTensorError
if the initial value caching occurs inside of a tf.function
. However, be aware that in order to avoid Pattern 1 above you will often inadvertently structure your code in such a way that this initial value caching will happen outside of any tf.function
that would be able to raise an error. So, take extra care if you know your program may be susceptible to this pattern.
The general solution to this pattern is to restructure the code or use Python callables if necessary to make sure the value is recomputed each time instead of being accidentally cached.
Example 1: Learning rate/hyperparameter/etc. schedules that depend on global step
In the following code snippet, the expectation is that every time the session is run the most recent global_step
value will be read and a new learning rate will be computed.
g = tf.Graph()
with g.as_default():
...
global_step = tf.Variable(0)
learning_rate = 1.0 / global_step
opt = tf.compat.v1.train.GradientDescentOptimizer(learning_rate)
...
global_step.assign_add(1)
...
sess = tf.compat.v1.Session(graph=g)
sess.run(...)
However, when trying to switch to eager, be wary of ending up with the learning rate only being computed once then reused, rather than following the intended schedule:
global_step = tf.Variable(0)
learning_rate = 1.0 / global_step # Wrong! Only computed once!
opt = tf.keras.optimizers.SGD(learning_rate)
def train_step(...):
...
opt.apply_gradients(...)
global_step.assign_add(1)
...
Because this specific example is a common pattern and optimizers should only be initialized once rather than at each training step, TF2 optimizers support tf.keras.optimizers.schedules.LearningRateSchedule
schedules or Python callables as arguments for the learning rate and other hyperparameters.
Example 2: Symbolic random number initializations assigned as object attributes then reused via pointer are accidentally cached when switching to eager
Consider the following NoiseAdder
module:
class NoiseAdder(tf.Module):
def __init__(shape, mean):
self.noise_distribution = tf.random.normal(shape=shape, mean=mean)
self.trainable_scale = tf.Variable(1.0, trainable=True)
def add_noise(input):
return (self.noise_distribution + input) * self.trainable_scale
Using it as follows in TF1.x will compute a new random noise tensor every time the session is run:
g = tf.Graph()
with g.as_default():
...
# initialize all variable-containing objects
noise_adder = NoiseAdder(shape, mean)
...
# computation pass
x_with_noise = noise_adder.add_noise(x)
...
...
sess = tf.compat.v1.Session(graph=g)
sess.run(...)
However, in TF2 initializing the noise_adder
at the beginning will cause the noise_distribution
to be only computed once and get frozen for all training steps:
...
# initialize all variable-containing objects
noise_adder = NoiseAdder(shape, mean) # Freezes `self.noise_distribution`!
...
# computation pass
x_with_noise = noise_adder.add_noise(x)
...
To fix this, refactor NoiseAdder
to call tf.random.normal
every time a new random tensor is needed, instead of referring to the same tensor object each time.
class NoiseAdder(tf.Module):
def __init__(shape, mean):
self.noise_distribution = lambda: tf.random.normal(shape=shape, mean=mean)
self.trainable_scale = tf.Variable(1.0, trainable=True)
def add_noise(input):
return (self.noise_distribution() + input) * self.trainable_scale
Pattern 3: TF1.x code directly relies on and looks up tensors by name
It is common for TF1.x code tests to rely on checking what tensors or operations are present in a graph. In some rare cases, modeling code will also rely on these lookups by name.
Tensor names are not generated when executing eagerly outside of tf.function
at all, so all usages of tf.Tensor.name
must happen inside of a tf.function
. Keep in mind the actual generated names are very likely to differ between TF1.x and TF2 even within the same tf.function
, and API guarantees do not ensure stability of the generated names across TF versions.
Pattern 4: TF1.x session selectively runs only part of the generated graph
In TF1.x, you can construct a graph and then choose to only selectively run only a subset of it with a session by choosing a set of inputs and outputs that do not require running every op in the graph.
For example, you may have both a generator and a discriminator inside of a single graph, and use separate tf.compat.v1.Session.run
calls to alternate between only training the discriminator or only training the generator.
In TF2, due to automatic control dependencies in tf.function
and eager execution, there is no selective pruning of tf.function
traces. A full graph containing all variable updates would get run even if, for example, only the output of the discriminator or the generator is output from the tf.function
.
So, you would need to either use multiple tf.function
s containing different parts of the program, or a conditional argument to the tf.function
that you branch on so as to execute only the things you actually want to have run.
Collections Removal
When eager execution is enabled, graph collection-related compat.v1
APIs (including those that read or write to collections under the hood such as tf.compat.v1.trainable_variables
) are no longer available. Some may raise ValueError
s, while others may silently return empty lists.
The most standard usage of collections in TF1.x is to maintain initializers, the global step, weights, regularization losses, model output losses, and variable updates that need to be run such as from BatchNormalization
layers.
To handle each of these standard usages:
- Initializers - Ignore. Manual variable initialization is not required with eager execution enabled.
- Global step - See the documentation of
tf.compat.v1.train.get_or_create_global_step
for migration instructions. - Weights - Map your models to
tf.Module
s/tf.keras.layers.Layer
s/tf.keras.Model
s by following the guidance in the model mapping guide and then use their respective weight-tracking mechanisms such astf.module.trainable_variables
. - Regularization losses - Map your models to
tf.Module
s/tf.keras.layers.Layer
s/tf.keras.Model
s by following the guidance in the model mapping guide and then usetf.keras.losses
. Alternatively, you can also manually track your regularization losses. - Model output losses - Use
tf.keras.Model
loss management mechanisms or separately track your losses without using collections. - Weight updates - Ignore this collection. Eager execution and
tf.function
(with autograph and auto-control-dependencies) means all variable updates will get run automatically. So, you will not have to explicitly run all weight updates at the end, but note that it means the weight updates may happen at a different time than they did in your TF1.x code, depending on how you were using control dependencies. - Summaries - Refer to the migrating summary API guide.
More complex collections usage (such as using custom collections) may require you to refactor your code to either maintain your own global stores, or to make it not rely on global stores at all.
ResourceVariables
instead of ReferenceVariables
ResourceVariables
have stronger read-write consistency guarantees than ReferenceVariables
. This leads to more predictable, easier-to-reason semantics about whether or not you will observe the result of a previous write when using your variables. This change is extremely unlikely to cause existing code to raise errors or to break silently.
However, it is possible though unlikely that these stronger consistency guarantees may increase the memory usage of your specific program. Please file an issue if you find this to be the case. Additionally, if you have unit tests relying on exact string comparisons against the operator names in a graph corresponding to variable reads, be aware that enabling resource variables may slightly change the name of these operators.
To isolate the impact of this behavior change on your code, if eager execution is disabled you can use tf.compat.v1.disable_resource_variables()
and tf.compat.v1.enable_resource_variables()
to globally disable or enable this behavior change. ResourceVariables
will always be used if eager execution is enabled.
Control flow v2
In TF1.x, control flow ops such as tf.cond
and tf.while_loop
inline low-level ops such as Switch
, Merge
etc. TF2 provides improved functional control flow ops that are implemented with separate tf.function
traces for every branch and support higher-order differentiation.
To isolate the impact of this behavior change on your code, if eager execution is disabled you can use tf.compat.v1.disable_control_flow_v2()
and tf.compat.v1.enable_control_flow_v2()
to globally disable or enable this behavior change. However, you can only disable control flow v2 if eager execution is also disabled. If it is enabled, control flow v2 will always be used.
This behavior change can dramatically change the structure of generated TF programs that use control flow, as they will contain several nested function traces rather than one flat graph. So, any code that is highly dependent on the exact semantics of produced traces may require some modification. This includes:
- Code relying on operator and tensor names
- Code referring to tensors created within a TensorFlow control flow branch from outside of that branch. This is likely to produce an
InaccessibleTensorError
This behavior change is intended to be performance neutral to positive, but if you run into an issue where control flow v2 performs worse for you than TF1.x control flow then please file an issue with reproduction steps.
TensorShape API behavior changes
The TensorShape
class was simplified to hold int
s, instead of tf.compat.v1.Dimension
objects. So there is no need to call .value
to get an int
.
Individual tf.compat.v1.Dimension
objects are still accessible from tf.TensorShape.dims
.
To isolate the impact of this behavior change on your code, you can use tf.compat.v1.disable_v2_tensorshape()
and tf.compat.v1.enable_v2_tensorshape()
to globally disable or enable this behavior change.
The following demonstrate the differences between TF1.x and TF2.
import tensorflow as tf
2024-08-15 02:19:00.377611: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered 2024-08-15 02:19:00.398782: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered 2024-08-15 02:19:00.405281: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
# Create a shape and choose an index
i = 0
shape = tf.TensorShape([16, None, 256])
shape
TensorShape([16, None, 256])
If you had this in TF1.x:
value = shape[i].value
Then do this in TF2:
value = shape[i]
value
16
If you had this in TF1.x:
for dim in shape:
value = dim.value
print(value)
Then, do this in TF2:
for value in shape:
print(value)
16 None 256
If you had this in TF1.x (or used any other dimension method):
dim = shape[i]
dim.assert_is_compatible_with(other_dim)
Then do this in TF2:
other_dim = 16
Dimension = tf.compat.v1.Dimension
if shape.rank is None:
dim = Dimension(None)
else:
dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
True
shape = tf.TensorShape(None)
if shape:
dim = shape.dims[i]
dim.is_compatible_with(other_dim) # or any other dimension method
The boolean value of a tf.TensorShape
is True
if the rank is known, False
otherwise.
print(bool(tf.TensorShape([]))) # Scalar
print(bool(tf.TensorShape([0]))) # 0-length vector
print(bool(tf.TensorShape([1]))) # 1-length vector
print(bool(tf.TensorShape([None]))) # Unknown-length vector
print(bool(tf.TensorShape([1, 10, 100]))) # 3D tensor
print(bool(tf.TensorShape([None, None, None]))) # 3D tensor with no known dimensions
print()
print(bool(tf.TensorShape(None))) # A tensor with unknown rank.
True True True True True True False
Potential errors due to TensorShape changes
The TensorShape behavior changes are unlikely to silently break your code. However, you may see shape-related code begin to raise AttributeError
s as int
s and None
s do not have the same attributes that tf.compat.v1.Dimension
s do. Below are some examples of these AttributeError
s:
try:
# Create a shape and choose an index
shape = tf.TensorShape([16, None, 256])
value = shape[0].value
except AttributeError as e:
# 'int' object has no attribute 'value'
print(e)
'int' object has no attribute 'value'
try:
# Create a shape and choose an index
shape = tf.TensorShape([16, None, 256])
dim = shape[1]
other_dim = shape[2]
dim.assert_is_compatible_with(other_dim)
except AttributeError as e:
# 'NoneType' object has no attribute 'assert_is_compatible_with'
print(e)
'NoneType' object has no attribute 'assert_is_compatible_with'
Tensor Equality by Value
The binary ==
and !=
operators on variables and tensors were changed to compare by value in TF2 rather than comparing by object reference like in TF1.x. Additionally, tensors and variables are no longer directly hashable or usable in sets or dict keys, because it may not be possible to hash them by value. Instead, they expose a .ref()
method that you can use to get a hashable reference to the tensor or variable.
To isolate the impact of this behavior change, you can use tf.compat.v1.disable_tensor_equality()
and tf.compat.v1.enable_tensor_equality()
to globally disable or enable this behavior change.
For example, in TF1.x, two variables with the same value will return false when you use the ==
operator:
tf.compat.v1.disable_tensor_equality()
x = tf.Variable(0.0)
y = tf.Variable(0.0)
x == y
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR I0000 00:00:1723688343.035972 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.039323 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.042974 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.046726 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.058112 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.061129 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.064551 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.068092 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.071030 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.074018 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.077399 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688343.080793 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.308197 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.310340 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.312350 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.314440 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.316503 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.318477 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.320380 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.322366 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.324325 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.326301 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.328196 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.330191 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.369490 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.371572 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.373525 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.375567 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.378245 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.381003 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.382927 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.385462 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.387403 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.389956 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.392328 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 I0000 00:00:1723688344.394762 99241 cuda_executor.cc:1015] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355 False
While in TF2 with tensor equality checks enabled, x == y
will return True
.
tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)
y = tf.Variable(0.0)
x == y
<tf.Tensor: shape=(), dtype=bool, numpy=True>
So, in TF2, if you need to compare by object reference make sure to use is
and is not
tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)
y = tf.Variable(0.0)
x is y
False
Hashing tensors and variables
With TF1.x behaviors you used to be able to directly add variables and tensors to data structures that require hashing, such as set
and dict
keys.
x = tf.Variable(0.0)
set([x, tf.constant(2.0)])
However, in TF2 with tensor equality enabled, tensors and variables are made unhashable due to the ==
and !=
operator semantics changing to value equality checks.
tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)
try:
set([x, tf.constant(2.0)])
except TypeError as e:
# TypeError: Variable is unhashable. Instead, use tensor.ref() as the key.
print(e)
Variable is unhashable. Instead, use variable.ref() as the key. (Variable: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>)
So, in TF2 if you need to use tensor or variable objects as keys or set
contents, you can use tensor.ref()
to get a hashable reference that can be used as a key:
tf.compat.v1.enable_tensor_equality()
x = tf.Variable(0.0)
tensor_set = set([x.ref(), tf.constant(2.0).ref()])
assert x.ref() in tensor_set
tensor_set
{<Reference wrapping <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>>, <Reference wrapping <tf.Tensor: shape=(), dtype=float32, numpy=2.0>>}
If needed, you can also get the tensor or variable from the reference by using reference.deref()
:
referenced_var = x.ref().deref()
assert referenced_var is x
referenced_var
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=0.0>
Resources and further reading
- Visit the Migrate to TF2 section to read more about migrating to TF2 from TF1.x.
- Read the model mapping guide to learn more mapping your TF1.x models to work in TF2 directly.