View source on GitHub |
Builds a joint variational posterior with a given event_shape
.
tfp.experimental.vi.build_affine_surrogate_posterior_stateless(
event_shape,
operators='diag',
bijector=None,
base_distribution=tfp.distributions.Normal
,
dtype=tf.float32,
batch_shape=(),
validate_args=False,
name=None
)
This function builds a surrogate posterior by applying a trainable
transformation to a standard base distribution and constraining the samples
with bijector
. The surrogate posterior has event shape equal to
the input event_shape
.
This function is a convenience wrapper around
build_affine_surrogate_posterior_from_base_distribution
that allows the
user to pass in the desired posterior event_shape
instead of
pre-constructed base distributions (at the expense of full control over the
base distribution types and parameterizations).
Args | |
---|---|
event_shape
|
(Nested) event shape of the posterior. |
operators
|
Either a string or a list/tuple containing LinearOperator
subclasses, LinearOperator instances, or callables returning
LinearOperator instances. Supported string values are "diag" (to create
a mean-field surrogate posterior) and "tril" (to create a full-covariance
surrogate posterior). A list/tuple may be passed to induce other
posterior covariance structures. If the list is flat, a
tf.linalg.LinearOperatorBlockDiag instance will be created and applied
to the base distribution. Otherwise the list must be singly-nested and
have a first element of length 1, second element of length 2, etc.; the
elements of the outer list are interpreted as rows of a lower-triangular
block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular
instance is created. For complete documentation and examples, see
tfp.experimental.vi.util.build_trainable_linear_operator_block , which
receives the operators arg if it is list-like.
Default value: "diag" .
|
bijector
|
tfb.Bijector instance, or nested structure of tfb.Bijector
instances, that maps (nested) values in R^n to the support of the
posterior. (This can be the experimental_default_event_space_bijector of
the distribution over the prior latent variables.)
Default value: None (i.e., the posterior is over R^n).
|
base_distribution
|
A tfd.Distribution subclass parameterized by loc and
scale . The base distribution of the transformed surrogate has loc=0.
and scale=1. .
Default value: tfd.Normal .
|
dtype
|
The dtype of the surrogate posterior.
Default value: tf.float32 .
|
batch_shape
|
Batch shape (Python tuple, list, or int) of the surrogate
posterior, to enable parallel optimization from multiple initializations.
Default value: () .
|
validate_args
|
Python bool . Whether to validate input with asserts. This
imposes a runtime cost. If validate_args is False , and the inputs are
invalid, correct behavior is not guaranteed.
Default value: False .
|
name
|
Python str name prefixed to ops created by this function.
Default value: None (i.e., 'build_affine_surrogate_posterior').
|
Returns | |
---|---|
init_fn
|
Python callable with signature initial_parameters = init_fn(seed) .
|
apply_fn
|
Python callable with signature instance = apply_fn(*parameters) .
|
Examples
tfd = tfp.distributions
tfb = tfp.bijectors
# Define a joint probabilistic model.
Root = tfd.JointDistributionCoroutine.Root
def model_fn():
concentration = yield Root(tfd.Exponential(1.))
rate = yield Root(tfd.Exponential(1.))
y = yield tfd.Sample(
tfd.Gamma(concentration=concentration, rate=rate),
sample_shape=4)
model = tfd.JointDistributionCoroutine(model_fn)
# Assume the `y` are observed, such that the posterior is a joint distribution
# over `concentration` and `rate`. The posterior event shape is then equal to
# the first two components of the model's event shape.
posterior_event_shape = model.event_shape_tensor()[:-1]
# Constrain the posterior values to be positive using the `Exp` bijector.
bijector = [tfb.Exp(), tfb.Exp()]
# Build a full-covariance surrogate posterior.
surrogate_posterior = (
tfp.experimental.vi.build_affine_surrogate_posterior(
event_shape=posterior_event_shape,
operators='tril',
bijector=bijector))
# For an example defining `'operators'` as a list to express an alternative
# covariance structure, see
# `build_affine_surrogate_posterior_from_base_distribution`.
# Fit the model.
y = [0.2, 0.5, 0.3, 0.7]
target_model = model.experimental_pin(y=y)
losses = tfp.vi.fit_surrogate_posterior(
target_model.unnormalized_log_prob,
surrogate_posterior,
num_steps=100,
optimizer=tf.optimizers.Adam(0.1),
sample_size=10)