pytorch scatter_函数学习笔记

x33g5p2x  于2021-11-10 转载在 其他  
字(2.0k)|赞(0)|评价(0)|浏览(358)

在pytorch中,scatter是一个非常实用的映射函数,其将一个源张量(src)中的值按照指定的轴方向(dim)和对应的位置关系(index)逐个填充到目标张量(target)中,其函数写法为:

  1. target.scatter(dim, index, src)

其中各变量及参数的说明如下:

  • target:即目标张量,将在该张量上进行映射
  • src:即源张量,将把该张量上的元素逐个映射到目标张量上
  • dim:指定轴方向,定义了填充方式。对于二维张量,dim=0表示逐列进行行填充,而dim=1表示逐行进行列填充
  • index: 按照轴方向,在target张量中需要填充的位置

dim 0:

把a按顺序(一行一行遍历)给b的索引(index)赋值,index是行编号

实际就是把a的行按照新的顺序赋值给b,顺序就是index行编号。

列子1:

  1. import torch
  2. a = (torch.arange(10) + 1).reshape(5, 2).float()
  3. print(a)
  4. print("-------------------------------------")
  5. b = torch.zeros(5, 3)
  6. b_ = b.scatter(dim=0, index=torch.LongTensor([[4, 2], [3, 0], [2, 0], [1, 0], [0, 0]]), src=a)
  7. print(b_)

结果:

tensor([[ 1.,  2.],
        [ 3.,  4.],
        [ 5.,  6.],
        [ 7.,  8.],
        [ 9., 10.]])

tensor([[ 9., 10.,  0.],
        [ 7.,  0.,  0.],
        [ 5.,  2.,  0.],
        [ 3.,  0.,  0.],
        [ 1.,  0.,  0.]])

例子2:

  1. import torch
  2. a = (torch.arange(10)+1).reshape(2,5).float()
  3. print(a)
  4. print("-------------------------------------")
  5. b = torch.zeros(3, 5)
  6. b_= b.scatter(dim=0, index=torch.LongTensor([[0, 2]]),src=a)
  7. # b 0行 a 1列 1行
  8. # b 2行 a 2列 1行
  9. print(b_)
  10. print("-------------------------------------")
  11. b_= b.scatter(dim=0, index=torch.LongTensor([[0, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
  12. # 7是因为两个 第2行 第4列 值发生覆盖了。
  13. # 第1行 第1列
  14. # 第3行 第2列
  15. # 第2行 第3列
  16. # 第2行 第4列
  17. # 第3行 第5列
  18. # 第3行 第1列
  19. # 第1行 第2列
  20. # 第3行 第3列
  21. # 第2行 第4列
  22. # 第1行 第5列
  23. print(b_)

结果:

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])

tensor([[1., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 2., 0., 0., 0.]])

tensor([[ 1.,  7.,  0.,  0., 10.],
        [ 0.,  0.,  3.,  9.,  0.],
        [ 6.,  2.,  8.,  0.,  5.]])

dim1:

把a按顺序(一行一行遍历)给b的索引(index)赋值,index是列编号

实际就是把a的行重新设置,赋值到b的行上,新的位置,就是index索引(列编号位置)。

  1. import torch
  2. a = (torch.arange(10)+1).reshape(2,5).float()
  3. print(a)
  4. print("-------------------------------------")
  5. b = torch.zeros(3, 5)
  6. b_= b.scatter(dim=1, index=torch.LongTensor([[0, 2]]),src=a)
  7. #b 0列 第1行, a 0列 第1行
  8. #b 2列 第1行 ,a 1列 第1行
  9. print(b_)
  10. print("-------------------------------------")
  11. b_= b.scatter(dim=1, index=torch.LongTensor([[0, 2, 1, 1, 2], [2, 0, 2, 1, 0]]),src=a)
  12. #把a的第1行按顺序 放在b的第1行上,顺序是index
  13. #4的来源:
  14. 0, 2, 1, [1], 2
  15. #把a的第2行按顺序 放在b的第2行上,顺序是index
  16. print(b_)

结果:

tensor([[ 1.,  2.,  3.,  4.,  5.],
        [ 6.,  7.,  8.,  9., 10.]])

tensor([[1., 0., 2., 0., 0.],
        [0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0.]])

tensor([[ 1.,  5.,  2.,  0.,  0.],
        [10.,  9.,  8.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.]])

相关文章