Encodes the `tf.data.Dataset as stacked tensors.
@tf.function
tff.analytics.data_processing.to_stacked_tensor(
ds: tf.data.Dataset
) -> tf.Tensor
This is effectively the inverse of tf.data.Dataset.from_tensor_slices()
.
All elements from the input dataset are concatenated into a tensor structure,
where the output structure matches the input ds.element_spec
, and each
output tensor will have the same shape plus one additional prefix dimension
which elements are stacked in. For example, if the dataset contains 5
elements with shape [3, 2], the returned tensor will have shape [5, 3, 2].
Note that each element in the dataset could be as single tensor or a structure
of tensors.
Dataset elements must have fully-defined shapes. Any partially-defined element
shapes will raise an error. If passing in a batched dataset, use
drop_remainder=True
to ensure the batched shape is fully defined.
Returns |
A structure of tensors encoding the input dataset.
|
Raises |
ValueError
|
If any dataset element shape is not fully-defined.
|