pytorch 如何以最快的方式操作具有多个区域的掩码中每个区域的值?

trnvg8h3  于 2023-10-20  发布在  其他
关注(0)|答案(2)|浏览(117)

所以,我正在研究文本本地化的分割模型的一些处理步骤。
简单的问题。我在下面做了几张图片:
1.掩码(默认每个区域的值标记为“1”,其他灰色单元格的值标记为“0”)(1)

所以,我想做一些事情来得到这个Tensor,其中每个区域现在都有它所有值的总和。目标Tensor是下图(2)
1.目标

问题:有没有一种方法可以在矩阵风格中完成我需要的任务(为了更快的计算)或者一些相关的关键字?
我尝试了:我尝试过使用标签坐标来循环OpenCV的findContours(),但我不知道这是这种操作任务的最佳实践。
数组复制

  1. import numpy as np
  2. mask = np.array([[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  3. [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0],
  4. [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
  5. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
  6. [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  7. [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=np.float32)

(Edit:add numpy array for reproducible)

rdrgkggo

rdrgkggo1#

您可以使用scipy.ndimage.label来识别聚类,然后计算每组的总和:

  1. from scipy import ndimage
  2. labels, n = ndimage.label(mask)
  3. out = np.zeros_like(mask, dtype=int)
  4. for i in range(1, n+1):
  5. m = labels == i
  6. out[m] = mask[m].sum()

输出量:

  1. array([[12, 12, 12, 12, 0, 0, 0, 0],
  2. [12, 12, 12, 12, 0, 7, 7, 0],
  3. [12, 12, 12, 12, 0, 7, 7, 7],
  4. [ 0, 0, 0, 0, 0, 0, 7, 7],
  5. [ 0, 5, 5, 5, 0, 0, 0, 0],
  6. [ 0, 5, 5, 0, 0, 0, 0, 0]])
展开查看全部
jgzswidk

jgzswidk2#

这就像一个数据结构问题,你可以使用deep first recursion算法来解决。

  1. def recursion(mask, h, w, i, j):
  2. if status[i][j]==True or mask[i][j]==0.0:
  3. return
  4. status[i][j]=True
  5. if i-1>=0:
  6. recursion(mask, h, w, i-1, j)
  7. if j-1>=0:
  8. recursion(mask, h, w, i, j-1)
  9. if i+1<h:
  10. recursion(mask, h, w, i+1, j)
  11. if j+1<w:
  12. recursion(mask, h, w, i, j+1)
  13. import numpy as np
  14. mask = np.array([[1.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  15. [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 0.0],
  16. [1.0, 1.0, 1.0, 1.0, 0.0, 1.0, 1.0, 1.0],
  17. [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0],
  18. [0.0, 1.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0],
  19. [0.0, 1.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0]], dtype=np.float32)
  20. status = np.zeros_like(mask, np.bool_)
  21. ret = np.copy(mask)
  22. h, w = mask.shape
  23. for i in range(h):
  24. for j in range(w):
  25. if status[i][j]==False and mask[i][j]==1.0:
  26. recursion(mask, h, w, i, j)
  27. kk = np.sum(status)
  28. for ii in range(h):
  29. for jj in range(w):
  30. if status[ii][jj]==True:
  31. ret[ii][jj] = kk
  32. status[ii][jj] = False
  33. mask[ii][jj] = -1
  34. ret
展开查看全部

相关问题