A class with a collection of APIs that can be called in a replica context.
tf.distribute.ReplicaContext(
strategy, replica_id_in_sync_group
)
You can use tf.distribute.get_replica_context
to get an instance of
ReplicaContext
, which can only be called inside the function passed to
tf.distribute.Strategy.run
.
strategy = tf.distribute.MirroredStrategy(['GPU:0', 'GPU:1'])
def func():
replica_context = tf.distribute.get_replica_context()
return replica_context.replica_id_in_sync_group
strategy.run(func)
PerReplica:{
0: <tf.Tensor: shape=(), dtype=int32, numpy=0>,
1: <tf.Tensor: shape=(), dtype=int32, numpy=1>
}
Args | |
---|---|
strategy
|
A tf.distribute.Strategy .
|
replica_id_in_sync_group
|
An integer, a Tensor or None. Prefer an
integer whenever possible to avoid issues with nested tf.function . It
accepts a Tensor only to be compatible with tpu.replicate .
|
Attributes | |
---|---|
devices
|
Returns the devices this replica is to be executed on, as a tuple of strings. (deprecated) |
num_replicas_in_sync
|
Returns number of replicas that are kept in sync. |
replica_id_in_sync_group
|
Returns the id of the replica.
This identifies the replica among all replicas that are kept in sync. The
value of the replica id can range from 0 to
|
strategy
|
The current tf.distribute.Strategy object.
|
Methods
all_gather
all_gather(
value, axis, options=None
)
All-gathers value
across all replicas along axis
.
For all strategies except tf.distribute.TPUStrategy
, the input
value
on different replicas must have the same rank, and their shapes must
be the same in all dimensions except the axis
-th dimension. In other
words, their shapes cannot be different in a dimension d
where d
does
not equal to the axis
argument. For example, given a
tf.distribute.DistributedValues
with component tensors of shape
(1, 2, 3)
and (1, 3, 3)
on two replicas, you can call
all_gather(..., axis=1, ...)
on it, but not all_gather(..., axis=0, ...)
or all_gather(..., axis=2, ...)
. However, with
tf.distribute.TPUStrategy
, all tensors must have exactly the same rank and
same shape.
You can pass in a single tensor to all-gather:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def gather_value():
ctx = tf.distribute.get_replica_context()
local_value = tf.constant([1, 2, 3])
return ctx.all_gather(local_value, axis=0)
result = strategy.run(gather_value)
result
PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}
strategy.experimental_local_results(result)
(<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>,
<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3],
dtype=int32)>)
You can also pass in a nested structure of tensors to all-gather, say, a list:
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
@tf.function
def gather_nest():
ctx = tf.distribute.get_replica_context()
value_1 = tf.constant([1, 2, 3])
value_2 = tf.constant([[1, 2], [3, 4]])
# all_gather a nest of `tf.distribute.DistributedValues`
return ctx.all_gather([value_1, value_2], axis=0)
result = strategy.run(gather_nest)
result
[PerReplica:{
0: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
1: <tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>
}, PerReplica:{
0: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>,
1: <tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>
}]
strategy.experimental_local_results(result)
([<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>],
[<tf.Tensor: shape=(6,), dtype=int32, numpy=array([1, 2, 3, 1, 2, 3], dtype=int32)>,
<tf.Tensor: shape=(4, 2), dtype=int32, numpy=
array([[1, 2],
[3, 4],
[1, 2],
[3, 4]], dtype=int32)>])
What if you are all-gathering tensors with different shapes on different
replicas? Consider the following example with two replicas, where you have
value
as a nested structure consisting of two items to all-gather, a
and
b
.
On Replica 0, value
is {'a': [0], 'b': [[0, 1]]}
.
On Replica 1, value
is {'a': [1], 'b': [[2, 3], [4, 5]]}
.
Result for all_gather
with axis
=0 (on each of the replicas) is:
```{'a': [1, 2], 'b': [[0, 1], [2, 3], [4, 5]]}```
Args | |
---|---|
value
|
a nested structure of tf.Tensor which tf.nest.flatten accepts,
or a tf.distribute.DistributedValues instance. The structure of the
tf.Tensor need to be same on all replicas. The underlying tensor
constructs can only be dense tensors with non-zero rank, NOT
tf.IndexedSlices .
|
axis
|
0-D int32 Tensor. Dimension along which to gather. |
options
|
a tf.distribute.experimental.CommunicationOptions . Options to
perform collective operations. This overrides the default options if the
tf.distribute.Strategy takes one in the constructor. See
tf.distribute.experimental.CommunicationOptions for details of the
options.
|
Returns | |
---|---|
A nested structure of tf.Tensor with the gathered values. The structure
is the same as value .
|
all_reduce
all_reduce(
reduce_op, value, options=None
)
All-reduces value
across all replicas.
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def step_fn():
ctx = tf.distribute.get_replica_context()
value = tf.identity(1.)
return ctx.all_reduce(tf.distribute.ReduceOp.SUM, value)
strategy.experimental_local_results(strategy.run(step_fn))
(<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=2.0>)
It supports batched operations. You can pass a list of values and it
attempts to batch them when possible. You can also specify options
to indicate the desired batching behavior, e.g. batch the values into
multiple packs so that they can better overlap with computations.
strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])
def step_fn():
ctx = tf.distribute.get_replica_context()
value1 = tf.identity(1.)
value2 = tf.identity(2.)
return ctx.all_reduce(tf.distribute.ReduceOp.SUM, [value1, value2])
strategy.experimental_local_results(strategy.run(step_fn))
([<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>],
[<tf.Tensor: shape=(), dtype=float32, numpy=2.0>,
<tf.Tensor: shape=(), dtype=float32, numpy=4.0>])
Note that all replicas need to participate in the all-reduce, otherwise this operation hangs. Note that if there're multiple all-reduces, they need to execute in the same order on all replicas. Dispatching all-reduce based on conditions is usually error-prone.
Known limitation: if value
contains tf.IndexedSlices
, attempting to
compute gradient w.r.t value
would result in an error.
This API currently can only be called in the replica context. Other
variants to reduce values across replicas are:
* tf.distribute.StrategyExtended.reduce_to
: the reduce and all-reduce API
in the cross-replica context.
* tf.distribute.StrategyExtended.batch_reduce_to
: the batched reduce and
all-reduce API in the cross-replica context.
* tf.distribute.Strategy.reduce
: a more convenient method to reduce
to the host in cross-replica context.
Args | |
---|---|
reduce_op
|
a tf.distribute.ReduceOp value specifying how values should
be combined. Allows using string representation of the enum such as
"SUM", "MEAN".
|
value
|
a potentially nested structure of tf.Tensor or tf.IndexedSlices which
tf.nest.flatten accepts. The structure and the shapes of value need to be
same on all replicas.
|
options
|
a tf.distribute.experimental.CommunicationOptions . Options to
perform collective operations. This overrides the default options if the
tf.distribute.Strategy takes one in the constructor. See
tf.distribute.experimental.CommunicationOptions for details of the
options.
|
Returns | |
---|---|
A nested structure of tf.Tensor with the reduced values. The structure
is the same as value .
|
merge_call
merge_call(
merge_fn, args=(), kwargs=None
)
Merge args across replicas and run merge_fn
in a cross-replica context.
This allows communication and coordination when there are multiple calls
to the step_fn triggered by a call to strategy.run(step_fn, ...)
.
See tf.distribute.Strategy.run
for an explanation.
If not inside a distributed scope, this is equivalent to:
strategy = tf.distribute.get_strategy()
with cross-replica-context(strategy):
return merge_fn(strategy, *args, **kwargs)
```
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2">Args</th></tr>
<tr>
<td>
`merge_fn`
</td>
<td>
Function that joins arguments from threads that are given as
PerReplica. It accepts <a href="../../tf/distribute/Strategy"><code>tf.distribute.Strategy</code></a> object as
the first argument.
</td>
</tr><tr>
<td>
`args`
</td>
<td>
List or tuple with positional per-thread arguments for `merge_fn`.
</td>
</tr><tr>
<td>
`kwargs`
</td>
<td>
Dict with keyword per-thread arguments for `merge_fn`.
</td>
</tr>
</table>
<!-- Tabular view -->
<table class="responsive fixed orange">
<colgroup><col width="214px"><col></colgroup>
<tr><th colspan="2">Returns</th></tr>
<tr class="alt">
<td colspan="2">
The return value of `merge_fn`, except for `PerReplica` values which are
unpacked.
</td>
</tr>
</table>