Meanwhile, with AttentionWrapper, coverage penalty is suggested to use
when computing scores (https://arxiv.org/pdf/1609.08144.pdf). It encourages
the decoder to cover all inputs.
Args
cell
An RNNCell instance.
embedding
A callable that takes a vector tensor of ids (argmax ids),
or the params argument for embedding_lookup.
start_tokens
int32 vector shaped [batch_size], the start tokens.
end_token
int32 scalar, the token that marks end of decoding.
initial_state
A (possibly nested tuple of...) tensors and TensorArrays.
Float weight to penalize length. Disabled with 0.0.
coverage_penalty_weight
Float weight to penalize the coverage of source
sentence. Disabled with 0.0.
reorder_tensor_arrays
If True, TensorArrays' elements within the cell
state will be reordered according to the beam search path. If the
TensorArray can be reordered, the stacked form will be returned.
Otherwise, the TensorArray will be returned as is. Set this flag to
False if the cell state contains TensorArrays that are not amenable
to reordering.
Raises
TypeError
if cell is not an instance of RNNCell,
or output_layer is not an instance of tf.keras.layers.Layer.
ValueError
If start_tokens is not a vector or
end_token is not a scalar.
Attributes
batch_size
output_dtype
A (possibly nested tuple of...) dtype[s].
output_size
tracks_own_finished
The BeamSearchDecoder shuffles its beams and their finished state.
For this reason, it conflicts with the dynamic_decode function's
tracking of finished states. Setting this property to true avoids
early stopping of decoding due to mismanagement of the finished state
in dynamic_decode.
An instance of BeamSearchDecoderState. Passed through to the
output.
sequence_lengths
An int64 tensor shaped [batch_size, beam_width].
The sequence lengths determined for each beam during decode.
NOTE These are ignored; the updated sequence lengths are stored in
final_state.lengths.
Returns
outputs
An instance of FinalBeamSearchDecoderOutput where the
predicted_ids are the result of calling _gather_tree.
final_state
The same input instance of BeamSearchDecoderState.