Wraps a Keras RNN/LSTM/GRU layer to make network state more consistent.
tf_agents.keras_layers.RNNWrapper(
layer: tf.keras.layers.RNN, **kwargs
)
Args |
layer
|
An instance of tf.keras.layers.RNN or subclasses (including
tf.keras.layers.{LSTM,GRU,...} .
|
**kwargs
|
Extra args to Layer parent class.
|
Raises |
TypeError
|
If layer is not a subclass of tf.keras.layers.RNN .
|
NotImplementedError
|
If layer was created with return_state == False .
|
NotImplementederror
|
If layer was created with
return_sequences == False .
|
Attributes |
cell
|
Return the cell underlying the RNN layer.
|
state_size
|
Return the state_size of the cell underlying the RNN layer.
|
wrapped_layer
|
Return the wrapped RNN layer.
|
Methods
get_initial_state
View source
get_initial_state(
inputs=None
)