tfp.substrates.jax.math.reduce_log_harmonic_mean_exp

Computes log(1 / mean(1 / exp(input_tensor))).

Reduces input_tensor along the dimensions given in axis. Unless keepdims is true, the rank of the tensor is reduced by 1 for each entry in axis. If keepdims is true, the reduced dimensions are retained with length 1.

If axis has no entries, all dimensions are reduced, and a tensor with a single element is returned.

This function is more numerically stable than log(1 / mean(1 - exp(input))). It avoids overflows caused by taking the exp of large inputs and underflows caused by taking the log of small inputs.

input_tensor The tensor to reduce. Should have numeric type.
axis The dimensions to reduce. If None (the default), reduces all dimensions. Must be in the range [-rank(input_tensor), rank(input_tensor)).
keepdims Boolean. Whether to keep the axis as singleton dimensions. Default value: False (i.e., squeeze the reduced dimensions).
experimental_named_axis A str or list ofstraxis names to additionally reduce over. ProvidingNonewill not reduce over any axes. </td> </tr><tr> <td>experimental_allow_all_gather<a id="experimental_allow_all_gather"></a> </td> <td> Allow using anall_gather-based fallback under TensorFlow when computing the distributed maximum. This fallback is only efficient whenaxisreduces away most of the dimensions ofinput_tensor. </td> </tr><tr> <td>name<a id="name"></a> </td> <td> Pythonstrname prefixed to Ops created by this function. Default value:None(i.e.,'reduce_log_harmonic_mean_exp'`).

log_mean_exp The reduced tensor.