Returns (batched) matmul of a SparseTensor (or Tensor) with a Tensor.
tfp.substrates.jax.math.sparse_or_dense_matmul(
sparse_or_dense_a, dense_b, validate_args=False, name=None, **kwargs
)
Args |
sparse_or_dense_a
|
SparseTensor or Tensor representing a (batch of)
matrices.
|
dense_b
|
Tensor representing a (batch of) matrices, with the same batch
shape as sparse_or_dense_a . The shape must be compatible with the shape
of sparse_or_dense_a and kwargs.
|
validate_args
|
When True , additional assertions might be embedded in the
graph.
Default value: False (i.e., no graph assertions are added).
|
name
|
Python str prefixed to ops created by this function.
Default value: 'sparse_or_dense_matmul'.
|
**kwargs
|
Keyword arguments to tf.sparse_tensor_dense_matmul or
tf.matmul .
|
Returns |
product
|
A dense (batch of) matrix-shaped Tensor of the same batch shape and
dtype as sparse_or_dense_a and dense_b . If sparse_or_dense_a or
dense_b is adjointed through kwargs then the shape is adjusted
accordingly.
|