View source on GitHub |
Outputs the position of index
in a permutation of [0, ..., max_index].
tf.random.experimental.index_shuffle(
index, seed, max_index
)
For each possible seed
and max_index
there is one pseudorandom permutation
of the sequence S=[0, ..., max_index]. Instead of materializing the full array
we can compute the new position of any single element in S. This can be useful
for very large max_index
s.
The input index
and output can be used as indices to shuffle a vector.
For example:
vector = tf.constant(['e0', 'e1', 'e2', 'e3'])
indices = tf.random.experimental.index_shuffle(tf.range(4), [5, 9], 3)
shuffled_vector = tf.gather(vector, indices)
print(shuffled_vector)
tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string)
More usefully, it can be used in a streaming (aka online) scenario such as
tf.data
, where each element of vector
is processed individually and the
whole vector
is never materialized in memory.
dataset = tf.data.Dataset.range(10)
dataset = dataset.map(
lambda idx: tf.random.experimental.index_shuffle(idx, [5, 8], 9))
print(list(dataset.as_numpy_iterator()))
[3, 8, 0, 1, 2, 7, 6, 9, 4, 5]
This operation is stateless (like other tf.random.stateless_*
functions),
meaning the output is fully determined by the seed
(other inputs being
equal).
Each seed
choice corresponds to one permutation, so when calling this
function
multiple times for the same shuffling, please make sure to use the same
seed
. For example:
seed = [5, 9]
idx0 = tf.random.experimental.index_shuffle(0, seed, 3)
idx1 = tf.random.experimental.index_shuffle(1, seed, 3)
idx2 = tf.random.experimental.index_shuffle(2, seed, 3)
idx3 = tf.random.experimental.index_shuffle(3, seed, 3)
shuffled_vector = tf.gather(vector, [idx0, idx1, idx2, idx3])
print(shuffled_vector)
tf.Tensor([b'e2' b'e0' b'e1' b'e3'], shape=(4,), dtype=string)
Returns | |
---|---|
If all inputs were scalar (shape [2] for seed ) the output will be a scalar
with the same dtype as index . The output can be seen as the new position
of v in S , or as the index of e in the vector before shuffling.
If one or multiple inputs were vectors (shape [n, 2] for seed ) then the
output will be a vector of the same size which each element shuffled
independently. Scalar values are broadcasted in this case.
|