Calculates the full beams from the per-step ids and parent beam ids.
tfa.seq2seq.gather_tree(
step_ids: tfa.types.TensorLike
,
parent_ids: tfa.types.TensorLike
,
max_sequence_lengths: tfa.types.TensorLike
,
end_token: tfa.types.Number
) -> tf.Tensor
For a given beam, past the time step containing the first decoded
end_token
all values are filled in with end_token
.
Args |
step_ids
|
The predicted token IDs.
A int32 Tensor of shape [max_time, batch_size, beam_width] .
|
parent_ids
|
The parent beam indices.
A int32 Tensor of shape [max_time, batch_size, beam_width] .
|
max_sequence_lengths
|
The maximum sequence length of each batch.
A int32 Tensor of shape [batch_size] .
|
end_token
|
The end token ID.
|
Returns |
The reordered token IDs based on parent_ids .
|
Raises |
InvalidArgumentError
|
if parent_ids contains an invalid index.
|