View source on GitHub |
Creates a ClientWorkProcess
for federated averaging.
tff.learning.templates.build_model_delta_client_work(
model_fn: Callable[[], tff.learning.models.VariableModel
],
optimizer: tff.learning.optimizers.Optimizer
,
client_weighting: tff.learning.ClientWeighting
,
metrics_aggregator: Optional[tff.learning.metrics.MetricsAggregatorType
] = None,
*,
loop_implementation: tff.learning.LoopImplementation
= tff.learning.LoopImplementation.DATASET_REDUCE
) -> tff.learning.templates.ClientWorkProcess
Args | |
---|---|
model_fn
|
A no-arg function that returns a
tff.learning.models.VariableModel . This method must not capture
TensorFlow tensors or variables and use them. The model must be
constructed entirely from scratch on each invocation, returning the same
pre-constructed model each call will result in an error.
|
optimizer
|
A tff.learning.optimizers.Optimizer .
|
client_weighting
|
A tff.learning.ClientWeighting value.
|
metrics_aggregator
|
A function that takes in the metric finalizers (i.e.,
tff.learning.models.VariableModel.metric_finalizers() ) returns a
tff.Computation for aggregating the unfinalized metrics. If None ,
this is set to tff.learning.metrics.sum_then_finalize .
|
loop_implementation
|
Changes the implementation of the training loop
generated. See tff.learning.LoopImplementation for more details.
|
Returns | |
---|---|
A ClientWorkProcess .
|