View source on GitHub |
Recursively prunes keys from wide
if they don't appear in narrow
.
tf_agents.utils.nest_utils.prune_extra_keys(
narrow, wide
)
Often used as preprocessing prior to calling tf.nest.flatten
or tf.nest.map_structure
.
This function is more forgiving than the ones in nest
; if two substructures'
types or structures don't agree, we consider it invalid and prune_extra_keys
will return the wide
substructure as is. Typically, additional checking is
needed: you will also want to use
nest.assert_same_structure(narrow, prune_extra_keys(narrow, wide))
to ensure the result of pruning is still a correct structure.
Examples:
wide = [{"a": "a", "b": "b"}]
# Narrows 'wide'
assert prune_extra_keys([{"a": 1}], wide) == [{"a": "a"}]
# 'wide' lacks "c", is considered invalid.
assert prune_extra_keys([{"c": 1}], wide) == wide
# 'wide' contains a different type from 'narrow', is considered invalid
assert prune_extra_keys("scalar", wide) == wide
# 'wide' substructure for key "d" does not match the one in 'narrow' and
# therefore is returned unmodified.
assert (prune_extra_keys({"a": {"b": 1}, "d": None},
{"a": {"b": "b", "c": "c"}, "d": [1, 2]})
== {"a": {"b": "b"}, "d": [1, 2]})
# assert prune_extra_keys((), wide) == ()
# assert prune_extra_keys({"a": ()}, wide) == {"a": ()}
Args | |
---|---|
narrow
|
A nested structure. |
wide
|
A nested structure that may contain dicts with more fields than
narrow .
|
Returns | |
---|---|
A structure with the same nested substructures as wide , but with
dicts whose entries are limited to the keys found in the associated
substructures of narrow .
In case of substructure or size mismatches, the returned substructures will be returned as is. Note that ObjectProxy-wrapped objects are considered equivalent to their non-ObjectProxy types. |