tf.data.experimental.load

View source on GitHub

Loads a previously saved dataset.

Example usage:

import tempfile
path = os.path.join(tempfile.gettempdir(), "saved_data")
# Save a dataset
dataset = tf.data.Dataset.range(2)
tf.data.experimental.save(dataset, path)
new_dataset = tf.data.experimental.load(path,
    tf.TensorSpec(shape=(), dtype=tf.int64))
for elem in new_dataset:
  print(elem)
tf.Tensor(0, shape=(), dtype=int64)
tf.Tensor(1, shape=(), dtype=int64)

Note that to load a previously saved dataset, you need to specify element_spec -- a type signature of the elements of the saved dataset, which can be obtained via tf.data.Dataset.element_spec. This requirement exists so that shape inference of the loaded dataset does not need to perform I/O.

If the default option of sharding the saved dataset was used, the element order of the saved dataset will be preserved when loading it.

The reader_func argument can be used to specify a custom order in which elements should be loaded from the individual shards. The reader_func is expected to take a single argument -- a dataset of datasets, each containing elements of one of the shards -- and return a dataset of elements. For example, the order of shards can be shuffled when loading them as follows:

def custom_reader_func(datasets):
  datasets = datasets.shuffle(NUM_SHARDS)
  return datasets.interleave(lambda x: x, num_parallel_calls=AUTOTUNE)

dataset = tf.data.experimental.load(
    path="/path/to/data", ..., reader_func=custom_reader_func)

path Required. A path pointing to a previously saved dataset.
element_spec Required. A nested structure of tf.TypeSpec objects matching the structure of an element of the saved dataset and specifying the type of individual element components.
compression Optional. The algorithm to use to decompress the data when reading it. Supported options are GZIP and NONE. Defaults to NONE.
reader_func Optional. A function to control how to read data from shards. If present, the function will be traced and executed as graph computation.

A tf.data.Dataset instance.