Tensorflow -使用自定义比较器排序Tensor

zkure5ic  于 2023-03-09  发布在  其他
关注(0)|答案(3)|浏览(148)

如何根据仅使用Tensorflow运算的自定义比较函数对形状为[n,2]的整数TensorflowTensor进行排序?
假设Tensor中有两个元素[x1,y1]和[x2,y2],我想对Tensor进行排序,使元素按条件x1* y2〉x2 * y1重新排序。

axr492tv

axr492tv1#

假设您可以为元素创建一个指标(如果不能,请参见下面的一般情况)(这里,将不等式重新排列为***x1/y1〉x2/y2***,因此指标将为***x/y***,并依赖TensorFlow生成***inf***(无穷大)以除以零),请使用tf.nn.top_k(),如下所示代码(已测试):

import tensorflow as tf

x = tf.constant( [ [1,2], [3,4], [1,3], [2,5] ] ) # some example numbers

s = tf.truediv( x[ ..., 0 ], x[ ..., 1 ] ) # your sort condition
val, idx = tf.nn.top_k( s, x.get_shape()[ 0 ].value )
x_sorted = tf.gather( x, idx )

with tf.Session() as sess:
    print( sess.run( x_sorted ) )

输出:
[[3 4]
[1和2]
[二、五]
[13]]
如果你不能或者不容易创建一个度量,那么仍然假设这个关系给你一个well-ordering(否则结果是不确定的)。在这种情况下,你为整个集合构建比较矩阵,并按行和(即有多少其他元素更大)对元素排序;这当然是要排序的元素数量的二次方。2这段代码(经过测试):

import tensorflow as tf

x = tf.constant( [ [1,2], [3,4], [1,3], [2,5] ] ) # some example numbers

x1, y1 = x[ ..., 0 ][ None, ... ], x[ ..., 1 ][ None, ... ] # expanding dims into cols
x2, y2 = x[ ..., 0, None ],        x[ ..., 1, None ] # expanding into rows
r = tf.cast( tf.less( x1 * y2, x2 * y1 ), tf.int32 ) # your sort condition, with implicit broadcasting
s = tf.reduce_sum( r, axis = 1 ) # how many other elements are greater

val, idx = tf.nn.top_k( s, s.get_shape()[ 0 ].value )
x_sorted = tf.gather( x, idx )

with tf.Session() as sess:
    print( sess.run( x_sorted ) )

输出:
[[3 4]
[1和2]
[二、五]
[13]]

olhwl3o2

olhwl3o22#

作为top_k彼得· solr 丹答案的替代,在1.13之后有一个tf.argsort
对于〈1.13,请使用

tf.nn.top_k( s, tf.shape(x)[0] )

如果无法在静态图形中获取形状。

xmq68pz9

xmq68pz93#

如果你想根据一个掩码来排列一个输入Tensor,在某种程度上,掩码包含了你感兴趣的输入Tensor的哪些索引的信息,那么这可能对你有用-

import tensorflow as tf
inp = tf.constant([
    [7, 0, 3, 0, 5, 0],
    [0, 7, 0, 9, 0, 2]])

mask = tf.constant([
    [1, 0, 1, 0, 1, 0],
    [0, 1, 0, 1, 0, 1]])

order = tf.argsort(mask, axis=-1, direction='DESCENDING', stable=True)
out = tf.gather(inp, order, batch_dims=-1)
print(out)

# tf.Tensor(
# [[7 3 5 0 0 0]
#  [7 9 2 0 0 0]], shape=(2, 6), dtype=int32)

在掩码中,我们已经编码了关于我们感兴趣排序的索引的信息。

相关问题