在tensorflow中使用掩码执行矩阵乘法时出错

ia2d9nvy  于 2023-05-01  发布在  其他
关注(0)|答案(1)|浏览(138)

假设val是一个大小为(2,N)的矩阵。我需要将它与大小为mask的掩码矩阵相乘
(K,K)包含不同索引处的值0和1。这应该输出一个大小为(N, K, K)的矩阵result,其中result的每个子矩阵沿着维度0为(K,K)矩阵,其中零被val(i,1)替换,一被val(i,2)替换。
比如说

mask = tf.constant([[0, 1, 0, 1],
                   [1, 0, 0, 1],
                   [1, 1, 1, 0],
                   [0, 1, 0, 0]], dtype=tf.int32)

val = tf.constant([[3, 2, 8, 1, 9, 5, 6], [7, 4, 9, 8, 3, 1, 9]])

那么输出应该是像这样的7 x 4 x 4矩阵,

result  = 

tf.Tensor(
[[[ 3.  7.  3.  7.]
  [ 7.  3.  3.  7.]
  [ 7.  7.  7.  3.]
  [ 3.  7.  3.  3.]]

 [[ 2.  4.  2.  4.]
  [ 4.  2.  2.  4.]
  [ 4.  4.  4.  2.]
  [ 3.  4.  3.  2.]]
          :
          :
          :
 [[ 6.  9.  6.  9.]
  [ 9.  6.  6.  9.]
  [ 9.  9.  9.  6.]
  [ 6.  9.  6.  6.]]]

目前,我使用for循环迭代瓦尔的每一列,以执行以下操作
val[0,i]*mask + val[1,i]*(1-mask)
我希望将其矢量化,以结合tensorflow的矩阵乘法功能。

e0bqpujr

e0bqpujr1#

对于tensorflow,不需要显式的for循环。
编辑:我的代码中有一个错误

import tensorflow as tf

mask = tf.constant([[0, 1, 0, 1],
                    [1, 0, 0, 1],
                    [1, 1, 1, 0],
                    [0, 1, 0, 0]], dtype=tf.int32)

val = tf.constant([[3, 2, 8, 1, 9, 5, 6], [7, 4, 9, 8, 3, 1, 9]])

mask_broadcasted = tf.broadcast_to(tf.expand_dims(mask, axis=0), (val.shape[1], *mask.shape))
mask_inv_broadcasted = 1 - mask_broadcasted

val_expanded = tf.expand_dims(val, axis=-1)  # Add an extra dimension to val (expand the last dimension)
val_expanded_reshaped = tf.reshape(val_expanded, (2, -1, 1, 1))

result = val_expanded_reshaped[0] * mask_inv_broadcasted + val_expanded_reshaped[1] * mask_broadcasted

print(result)

相关问题