tf_agents.utils.nest_utils.prune_extra_keys

Recursively prunes keys from wide if they don't appear in narrow.

Often used as preprocessing prior to calling tf.nest.flatten or tf.nest.map_structure.

This function is more forgiving than the ones in nest; if two substructures' types or structures don't agree, we consider it invalid and prune_extra_keys will return the wide substructure as is. Typically, additional checking is needed: you will also want to use nest.assert_same_structure(narrow, prune_extra_keys(narrow, wide)) to ensure the result of pruning is still a correct structure.

Examples:

wide = [{"a": "a", "b": "b"}]
# Narrows 'wide'
assert prune_extra_keys([{"a": 1}], wide) == [{"a": "a"}]
# 'wide' lacks "c", is considered invalid.
assert prune_extra_keys([{"c": 1}], wide) == wide
# 'wide' contains a different type from 'narrow', is considered invalid
assert prune_extra_keys("scalar", wide) == wide
# 'wide' substructure for key "d" does not match the one in 'narrow' and
# therefore is returned unmodified.
assert (prune_extra_keys({"a": {"b": 1}, "d": None},
                         {"a": {"b": "b", "c": "c"}, "d": [1, 2]})
        == {"a": {"b": "b"}, "d": [1, 2]})
# assert prune_extra_keys((), wide) == ()
# assert prune_extra_keys({"a": ()}, wide) == {"a": ()}

narrow A nested structure.
wide A nested structure that may contain dicts with more fields than narrow.

A structure with the same nested substructures as wide, but with dicts whose entries are limited to the keys found in the associated substructures of narrow.

In case of substructure or size mismatches, the returned substructures will be returned as is. Note that ObjectProxy-wrapped objects are considered equivalent to their non-ObjectProxy types.