在PyTorch中加速SVD

7xzttuei  于 2022-11-23  发布在  其他
关注(0)|答案(3)|浏览(313)

我正在用Pytorch为CIFAR10做一些分类任务,对于每次迭代,我必须在能够前馈到模型之前对每个批次做一些预处理。下面是每个批次的预处理部分的代码:

S = torch.zeros((batch_size, C, H, W))
for i in range(batch_size):
    img = batch[i, :, :, :]
    for c in range(C):                
        U, _, V = torch.svd(img[c])
        S[i, c] = U[:, 0].view(-1, 1).matmul(V[:, 0].view(1, -1))

然而,这个计算非常慢。有没有什么方法可以加快这个代码?

5cnsuln7

5cnsuln71#

批量计算

假设您有PyTorch〉= 1.2.0,那么支持批处理SVD,因此您可以使用

U, _, V = torch.svd(batch)
S = U[:, :, :, 0].unsqueeze(3) @ V[:, :, :, 0].unsqueeze(2)

我发现它平均比迭代版本快一点。

截断SVD(仅CPU)

如果你没有cuda加速,你可以使用截断SVD来避免计算不必要的奇异值/向量。不幸的是,PyTorch不支持截断SVD和AFAIK,没有批处理或GPU版本可用。我知道有两个选择

这两个选项都允许您选择要返回的组件数量。在OP的原始问题中,我们只需要第一个组件。
虽然我没有在稀疏矩阵上使用它,但我发现svdsk=1在CPUTensor上比torch.svd快10倍。我发现randomized_svd只快2倍。您的结果将取决于实际数据。此外,svds应该比randomized_svd更精确一些。请记住,这些结果与torch.svd结果之间会有微小的差异,但这些差异应该可以忽略不计。

import scipy.sparse.linalg as sp
import numpy as np

S = torch.zeros((batch_size, C, H, W))
for i in range(batch_size):
    img = batch[i, :, :, :]
    for c in range(C):
        u, _, v = sp.svds(img[c], k=1)
        S[i, c] = torch.from_numpy(np.outer(u, v))
mo49yndu

mo49yndu2#

PyTorch现在有了speed optimised Linear Algebra operations,类似于numpy的linalg模块,包括torch.linalg.svd
在CPU上实现SVD时,为了提高速度,使用LAPACK例程 ?gesdd(一种分治算法)而不是 ?gesvd。类似地,在GPU上实现SVD时,在CUDA 10.1.243及更高版本上使用cuSOLVER例程 gesvdjgesvdjBatched,在CUDA早期版本上使用MAGMA例程gesdd。

yvgpqqbh

yvgpqqbh3#

有一个非常有趣的工具可以批量操作SVD:批处理svd

相关问题