Look for spec
's shape, check that outer dim is 1, and remove it.
tf_agents.utils.nest_utils.remove_singleton_batch_spec_dim(
spec: tf.TypeSpec, outer_ndim: int
) -> tf.TypeSpec
If spec.shape[i] != 1
for any i in range(outer_ndim)
, we stop removing
singleton batch dimensions at i
and return what's left. This is necessary
to handle the outputs of inconsistent layers like tf.keras.layers.LSTM()
which may take as input (batch, time, dim) = (1, 1, Nin)
and emits only the
batch entry if time == 1
: output shape is (1, Nout)
. We log an error
in these cases.
Args |
spec
|
A tf.TypeSpec .
|
outer_ndim
|
The maximum number of outer singleton dims to remove.
|
Returns |
A tf.TypeSpec , the spec without its outer batch dimension(s).
|
Raises |
ValueError
|
If spec lacks a shape property.
|