tfp.substrates.jax.math.soft_sorting_matrix

Computes a matrix representing a continuous relaxation of sorting.

Given a vector x, there exists a permutation matrix P_x, when applied to x gives x sorted in decreasing order. Here, we compute a continuous relaxation of P_x, parameterized by temperature. This continuous relaxation satisfies the property that it is a unimodal row-stochastic matrix, meaning that all entries are non-negative, all rows sum to 1., and there is a unique maximum entry in each column. The unique maximum entry will correspond to the location of a 1 in the exact sorting permutation.

Complexity: Given a vector x of size N, this operation will take O(N**2) time.

This is also known as a Neural sort in [1].

x float Tensor. Argument to compute the relaxed sorting matrix with respect to. The relaxed permutation is computed with respect to the last axis.
temperature Positive float Tensor. Whentemperatureapproaches zero, this will retrieve the exact permutation matrix corresponding to sorting from largest to smallest. </td> </tr><tr> <td>name<a id="name"></a> </td> <td> Pythonstrname prefixed to Ops created by this function. Default value:None(i.e.,'soft_sorting_matrix'`).

soft_sort A unimodal row-stochastic matrix. Applying this matrix on x will in the limit of low temperature, sort it.

References

[1]: Aditya Grover, Eric Wang, Aaron Zweig, Stefano Ermon. Stochastic Optimization of Sorting Networks via Continuous Relaxations. https://arxiv.org/abs/1903.08850