在pytorch中有bitcount操作吗

yruzcnhs  于 2023-10-20  发布在  其他
关注(0)|答案(3)|浏览(165)

我尝试用pytorch写一个简单的汉明距离函数:

xorimg = torch.bitwise_xor(img1,img2)
for i in range(bitlen) : 
        hdist = hdist + (xorimg & 1)
        xorimg = xorimg >> 1

我想知道是否有一个简单的bitcount操作来计算1s位以摆脱for循环
例如:

xorimg = torch.bitwise_xor(img1,img2)
hdist = torch.bitcount(xorimg)

或者任何其他等效的方法来摆脱耗时的for循环?
或pytorch直接支持汉明距离,如:

hdist = torch.hamming(img1,img2)

那就更好了
谢谢你的帮忙。

31moq8wy

31moq8wy1#

我发布了第二个答案,因为我有一些时间来思考这个问题,我意识到你可能可以很容易地自己提供一个本地实现,即使PyTorch没有。
这是C语言的源代码:

#include <stdlib.h>

void bitcount(unsigned long long *buf, size_t len) {
  for (size_t i = 0; i < len; ++i) {
    buf[i] = __builtin_popcountll(buf[i]);
  }
}

编译工具:

cc -O3 -march=nehalem bitcount.c -fPIC -shared -o bitcount.so

(Note 2008年发布的Intel Nehalem CPU架构是第一个支持POPCNT指令的架构。
我在一个包含1000万个随机64位整数的数组上测试了性能:

import numpy as np

def Bitcount1(a):
  counts = a & 1
  for i in range(1, 64):
    counts += (a >> i) & 1
  return counts

# More efficient versions from Wikipedia
# https://en.wikipedia.org/wiki/Hamming_weight

m1  = 0x5555555555555555
m2  = 0x3333333333333333
m4  = 0x0f0f0f0f0f0f0f0f
m8  = 0x00ff00ff00ff00ff
m16 = 0x0000ffff0000ffff
m32 = 0x00000000ffffffff
h01 = 0x0101010101010101

def Bitcount2(a):
  a = (a & m1 ) + ((a >>  1) & m1 )
  a = (a & m2 ) + ((a >>  2) & m2 )
  a = (a & m4 ) + ((a >>  4) & m4 )
  a = (a & m8 ) + ((a >>  8) & m8 )
  a = (a & m16) + ((a >> 16) & m16)
  a = (a & m32) + ((a >> 32) & m32)
  return a

def Bitcount3(a):
  a = a - ((a >> 1) & m1)
  a = (a & m2) + ((a >> 2) & m2)
  a = (a + (a >> 4)) & m4
  return (a * h01) >> 56

# Native solution via ctypes
import ctypes
lib = ctypes.cdll.LoadLibrary('./bitcount.so')
lib.bitcount.restype = None
lib.bitcount.argtypes = (np.ctypeslib.ndpointer(dtype=np.ulonglong, flags=('ALIGNED', 'CONTIGUOUS', 'WRITEABLE')), ctypes.c_size_t)

def Bitcount4(a):
  a = np.copy(a, order='C')
  lib.bitcount(a, a.size)
  return a

#
# Test code follows
#

def GenTestArray(n):
  return np.random.default_rng().integers(0, 2**64, n, dtype=np.uint64)

a=GenTestArray(10_000_000)
r1 = Bitcount1(a)
r2 = Bitcount2(a)
r3 = Bitcount3(a)
r4 = Bitcount4(a)

assert (r1 == r2).all()
assert (r1 == r3).all()
assert (r1 == r4).all()

from timeit import timeit

print(timeit('Bitcount1(a)', globals=globals(), number=1))
print(timeit('Bitcount2(a)', globals=globals(), number=1))
print(timeit('Bitcount3(a)', globals=globals(), number=1))
print(timeit('Bitcount4(a)', globals=globals(), number=1))

结果如下(执行时间以秒为单位):

2.3275568269891664    # Bitcount1()
0.31765274197096005   # Bitcount2()
0.16485577001003549   # Bitcount3()
0.023819021007511765  # Bitcount4()

当然,这些时间在运行之间会有所不同,这取决于机器,但我认为整体情况非常清楚:虽然Bitcount2()比Bitcount1()快8倍(接近我的预测),Bitcount3()又快了近两倍,但本机解决方案将所有其他解决方案都打得落花流水:它比原始方法快了近100倍,比次优解决方案快了7倍。
因此,如果你想真正快速地执行位计数,你必须加载一个本地库来执行。
注意:上面的解决方案使用了numpy数组,而你问的是PyTorch。但是你可能可以使用相同的方法,因为Tensors和numpy数组非常相似,你通常可以很便宜地在它们之间转换。参见:https://www.tensorflow.org/tutorials/customization/basics#numpy_compatibility

ztmd8pv5

ztmd8pv52#

参见答案here
如果要在cupy部分使用cuda内部函数__popc(),则替换以下while循环(bitcount):

while(x != 0){
  x = x & (x - 1);
  dist[elem_idx]++;
}

使用cuda的int32函数:

dist[elem_idx] = __popc(x);

或者在torch.py和cupy.py中稍微修改一下int64:

dist[elem_idx] = __popcll(x);
fykwrbwg

fykwrbwg3#

我不能给予一个明确的答案,关于这是否可以用PyTorch原生地完成(尽管一些快速的谷歌搜索建议不可以),但是如果你想加速你的循环,计数位是一个众所周知的问题,有合理的标准解决方案。这里列出了一些算法:https://graphics.stanford.edu/~seander/bithacks.html#CountBitsSetParallel和这里:https://en.wikipedia.org/wiki/Hamming_weight
维基百科上标记为popcount64a()的函数应该很容易移植到Python,并在24次操作中计算64位整数的位数,这将比当前版本快8倍,如果bitlen=64,则执行192次操作。

相关问题