Split batched nested tensors, on batch dim (outer dim), into a list.
tf_agents.utils.nest_utils.split_nested_tensors(
tensors, specs, num_or_size_splits
)
Args |
tensors
|
Nested list/tuple or dict of batched Tensors.
|
specs
|
Nested list/tuple or dict of TensorSpecs, describing the shape of the
non-batched Tensors.
|
num_or_size_splits
|
Same as argument for tf.split. Either a python integer
indicating the number of splits along batch_dim or a list of integer
Tensors containing the sizes of each output tensor along batch_dim. If a
scalar then it must evenly divide value.shape[axis]; otherwise the sum of
sizes along the split dimension must match that of the value. For
SparseTensor inputs, num_or_size_splits must be the scalar num_split
(see documentation of tf.sparse.split for more details).
|
Returns |
A list of nested non-batched version of each tensor, where each list item
corresponds to one batch item.
|
Raises |
ValueError
|
if the tensors and specs have incompatible dimensions or shapes.
|
ValueError
|
if a non-scalar is passed and there are SparseTensors in the
structure.
|