View source on GitHub
|
Recursively prunes keys from wide if they don't appear in narrow.
tf_agents.utils.nest_utils.prune_extra_keys(
narrow, wide
)
Often used as preprocessing prior to calling tf.nest.flatten
or tf.nest.map_structure.
This function is more forgiving than the ones in nest; if two substructures'
types or structures don't agree, we consider it invalid and prune_extra_keys
will return the wide substructure as is. Typically, additional checking is
needed: you will also want to use
nest.assert_same_structure(narrow, prune_extra_keys(narrow, wide))
to ensure the result of pruning is still a correct structure.
Examples:
wide = [{"a": "a", "b": "b"}]
# Narrows 'wide'
assert prune_extra_keys([{"a": 1}], wide) == [{"a": "a"}]
# 'wide' lacks "c", is considered invalid.
assert prune_extra_keys([{"c": 1}], wide) == wide
# 'wide' contains a different type from 'narrow', is considered invalid
assert prune_extra_keys("scalar", wide) == wide
# 'wide' substructure for key "d" does not match the one in 'narrow' and
# therefore is returned unmodified.
assert (prune_extra_keys({"a": {"b": 1}, "d": None},
{"a": {"b": "b", "c": "c"}, "d": [1, 2]})
== {"a": {"b": "b"}, "d": [1, 2]})
# assert prune_extra_keys((), wide) == ()
# assert prune_extra_keys({"a": ()}, wide) == {"a": ()}
Args | |
|---|---|
narrow
|
A nested structure. |
wide
|
A nested structure that may contain dicts with more fields than
narrow.
|
Returns | |
|---|---|
A structure with the same nested substructures as wide, but with
dicts whose entries are limited to the keys found in the associated
substructures of narrow.
In case of substructure or size mismatches, the returned substructures will be returned as is. Note that ObjectProxy-wrapped objects are considered equivalent to their non-ObjectProxy types. |
View source on GitHub