Returns the indices of a tensor that give its sorted order along an axis.
tfp.experimental.distributions.marginal_fns.ps.argsort(
values, axis=-1, direction='ASCENDING', stable=False, name=None
)
values = [1, 10, 26.9, 2.8, 166.32, 62.3]
sort_order = tf.argsort(values)
sort_order.numpy()
array([0, 3, 1, 2, 5, 4], dtype=int32)
For a 1D tensor:
sorted = tf.gather(values, sort_order)
assert tf.reduce_all(sorted == tf.sort(values))
For higher dimensions, the output has the same shape as
values
, but along the given axis, values represent the index of the sorted
element in that slice of the tensor at the given position.
mat = [[30,20,10],
[20,10,30],
[10,30,20]]
indices = tf.argsort(mat)
indices.numpy()
array([[2, 1, 0],
[1, 0, 2],
[0, 2, 1]], dtype=int32)
If axis=-1
these indices can be used to apply a sort using tf.gather
:
tf.gather(mat, indices, batch_dims=-1).numpy()
array([[10, 20, 30],
[10, 20, 30],
[10, 20, 30]], dtype=int32)
See also | |
---|---|
|
Returns | |
---|---|
An int32 Tensor with the same shape as values . The indices that would
sort each slice of the given values along the given axis .
|
Raises | |
---|---|
ValueError
|
If axis is not a constant scalar, or the direction is invalid. |
tf.errors.InvalidArgumentError
|
If the values.dtype is not a float or
int type.
|