View source on GitHub |
Policy that splits tensors into shards based on their device spec task.
Inherits From: ShardingCallback
Attributes | |
---|---|
description
|
Methods
__call__
__call__(
shardable_tensors: Sequence[tf.train.experimental.ShardableTensor
]
) -> Sequence[sharding_util.TensorSliceDict]
Callback to split tensors into shards based on their device spec task.
Args | |
---|---|
shardable_tensors
|
A list of ShardableTensors. |
Returns | |
---|---|
List of shard dicts containing tensors. [ {checkpoint key: {slice_spec: tensor} } ] |