View source on GitHub |
Facilitates flattening and unflattening batch dims of a tensor.
tf_agents.networks.utils.BatchSquash(
batch_dims
)
Used in the notebooks
Used in the tutorials |
---|
Exposes a pair of matched faltten and unflatten methods. After flattening only 1 batch dimension will be left. This facilitates evaluating networks that expect inputs to have only 1 batch dimension.
Args | |
---|---|
batch_dims
|
Number of batch dimensions the flatten/unflatten ops should handle. |
Raises | |
---|---|
ValueError
|
if batch dims is negative. |
Methods
flatten
flatten(
tensor
)
Flattens and caches the tensor's batch_dims.
unflatten
unflatten(
tensor
)
Unflattens the tensor's batch_dims using the cached shape.