Compares tensors to specs to determine if all tensors are batched or not.
tf_agents.utils.nest_utils.is_batched_nested_tensors(
tensors,
specs,
num_outer_dims=1,
allow_extra_fields=False,
check_dtypes=True
)
For each tensor, it checks the dimensions and dtypes with respect to specs.
Returns True
if all tensors are batched and False
if all tensors are
unbatched.
Raises a ValueError
if the shapes are incompatible or a mix of batched and
unbatched tensors are provided.
Raises a TypeError
if tensors' dtypes do not match specs.
Args |
tensors
|
Nested list/tuple/dict of Tensors.
|
specs
|
Nested list/tuple/dict of Tensors or CompositeTensors describing the
shape of unbatched tensors.
|
num_outer_dims
|
The integer number of dimensions that are considered batch
dimensions. Default 1.
|
allow_extra_fields
|
If True , then tensors may have extra subfields which
are not in specs. In this case, the extra subfields will not be checked.
For example: python tensors = {"a": tf.zeros((3, 4),
dtype=tf.float32), "b": tf.zeros((5, 6), dtype=tf.float32)} specs = {"a":
tf.TensorSpec(shape=(4,), dtype=tf.float32)} assert
is_batched_nested_tensors(tensors, specs, allow_extra_fields=True) The
above example would raise a ValueError if allow_extra_fields was False.
|
check_dtypes
|
If True will validate that tensors and specs have the same
dtypes.
|
Returns |
True if all Tensors are batched and False if all Tensors are unbatched.
|
Raises |
ValueError
|
If
- Any of the tensors or specs have shapes with ndims == None, or
- The shape of Tensors are not compatible with specs, or
- A mix of batched and unbatched tensors are provided.
- The tensors are batched but have an incorrect number of outer dims.
|
TypeError
|
If dtypes between tensors and specs are not compatible.
|