View source on GitHub |
Compute mean and variance, accounting for a mask.
tfp.substrates.jax.sts.moments_of_masked_time_series(
time_series_tensor, broadcast_mask
)
Args | |
---|---|
time_series_tensor
|
float Tensor time series of shape
concat([batch_shape, [num_timesteps]]) .
|
broadcast_mask
|
bool Tensor of the same shape as time_series .
|
Returns | |
---|---|
mean
|
float Tensor of shape batch_shape .
|
variance
|
float Tensor of shape batch_shape .
|