tfp.experimental.vi.build_affine_surrogate_posterior_stateless

Builds a joint variational posterior with a given event_shape.

This function builds a surrogate posterior by applying a trainable transformation to a standard base distribution and constraining the samples with bijector. The surrogate posterior has event shape equal to the input event_shape.

This function is a convenience wrapper around build_affine_surrogate_posterior_from_base_distribution that allows the user to pass in the desired posterior event_shape instead of pre-constructed base distributions (at the expense of full control over the base distribution types and parameterizations).

event_shape (Nested) event shape of the posterior.
operators Either a string or a list/tuple containing LinearOperator subclasses, LinearOperator instances, or callables returning LinearOperator instances. Supported string values are "diag" (to create a mean-field surrogate posterior) and "tril" (to create a full-covariance surrogate posterior). A list/tuple may be passed to induce other posterior covariance structures. If the list is flat, a tf.linalg.LinearOperatorBlockDiag instance will be created and applied to the base distribution. Otherwise the list must be singly-nested and have a first element of length 1, second element of length 2, etc.; the elements of the outer list are interpreted as rows of a lower-triangular block structure, and a tf.linalg.LinearOperatorBlockLowerTriangular instance is created. For complete documentation and examples, see tfp.experimental.vi.util.build_trainable_linear_operator_block, which receives the operators arg if it is list-like. Default value: "diag".
bijector tfb.Bijector instance, or nested structure of tfb.Bijector instances, that maps (nested) values in R^n to the support of the posterior. (This can be the experimental_default_event_space_bijector of the distribution over the prior latent variables.) Default value: None (i.e., the posterior is over R^n).
base_distribution A tfd.Distribution subclass parameterized by loc and scale. The base distribution of the transformed surrogate has loc=0. and scale=1.. Default value: tfd.Normal.
dtype The dtype of the surrogate posterior. Default value: tf.float32.
batch_shape Batch shape (Python tuple, list, or int) of the surrogate posterior, to enable parallel optimization from multiple initializations. Default value: ().
validate_args Python bool. Whether to validate input with asserts. This imposes a runtime cost. If validate_args is False, and the inputs are invalid, correct behavior is not guaranteed. Default value: False.
name Python str name prefixed to ops created by this function. Default value: None (i.e., 'build_affine_surrogate_posterior').

init_fn Python callable with signature initial_parameters = init_fn(seed).
apply_fn Python callable with signature instance = apply_fn(*parameters).

Examples

tfd = tfp.distributions
tfb = tfp.bijectors

# Define a joint probabilistic model.
Root = tfd.JointDistributionCoroutine.Root
def model_fn():
  concentration = yield Root(tfd.Exponential(1.))
  rate = yield Root(tfd.Exponential(1.))
  y = yield tfd.Sample(
      tfd.Gamma(concentration=concentration, rate=rate),
      sample_shape=4)
model = tfd.JointDistributionCoroutine(model_fn)

# Assume the `y` are observed, such that the posterior is a joint distribution
# over `concentration` and `rate`. The posterior event shape is then equal to
# the first two components of the model's event shape.
posterior_event_shape = model.event_shape_tensor()[:-1]

# Constrain the posterior values to be positive using the `Exp` bijector.
bijector = [tfb.Exp(), tfb.Exp()]

# Build a full-covariance surrogate posterior.
surrogate_posterior = (
  tfp.experimental.vi.build_affine_surrogate_posterior(
      event_shape=posterior_event_shape,
      operators='tril',
      bijector=bijector))

# For an example defining `'operators'` as a list to express an alternative
# covariance structure, see
# `build_affine_surrogate_posterior_from_base_distribution`.

# Fit the model.
y = [0.2, 0.5, 0.3, 0.7]
target_model = model.experimental_pin(y=y)
losses = tfp.vi.fit_surrogate_posterior(
    target_model.unnormalized_log_prob,
    surrogate_posterior,
    num_steps=100,
    optimizer=tf.optimizers.Adam(0.1),
    sample_size=10)