Builds a joint variational posterior by splitting a normalizing flow.
tfp.experimental.vi.build_split_flow_surrogate_posterior(
event_shape,
trainable_bijector,
constraining_bijector=None,
base_distribution=tfp.distributions.Normal
,
batch_shape=(),
dtype=tf.float32,
validate_args=False,
name=None
)
Args |
event_shape
|
(Nested) event shape of the surrogate posterior.
|
trainable_bijector
|
A trainable tfb.Bijector instance that operates on
Tensor s (not structures), e.g. tfb.MaskedAutoregressiveFlow or
tfb.RealNVP . This bijector transforms the base distribution before it is
split.
|
constraining_bijector
|
tfb.Bijector instance, or nested structure of
tfb.Bijector instances, that maps (nested) values in R^n to the support
of the posterior. (This can be the
experimental_default_event_space_bijector of the distribution over the
prior latent variables.)
Default value: None (i.e., the posterior is over R^n).
|
base_distribution
|
A tfd.Distribution subclass parameterized by loc and
scale . The base distribution for the transformed surrogate has loc=0.
and scale=1. .
Default value: tfd.Normal .
|
batch_shape
|
The batch_shape of the output distribution.
Default value: () .
|
dtype
|
The dtype of the surrogate posterior.
Default value: tf.float32 .
|
validate_args
|
Python bool . Whether to validate input with asserts. This
imposes a runtime cost. If validate_args is False , and the inputs are
invalid, correct behavior is not guaranteed.
Default value: False .
|
name
|
Python str name prefixed to ops created by this function.
Default value: None (i.e., 'build_split_flow_surrogate_posterior').
|
Returns |
surrogate_distribution
|
Trainable tfd.TransformedDistribution with event
shape equal to event_shape .
|
Examples
# Train a normalizing flow on the Eight Schools model [1].
treatment_effects = [28., 8., -3., 7., -1., 1., 18., 12.]
treatment_stddevs = [15., 10., 16., 11., 9., 11., 10., 18.]
model = tfd.JointDistributionNamed({
'avg_effect':
tfd.Normal(loc=0., scale=10., name='avg_effect'),
'log_stddev':
tfd.Normal(loc=5., scale=1., name='log_stddev'),
'school_effects':
lambda log_stddev, avg_effect: (
tfd.Independent(
tfd.Normal(
loc=avg_effect[..., None] * tf.ones(8),
scale=tf.exp(log_stddev[..., None]) * tf.ones(8),
name='school_effects'),
reinterpreted_batch_ndims=1)),
'treatment_effects': lambda school_effects: tfd.Independent(
tfd.Normal(loc=school_effects, scale=treatment_stddevs),
reinterpreted_batch_ndims=1)
})
# Pin the observed values in the model.
target_model = model.experimental_pin(treatment_effects=treatment_effects)
# Create a Masked Autoregressive Flow bijector.
net = tfb.AutoregressiveNetwork(2, hidden_units=[16, 16], dtype=tf.float32)
maf = tfb.MaskedAutoregressiveFlow(shift_and_log_scale_fn=net)
# Build and fit the surrogate posterior.
surrogate_posterior = (
tfp.experimental.vi.build_split_flow_surrogate_posterior(
event_shape=target_model.event_shape_tensor(),
trainable_bijector=maf,
constraining_bijector=(
target_model.experimental_default_event_space_bijector())))
losses = tfp.vi.fit_surrogate_posterior(
target_model.unnormalized_log_prob,
surrogate_posterior,
num_steps=100,
optimizer=tf.optimizers.Adam(0.1),
sample_size=10)
References
[1] Andrew Gelman, John Carlin, Hal Stern, David Dunson, Aki Vehtari, and
Donald Rubin. Bayesian Data Analysis, Third Edition.
Chapman and Hall/CRC, 2013.