Performs a soft/hard update of variables from the source to the target.
tf_agents.utils.common.soft_variables_update(
source_variables,
target_variables,
tau=1.0,
tau_non_trainable=None,
sort_variables_by_name=False
)
For each variable v_t in target variables and its corresponding variable v_s
in source variables, a soft update is:
v_t = (1 - tau) * v_t + tau * v_s
When tau is 1.0 (the default), then it does a hard update:
v_t = v_s
Args |
source_variables
|
list of source variables.
|
target_variables
|
list of target variables.
|
tau
|
A float scalar in [0, 1]. When tau is 1.0 (the default), we do a hard
update. This is used for trainable variables.
|
tau_non_trainable
|
A float scalar in [0, 1] for non_trainable variables. If
None, will copy from tau.
|
sort_variables_by_name
|
A bool, when True would sort the variables by name
before doing the update.
|
Returns |
An operation that updates target variables from source variables.
|
Raises |
ValueError
|
if tau not in [0, 1] .
|
ValueError
|
if len(source_variables) != len(target_variables) .
|
ValueError
|
"Method requires being in cross-replica context,
use get_replica_context().merge_call()" if used inside replica context.
|