Turns a (potentially nested) structure of dists into a single dist.
tfp.substrates.jax.distributions.independent_joint_distribution_from_structure(
structure_of_distributions, batch_ndims=None, validate_args=False
)
Args |
structure_of_distributions
|
instance of tfd.Distribution , or nested
structure (tuple, list, dict, etc.) in which all leaves are
tfd.Distribution instances.
|
batch_ndims
|
Optional integer Tensor number of leftmost batch dimensions
shared across all members of the input structure. If this is specified,
the returned joint distribution will be an autobatched distribution with
the given batch rank, and all other dimensions absorbed into the event.
|
validate_args
|
Python bool . Whether the joint distribution should validate
input with asserts. This imposes a runtime cost. If validate_args is
False , and the inputs are invalid, correct behavior is not guaranteed.
Default value: False .
|
Returns |
distribution
|
instance of tfd.Distribution such that
distribution.sample() is equivalent to
tf.nest.map_structure(lambda d: d.sample(), structure_of_distributions) .
If structure_of_distributions was indeed a structure (as opposed to
a single Distribution instance), this will be a JointDistribution
with the corresponding structure.
|
Raises |
TypeError
|
if any leaves of the input structure are not tfd.Distribution
instances.
|