我希望得到tf.math.bincount中的最大/最小值,而不是权重和。基本上,目前它的工作方式如下:
tf.math.bincount
values = tf.constant([1,1,2,3,2,4,4,5]) weights = tf.constant([1,5,0,1,0,5,4,5]) tf.math.bincount(values, weights=weights) #[0 6 0 1 9 5]
但是,我想获取冲突权重的最大值/最小值,例如,对于max,应返回:[0 5 0 1 5 5]个
[0 5 0 1 5 5]
8e2ybdfx1#
这需要一些技巧,但您可以按如下方式完成:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor: _range = tf.range(tf.reduce_max(values) + 1) return tf.map_fn(lambda x: tf.maximum( tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), 0), _range)
示例案例的输出为:
第一行计算values中的值的范围:
values
_range = tf.range(tf.reduce_max(values) + 1)
并且在第二行中,使用tf.map_fn和tf.where来计算weight的每个_range中的元素的最大值,其中tf.where检索子句为真的索引,tf.gather检索与所提供的索引相对应的值。tf.maximum Package 输出以处理values中不存在该元素的情况,即:在示例情况下,0不存在于values中,因此没有tf.maximum的输出对于0:
tf.map_fn
tf.where
weight
_range
tf.gather
tf.maximum
0
[-2147483648 5 0 1 5 5]
这也可以应用于最终结果Tensor而不是每个元素:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor: _range = tf.range(tf.reduce_max(values) + 1) result = tf.map_fn(lambda x: tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), _range) return tf.maximum(result, 0)
请注意,如果使用负权重,则这将不起作用-在这种情况下,可以使用tf.where来与最小整数值进行比较(在示例中为tf.int32.min,但这可以应用于任何数字dtype),而不是应用tf.maximum:
tf.int32.min
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor) -> tf.Tensor: _range = tf.range(tf.reduce_max(values) + 1) result = tf.map_fn(lambda x: tf.reduce_max(tf.gather(weights, tf.where(tf.equal(values, x)))), _range) return tf.where(tf.equal(result, tf.int32.min), 0, result)
对于处理2DTensor的情况,我们可以使用tf.map_fn来将最大权重函数应用于批处理中的每对值和权重:
def bincount_with_max_weight(values: tf.Tensor, weights: tf.Tensor, axis: Optional[int] = None) -> tf.Tensor: _range = tf.range(tf.reduce_max(values) + 1) def mapping_function(x: int, _values: tf.Tensor, _weights: tf.Tensor) -> tf.Tensor: return tf.reduce_max(tf.gather(_weights, tf.where(tf.equal(_values, x)))) if axis == -1: result = tf.map_fn(lambda pair: tf.map_fn(lambda x: mapping_function(x, *pair), _range), (values, weights), dtype=tf.int32) else: result = tf.map_fn(lambda x: mapping_function(x, values, weights), _range) return tf.where(tf.equal(result, tf.int32.min), 0, result)
对于提供的2D示例:
values = tf.constant([[1, 1, 2, 3], [2, 1, 4, 5]]) weights = tf.constant([[1, 5, 0, 1], [0, 5, 4, 5]]) print(bincount_with_max_weight(values, weights, axis=-1))
输出为:
tf.Tensor( [[0 5 0 1 0 0] [0 5 0 0 4 5]], shape=(2, 6), dtype=int32)
此实现是最初描述的方法的一般化-如果省略axis,它将计算1D情况的结果。
axis
33qvvth12#
要获得更快的执行速度,
values = tf.constant([[1,1,2,3], [2,1,4,5]]) weights = tf.constant([[1,5,0,1], [0,5,4,5]]) def find_max_bins(output , values , weights): np.maximum.at(output , values , weights) return output @tf.function(input_signature=[tf.TensorSpec(shape=[None], dtype = tf.float32), tf.TensorSpec(shape=[None], dtype = tf.int32), tf.TensorSpec(shape=[None], dtype = tf.int32) ]) def tf_function(output , values , weights): print(values) y = tf.numpy_function(find_max_bins, [output , values , weights], tf.float32) return y length = np.max(values)+1 initial_value = [0 for x in range(length)] variable = tf.Variable(initial_value = initial_value, shape=(length) , dtype=tf.float32) for i , (value , weight) in enumerate(zip(values , weights)): if(i > 0): output = tf.stack([output , tf_function(variable , value , weight)] , 0) else: output = tf_function(variable , value , weight) variable.assign_sub(initial_value)
输出量:
<tf.Tensor: shape=(2, 6), dtype=float32, numpy= array([[0., 5., 0., 1., 0., 0.], [0., 5., 0., 0., 4., 5.]], dtype=float32)>
2条答案
按热度按时间8e2ybdfx1#
这需要一些技巧,但您可以按如下方式完成:
示例案例的输出为:
第一行计算
values
中的值的范围:并且在第二行中,使用
tf.map_fn
和tf.where
来计算weight
的每个_range
中的元素的最大值,其中tf.where
检索子句为真的索引,tf.gather
检索与所提供的索引相对应的值。tf.maximum
Package 输出以处理values
中不存在该元素的情况,即:在示例情况下,0
不存在于values
中,因此没有tf.maximum
的输出对于0:这也可以应用于最终结果Tensor而不是每个元素:
请注意,如果使用负权重,则这将不起作用-在这种情况下,可以使用
tf.where
来与最小整数值进行比较(在示例中为tf.int32.min
,但这可以应用于任何数字dtype),而不是应用tf.maximum
:更新
对于处理2DTensor的情况,我们可以使用
tf.map_fn
来将最大权重函数应用于批处理中的每对值和权重:对于提供的2D示例:
输出为:
此实现是最初描述的方法的一般化-如果省略
axis
,它将计算1D情况的结果。33qvvth12#
要获得更快的执行速度,
输出量: