tff.learning.metrics.create_functional_metric_fns

Turn a Keras metric construction method into a tuple of pure functions.

This can be used to convert Keras metrics for use in tff.learning.models.FunctionalModel. The method traces the metric logic into three tf.function with explicit state parameters that replace the closure over internal tf.Variable of the tf.keras.metrics.Metric.

>>> metric = tf.keras.metrics.Accuracy()
>>> metric.update_state([1.0, 1.0], [0.0, 1.0])
>>> metric.result()  # == 0.5
>>>
>>> metric_fns = tff.learning.metrics.create_functional_metric_fns(
>>>    tf.keras.metrics.Accuracy)
>>> initialize, update, finalize = metric_fns
>>> state = initialize()
>>> batch_output = tff.learning.models.BatchOutput(predictions=[0.0, 1.0])
>>> state = update(state, [1.0, 1.0], batch_output)
>>> finalize(state)  # == 0.5

metrics_constructor Either a no-arg callable that returns a tf.keras.metrics.Metric or an OrderedDict of str names to tf.keras.metrics.Metric, or OrderedDict of no-arg callables returning tf.keras.metrics.Metric instances. The no-arg callables can be the metric class itself (e.g. tf.keras.metrics.Accuracy) in which case the default metric configuration will be used. It also supports lambdas or functools.partial to provide alternate metric configurations.

A 3-tuple of tf.functions namely (initialize, update, finalize). initialize is a no-arg function used to create the algrebraic "zero" before reducing the metric over batches of examples. update is a function that takes three arguments, the state, labels, and the tff.learning.models.BatchOutput structure from the model's forward pass, and is used to add an observation to the metric. finalize only takes a state argument and returns the final metric value based on observations previously added.

TypeError If metrics_constructor is not a callable or OrderedDict, or if metrics_constructor is a callable returning values of the wrong type.