Add batch dimension if needed to nested tensors while checking their specs.
tf_agents.utils.nest_utils.batch_nested_tensors(
tensors, specs=None
)
If specs is None, a batch dimension is added to each tensor.
If specs are provided, each tensor is compared to the corresponding spec,
and a batch dimension is added only if the tensor doesn't already have it.
For each tensor, it checks the dimensions with respect to specs, and adds an
extra batch dimension if it doesn't already have it.
Args |
tensors
|
Nested list/tuple or dict of Tensors.
|
specs
|
Nested list/tuple or dict of TensorSpecs, describing the shape of the
non-batched Tensors.
|
Returns |
A nested batched version of each tensor.
|
Raises |
ValueError
|
if the tensors and specs have incompatible dimensions or shapes.
|