tf_agents.utils.common.soft_variables_update
Performs a soft/hard update of variables from the source to the target.
tf_agents.utils.common.soft_variables_update(
source_variables,
target_variables,
tau=1.0,
tau_non_trainable=None,
sort_variables_by_name=False
)
For each variable v_t in target variables and its corresponding variable v_s
in source variables, a soft update is:
v_t = (1 - tau) * v_t + tau * v_s
When tau is 1.0 (the default), then it does a hard update:
v_t = v_s
Args |
source_variables
|
list of source variables.
|
target_variables
|
list of target variables.
|
tau
|
A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
update. This is used for trainable variables.
|
tau_non_trainable
|
A float scalar in [0, 1] for non_trainable variables. If
None, will copy from tau.
|
sort_variables_by_name
|
A bool, when True would sort the variables by name
before doing the update.
|
Returns |
An operation that updates target variables from source variables.
|
Raises |
ValueError
|
if tau not in [0, 1] .
|
ValueError
|
if len(source_variables) != len(target_variables) .
|
ValueError
|
"Method requires being in cross-replica context,
use get_replica_context().merge_call()" if used inside replica context.
|
Except as otherwise noted, the content of this page is licensed under the Creative Commons Attribution 4.0 License, and code samples are licensed under the Apache 2.0 License. For details, see the Google Developers Site Policies. Java is a registered trademark of Oracle and/or its affiliates.
Last updated 2024-04-26 UTC.
[null,null,["Last updated 2024-04-26 UTC."],[],[]]