是否可以在Tensorflow中对秩为3的矩阵执行稀疏-密集矩阵乘法?

3pvhb19x  于 2023-01-05  发布在  其他
关注(0)|答案(1)|浏览(152)

我尝试在TensorFlow中执行稀疏矩阵-密集矩阵乘法,其中两个矩阵都有一个前导批处理维度我知道TensorFlow为秩为2的矩阵提供了tf.sparse.sparse_dense_matmul函数,但我正在寻找一种方法来处理秩为3的矩阵。TensorFlow中是否有内置函数或方法可以有效地处理这种情况,而不需要昂贵的整形或切片操作?性能在我的应用程序中至关重要。
为了说明我的问题,考虑下面的示例代码:

import tensorflow as tf

# Define sparse and dense matrices with leading batch dimension
sparse_tensor = tf.SparseTensor(indices=[[0, 1, 1], [0, 0, 1], [1, 1, 1], [1, 2, 1], [2, 1, 1]],
                                values=[1, 1, 1, 1, 1],
                                dense_shape=[3, 3, 2])
dense_matrix = tf.constant([[[0.1, 0.2, 0.3, 0.4], [0.5, 0.6, 0.7, 0.8]],
                        [[0.9, 0.10, 0.11, 0.12], [0.13, 0.14, 0.15, 0.16]],
                        [[0.17, 0.18, 0.19, 0.20], [0.21, 0.22, 0.23, 0.24]]], dtype=tf.float32)
        
# Perform sparse x dense matrix multiplication
result = tf.???(sparse_tensor, dense_matrix)  # Result should have shape [3, 3, 4]
bxgwgixi

bxgwgixi1#

在TF中,稀疏和密集乘法只将密集广播到稀疏。否则,batch_sparse_dense_matmul可以简单地通过,

tf.sparse.reduce_sum(tf.sparse.expand_dims(sparse_tensor,-1)*tf.expand_dims(dense_matrix,1), 2)
#[3,3,2,1] * [3,1,2,4] and reduce sum along dim=2
# the above throws error 
# because the last dim of sparse tensor [1] cannot be broadcasted to [4]

为了解决上述问题,我们需要tile稀疏Tensor的最后一个维度,使其为4。

k = 4
tf.sparse.concat(-1,[tf.sparse.expand_dims(sparse_tensor, -1)]*k)
##[3,3,2,4]

总的来说,

tf.sparse.reduce_sum(tf.sparse.concat(-1,[tf.sparse.expand_dims(sparse_tensor, -1)]*k)*tf.expand_dims(dense_matrix,1), 2)
#timeit:1.99ms

另一种方法是使用tf.map_fn

tf.map_fn(
        lambda x: tf.sparse.sparse_dense_matmul(x[0], x[1]),
        elems=(tf.sparse.reorder(tf.cast(sparse_tensor, tf.float32)),dense_matrix ), fn_output_signature=tf.float32
    )
#timeit:4.42ms

相关问题