A TimeStep spec of the expected time_steps. Usually
provided by the user to the subclass.
action_spec
A nest of BoundedTensorSpec representing the actions. Usually
provided by the user to the subclass.
policy_state_spec
A nest of TensorSpec representing the policy_state.
Provided by the subclass, not directly by the user.
info_spec
A nest of TensorSpec representing the policy info. Provided by
the subclass, not directly by the user.
clip
Whether to clip actions to spec before returning them. Default
True. Most policy-based algorithms (PCL, PPO, REINFORCE) use unclipped
continuous actions for training.
emit_log_probability
Emit log-probabilities of actions, if supported. If
True, policy_step.info will have CommonFields.LOG_PROBABILITY set.
Please consult utility methods provided in policy_step for setting and
retrieving these. When working with custom policies, either provide a
dictionary info_spec or a namedtuple with the field 'log_probability'.
automatic_state_reset
If True, then get_initial_policy_state is used
to clear state in action() and distribution() for for time steps
where time_step.is_first().
observation_and_action_constraint_splitter
A function used to process
observations with action constraints. These constraints can indicate,
for example, a mask of valid/invalid actions for a given state of the
environment. The function takes in a full observation and returns a
tuple consisting of 1) the part of the observation intended as input to
the network and 2) the constraint. An example
observation_and_action_constraint_splitter could be as simple as: def observation_and_action_constraint_splitter(observation): return
observation['network_input'], observation['constraint']Note: when
using observation_and_action_constraint_splitter, make sure the
provided q_network is compatible with the network-specific half of the
output of the observation_and_action_constraint_splitter. In
particular, observation_and_action_constraint_splitter will be called
on the observation before passing to the network. If
observation_and_action_constraint_splitter is None, action constraints
are not applied.
validate_args
Python bool. Whether to verify inputs to, and outputs of,
functions like action and distribution against spec structures,
dtypes, and shapes. Research code may prefer to set this value to
False to allow iterating on input and output structures without being
hamstrung by overly rigid checking (at the cost of harder-to-debug
errors). See also TFAgent.validate_args.
name
A name for this module. Defaults to the class name.
Attributes
action_spec
Describes the TensorSpecs of the Tensors expected by step(action).
action can be a single Tensor, or a nested dict, list or tuple of
Tensors.
collect_data_spec
Describes the Tensors written when using this policy with an environment.
emit_log_probability
Whether this policy instance emits log probabilities or not.
info_spec
Describes the Tensors emitted as info by action and distribution.
info can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
observation_and_action_constraint_splitter
policy_state_spec
Describes the Tensors expected by step(_, policy_state).
policy_state can be an empty tuple, a single Tensor, or a nested dict,
list or tuple of Tensors.
policy_step_spec
Describes the output of action().
time_step_spec
Describes the TimeStep tensors returned by step().
trajectory_spec
Describes the Tensors written when using this policy with an environment.
validate_args
Whether action & distribution validate input and output args.
Generates next action given the time_step and policy_state.
Args
time_step
A TimeStep tuple corresponding to time_step_spec().
policy_state
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
seed
Seed to use if action performs sampling (optional).
Returns
A PolicyStep named tuple containing:
action: An action Tensor matching the action_spec.
state: A policy state tensor to be fed into the next call to action.
info: Optional side information such as action log probabilities.
Raises
RuntimeError
If subclass init didn't call super().init.
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec, policy_state_spec,
or policy_step_spec.
Generates the distribution over next actions given the time_step.
Args
time_step
A TimeStep tuple corresponding to time_step_spec().
policy_state
A Tensor, or a nested dict, list or tuple of Tensors
representing the previous policy_state.
Returns
A PolicyStep named tuple containing:
action: A tf.distribution capturing the distribution of next actions.
state: A policy state tensor for the next call to distribution.
info: Optional side information such as action log probabilities.
Raises
ValueError or TypeError: If validate_args is True and inputs or
outputs do not match time_step_spec, policy_state_spec,
or policy_step_spec.