View source on GitHub |
Multi-linear interpolation on a regular (constant spacing) grid.
tfp.substrates.jax.math.batch_interp_regular_nd_grid(
x,
x_ref_min,
x_ref_max,
y_ref,
axis,
fill_value='constant_extension',
name=None
)
Given [a batch of] reference values, this function computes a multi-linear
interpolant and evaluates it on [a batch of] of new x
values. This is a
multi-dimensional generalization of Bilinear Interpolation.
The interpolant is built from reference values indexed by nd
dimensions
of y_ref
, starting at axis
.
The x grid span is defined by x_ref_min
, x_ref_max
. The number of grid
points is inferred from the shape of y_ref
.
For example, take the case of a 2-D
scalar valued function and no leading
batch dimensions. In this case, y_ref.shape = [C1, C2]
and y_ref[i, j]
is the reference value corresponding to grid point
[x_ref_min[0] + i * (x_ref_max[0] - x_ref_min[0]) / (C1 - 1),
x_ref_min[1] + j * (x_ref_max[1] - x_ref_min[1]) / (C2 - 1)]
In the general case, dimensions to the left of axis
in y_ref
are broadcast
with leading dimensions in x
, x_ref_min
, x_ref_max
.
Returns | |
---|---|
y_interp
|
Interpolation between members of y_ref , at points x .
Tensor of same dtype as x , and shape [..., D, B1, ..., BM].
|
Exceptions will be raised if shapes are statically determined to be wrong.
Raises | |
---|---|
ValueError
|
If rank(x) < 2 .
|
ValueError
|
If axis is not a scalar.
|
ValueError
|
If axis + nd > rank(y_ref) .
|
Examples
Interpolate a function of one variable.
y_ref = tf.exp(tf.linspace(start=0., stop=10., num=20))
tfp.math.batch_interp_regular_nd_grid(
# x.shape = [3, 1], x_ref_min/max.shape = [1]. Trailing `1` for `1-D`.
x=[[6.0], [0.5], [3.3]], x_ref_min=[0.], x_ref_max=[10.], y_ref=y_ref,
axis=0)
==> approx [exp(6.0), exp(0.5), exp(3.3)]
Interpolate a scalar function of two variables.
x_ref_min = [0., 0.]
x_ref_max = [2 * np.pi, 2 * np.pi]
# Build y_ref.
x0s, x1s = tf.meshgrid(
tf.linspace(x_ref_min[0], x_ref_max[0], num=100),
tf.linspace(x_ref_min[1], x_ref_max[1], num=100),
indexing='ij')
def func(x0, x1):
return tf.sin(x0) * tf.cos(x1)
y_ref = func(x0s, x1s)
x = 2 * np.pi * tf.random.stateless_uniform(shape=(10, 2))
tfp.math.batch_interp_regular_nd_grid(x, x_ref_min, x_ref_max, y_ref, axis=-2)
==> tf.sin(x[:, 0]) * tf.cos(x[:, 1])