找不到有效的PyTorch或NumPy广播来避免瓶颈操作

yrdbyhpb  于 2023-03-02  发布在  其他
关注(0)|答案(1)|浏览(149)

我在我的基于PyTorch的代码中有以下实现,它涉及到一个嵌套的for循环。嵌套的for循环沿着if条件使代码执行起来非常慢。我试图避免嵌套循环,以涉及NumPy和PyTorch中的广播概念,但没有产生任何结果。任何关于避免for循环的帮助将不胜感激。
下面是我读过的链接PyTorchNumPy

#!/usr/bin/env python
# coding: utf-8

import torch

batch_size=32
mask=torch.FloatTensor(batch_size).uniform_() > 0.8

teacher_count=510
student_count=420
feature_dim=750
student_output=torch.zeros([batch_size,student_count])
teacher_output=torch.zeros([batch_size,teacher_count])

student_adjacency_mat=torch.randint(0,1,(student_count,student_count))
teacher_adjacency_mat=torch.randint(0,1,(teacher_count,teacher_count))

student_feat=torch.rand([batch_size,feature_dim])
student_graph=torch.rand([student_count,feature_dim])
teacher_feat=torch.rand([batch_size,feature_dim])
teacher_graph=torch.rand([teacher_count,feature_dim])

for m in range(batch_size):
    if mask[m]==1:
        for i in range(student_count):
            for j in range(student_count):
                student_output[m][i]=student_output[m][i]+student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])
    if mask[m]==0:
        for i in range(teacher_count):
            for j in range(teacher_count):
                teacher_output[m][i]=teacher_output[m][i]+teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])
oxcyiej7

oxcyiej71#

问题陈述

您要执行的操作非常简单。如果您仔细查看循环:

for m in range(batch_size):
  if mask[m]==1:
    for i in range(student_count):
      for j in range(student_count):
         student_output[m][i] += student_adjacency_mat[i][j]*torch.dot(student_feat[m],student_graph[j])

   if mask[m]==0:
     for i in range(teacher_count):
       for j in range(teacher_count):
         teacher_output[m][i] += teacher_adjacency_mat[i][j]*torch.dot(teacher_feat[m],teacher_graph[j])

与我们相关的要素:

  • 您有两个基于掩码的独立操作,掩码最终可以单独计算。
  • 每个操作都在相邻矩阵之间循环,* 即 * student_count²
  • 赋值操作归结为
output[m,i] += adj_matrix[i,j] * <feats[m] / graph[j]>

其中adj_matrix[i,j]是标量。

使用torch.einsum

这是torch.einsum的一个典型用例,您可以阅读更多关于this thread的内容,我碰巧也写过an answer
如果我们不考虑所有的实现细节,那么torch.einsum的公式是相当不言自明的:

o = torch.einsum('ij,mf,jf->mi', adj_matrix, feats, graph)

在伪代码中,这可以归结为:

o[m,i] += adj_matrix[i,j]*feats[m,f]*graph[j,f]

对于您感兴趣域中的所有ijmf
结合使用M = mask[:,None]扩展到适当形式的掩码,这为学生Tensor提供:

>>> student = M*torch.einsum('ij,mf,jf->mi', student_adjacency_mat, student_feat, student_graph)

对于教师结果,您可以使用~M反转掩码:

>>> teacher = ~M*torch.einsum('ij,mf,jf->mi', teacher_adjacency_mat, teacher_feat, teacher_graph)
使用torch.matmul

或者,由于这是torch.einsum的一个相当简单的应用程序,因此您也可以通过两次调用torch.matmul来避免。给定AB,这两个矩阵分别由ikkj索引,您将得到A@B,它对应于ik@kj -> ij。因此,您可以通过下式获得所需的结果:

>>> g = student_feat@student_graph.T # mf@jf.T -> mf@fj -> mj
>>> g@student_adjacency_mat.T        # mj@ij.T -> mj@ji -> mi

看看这两个步骤与torch的关系。einsum调用'ij,mf,jf-〉mi'。首先是mf,jf->mj,然后是mj,ij->mi
边注您当前的虚拟学生和教师邻接矩阵初始化为零。

student_adjacency_mat=torch.randint(0,2,(student_count,student_count)).float()
teacher_adjacency_mat=torch.randint(0,2,(teacher_count,teacher_count)).float()

相关问题