In some environments, different sets of actions are available given different
observations. To represent this, env.observation actually contains both the
raw observation, and an action mask for this particular observation. Our
network needs to know how to split env.observation into these two parts. The
raw observation will be fed into the wrapped network, and the action mask will
be optionally passed into the wrapped network to ensure that the network only
outputs possible actions.
The network uses the splitter_fn to separate the observation from the action
mask (i.e. observation, mask = splitter_fn(inputs)). Depending on the value
of pass_mask_to_wrapped_network the mask is passed into the wrapped network
or dropped, i.e.
obs, mask = splitter_fn(inputs)
wrapped_network(obs, ...) # If pass_mask_to_wrapped_network is `False`
wrapped_network(obs, ..., mask=mask) # Otherwise, i.e. it is `True`.
In each case the observation part is fed into the wrapped_network. It is
expected that the input spec of wrapped network is compatible with the
observation part of the input of the MaskSplitterNetwork.
Args
splitter_fn
A function used to process observations with action
constraints (i.e. mask). Note: The input spec of the wrapped network
must be compatible with the network-specific half of the output of the
splitter_fn on the input spec.
wrapped_network
A network.Network used to process the network-specific
part of the observation, and the mask passed as the mask parameter of
the method call of the wrapped network.
passthrough_mask
If it is set to True, the mask is fed into wrapped
network. If it is set to False, the mask portion of the input is
dropped and not fed into the wrapped network.
input_tensor_spec
A tensor_spec.TensorSpec or a tuple of specs
representing the input observations including the specs of the action
constraints.
name
A string representing name of the network.
Raises
ValueError
If input_tensor_spec is not an instance of network.InputSpec.
Attributes
input_tensor_spec
Returns the spec of the input to the network of type InputSpec.
layers
Get the list of all (nested) sub-layers used in this Network.
(Optional). Override or provide an input tensor spec
when creating variables.
**kwargs
Other arguments to network.call(), e.g. training=True.
Returns
Output specs - a nested spec calculated from the outputs (excluding any
batch dimensions). If any of the output elements is a tfp Distribution,
the associated spec entry returned is a DistributionSpec.
Raises
ValueError
If no input_tensor_spec is provided, and the network did
not provide one during construction.
Total length of printed lines (e.g. set this to adapt the
display to different terminal window sizes).
positions
Relative or absolute positions of log elements in each line.
If not provided, defaults to [.33, .55, .67, 1.].
print_fn
Print function to use. Defaults to print. It will be called
on each line of the summary. You can set it to a custom function in
order to capture the string summary.