Keras后端开关与tf.where结合使用,无法按预期工作

mftmpeh8  于 2022-11-24  发布在  其他
关注(0)|答案(1)|浏览(123)

我有一个自定义的损失函数,我想把值从基于一位热码的编码改为特定范围内的值,以计算IOU。
这段代码的一部分是看在Tensor中哪里有1,否则就有0。为此,我使用tf.where返回位置。我有一个形状为[batch_size,S1,S2,12]的向量,我只关心最后一个维度,这就是为什么我使用tf. where的[...,2]。
我的预测经常是全零,因为我的背景事件中没有任何值,而且我的网络会时不时地预测一个全零向量,这意味着tf.where会返回一个空Tensor,这就是为什么我想用K.switch来检查Tensor是否是空的,因为如果是的话,我想返回零。
现在的问题是,K.switch希望then else选项的shape具有相同的shape,但我需要输出具有shape [batch_size,S1,S2,1]。我尝试了不同的方法,但我无法使其工作。我需要得到shape [batch_size,S1,S2,1]的零值,或者我需要where_box1来使[batch_size,S1,S2,1]具有浮点数。
当where_box1_temp为空时,K.switch返回一个空的零向量,这不是我想要的。当我使用tf.zeros([batch_size,S1,S2,1])时,它会抱怨当where_box1_temp为空时,条件的形状不同....

where_box1_temp = tf.where(y_pred[...,C+1:C+13])[...,2]

where_box1 = K.switch(tf.equal(tf.size(where_box1_temp),0) , 
                          tf.zeros_like(where_box1_temp) , where_box1_temp)
mqxuamgl

mqxuamgl1#

于是我找到了一个变通办法,也许这对别人是有帮助的:

where_box1_temp = tf.where(y_pred[...,C+1:C+13],[1,2,3,4,5,6,7,8,9,10,11,12],0)

where_box1 = tf.reshape(K.sum(where_box1_temp,axis=3),[batch_size,5,5])

这允许我有一个我想要的形状的Tensor,其中所有的背景/零预测值都是0,而不必使用k。开关,也不必有任何空维度或类似的麻烦。

相关问题