TensorFlow 1 version
|
View source on GitHub
|
Updates the shape of a tensor and checks at runtime that the shape holds.
tf.ensure_shape(
x, shape, name=None
)
With eager execution this is a shape assertion, that returns the input:
x = tf.constant([1,2,3])print(x.shape)(3,)x = tf.ensure_shape(x, [3])x = tf.ensure_shape(x, [5])Traceback (most recent call last):tf.errors.InvalidArgumentError: Shape of tensor dummy_input [3] is notcompatible with expected shape [5]. [Op:EnsureShape]
Inside a tf.function or v1.Graph context it checks both the buildtime and
runtime shapes. This is stricter than tf.Tensor.set_shape which only
checks the buildtime shape.
For example, of loading images of a known size:
@tf.functiondef decode_image(png):image = tf.image.decode_png(png, channels=3)# the `print` executes during tracing.print("Initial shape: ", image.shape)image = tf.ensure_shape(image,[28, 28, 3])print("Final shape: ", image.shape)return image
When tracing a function, no ops are being executed, shapes may be unknown. See the Concrete Functions Guide for details.
concrete_decode = decode_image.get_concrete_function(tf.TensorSpec([], dtype=tf.string))Initial shape: (None, None, 3)Final shape: (28, 28, 3)
image = tf.random.uniform(maxval=255, shape=[28, 28, 3], dtype=tf.int32)image = tf.cast(image,tf.uint8)png = tf.image.encode_png(image)image2 = concrete_decode(png)print(image2.shape)(28, 28, 3)
image = tf.concat([image,image], axis=0)print(image.shape)(56, 28, 3)png = tf.image.encode_png(image)image2 = concrete_decode(png)Traceback (most recent call last):tf.errors.InvalidArgumentError: Shape of tensor DecodePng [56,28,3] is notcompatible with expected shape [28,28,3].
@tf.functiondef bad_decode_image(png):image = tf.image.decode_png(png, channels=3)# the `print` executes during tracing.print("Initial shape: ", image.shape)# BAD: forgot to use the returned tensor.tf.ensure_shape(image,[28, 28, 3])print("Final shape: ", image.shape)return image
image = bad_decode_image(png)Initial shape: (None, None, 3)Final shape: (None, None, 3)print(image.shape)(56, 28, 3)
Args | |
|---|---|
x
|
A Tensor.
|
shape
|
A TensorShape representing the shape of this tensor, a
TensorShapeProto, a list, a tuple, or None.
|
name
|
A name for this operation (optional). Defaults to "EnsureShape". |
Returns | |
|---|---|
A Tensor. Has the same type and contents as x. At runtime, raises a
tf.errors.InvalidArgumentError if shape is incompatible with the shape
of x.
|
TensorFlow 1 version
View source on GitHub