ML Community Day is November 9! Join us for updates from TensorFlow, JAX, and more Learn more

tf.nest.assert_same_structure

Asserts that two structures are nested in the same way.

Note the method does not check the types of data inside the structures.

Examples:

  • These scalar vs. scalar comparisons will pass:
  tf.nest.assert_same_structure(1.5, tf.Variable(1, tf.uint32))
  tf.nest.assert_same_structure("abc", np.array([1, 2]))
    
  • These sequence vs. sequence comparisons will pass:
  structure1 = (((1, 2), 3), 4, (5, 6))
  structure2 = ((("foo1", "foo2"), "foo3"), "foo4", ("foo5", "foo6"))
  structure3 = [(("a", "b"), "c"), "d", ["e", "f"]]
  tf.nest.assert_same_structure(structure1, structure2)
  tf.nest.assert_same_structure(structure1, structure3, check_types=False)
    
  import collections
  tf.nest.assert_same_structure(
      collections.namedtuple("bar", "a b")(1, 2),
      collections.namedtuple("foo", "a b")(2, 3),
      check_types=False)
    
  tf.nest.assert_same_structure(
      collections.namedtuple("bar", "a b")(1, 2),
      { "a": 1, "b": 2 },
      check_types=False)
    
  tf.nest.assert_same_structure(
      { "a": 1, "b": 2, "c": 3 },
      { "c": 6, "b": 5, "a": 4 })
    
  ragged_tensor1 = tf.RaggedTensor.from_row_splits(
        values=[3, 1, 4, 1, 5, 9, 2, 6],
        row_splits=[0, 4, 4, 7, 8, 8])
  ragged_tensor2 = tf.RaggedTensor.from_row_splits(
        values=[3, 1, 4],
        row_splits=[0, 3])
  tf.nest.assert_same_structure(
        ragged_tensor1,
        ragged_tensor2,
        expand_composites=True)
    
  • These examples will raise exceptions:
  tf.nest.assert_same_structure([0, 1], np.array([0, 1]))
    Traceback (most recent call last):
  
    ValueError: The two structures don't have the same nested structure
    
  tf.nest.assert_same_structure(