tf.train.experimental.ShardingCallback

Checkpoint sharding callback function, along with a text description.

A callback function wrapper that will be executed to determine how tensors will be split into shards when the saver writes the checkpoint shards to disk.

The callback takes a list of tf.train.experimental.ShardableTensors as input (as well as any kwargs defined by the tf.train.experimental.ShardingCallback subclass), and organizes the input tensors into different shards. Tensors are first organized by device task (see tf.DeviceSpec), then the callback will be called for each collection of tensors.

There are a few restrictions to keep in mind when creating a custom callback:

  • Tensors must not be removed from the checkpoint.
  • Tensors must not be reshaped.
  • Tensor dtypes must not change.
  • Tensors within a shard must belong to the same task. Validation checks will be performed after the callback function is executed to ensure these restrictions aren't violated.

Here's an example of a simple custom callback:

# Place all tensors in a single shard.
class AllInOnePolicy(tf.train.experimental.ShardingCallback):
  @property
  def description(self):
    return "Place all tensors in a single shard."

  def __call__(self, shardable_tensors):
    tensors = {}
    for shardable_tensor in shardable_tensors:
      tensor = shardable_tensor.tensor_save_spec.tensor
      checkpoint_key = shardable_tensor.checkpoint_key
      slice_spec = shardable_tensor.slice_spec

      tensors.set_default(checkpoint_key, {})[slice_spec] = tensor
    return [tensors]

ckpt.save(
    "path",
    options=tf.train.CheckpointOptions(
        experimental_sharding_callback=AllInOnePolicy()))

The description attribute is used to identify the callback and to aid debugging during saving and restoration.

To take in kwargs, simply define the constructor and pass them in:

class ParameterPolicy(tf.train.experimental.ShardingCallback):
  def __init__(self, custom_param):
    self.custom_param = custom_param
  ...

ckpt.save(
    "path",
    options=tf.train.CheckpointOptions(
        experimental_sharding_callback=ParameterPolicy(custom_param=...)))

description

Methods

__call__

View source

Call self as a function.