Pytorch:转换为单热点矢量表示

g6ll5ycj  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(120)

我正在CIFAR-10上训练一个acGAN。到目前为止,相应的标签由它们的数值[1,...,10]表示。现在,我想将这些值更改为一个单热点向量表示,例如,class[1] = [1,0,0,0,0,0,0,0,0,0,0]。我不知道如何改变生成器和鉴别器类,以实现这一点。感谢任何帮助!
以下是目前为止的类别:

  1. class Generator(nn.Module):
  2. def __init__(self, latent_size , nb_filter, n_classes):
  3. super(Generator, self).__init__()
  4. self.label_embedding = nn.Embedding(n_classes, latent_size)
  5. self.conv1 = nn.ConvTranspose2d(latent_size, nb_filter * 8, 4, 1, 0)
  6. self.bn1 = nn.BatchNorm2d(nb_filter * 8)
  7. self.conv2 = nn.ConvTranspose2d(nb_filter * 8, nb_filter * 4, 4, 2, 1)
  8. self.bn2 = nn.BatchNorm2d(nb_filter * 4)
  9. self.conv3 = nn.ConvTranspose2d(nb_filter * 4, nb_filter * 2, 4, 2, 1)
  10. self.bn3 = nn.BatchNorm2d(nb_filter * 2)
  11. self.conv4 = nn.ConvTranspose2d(nb_filter * 2, nb_filter * 1, 4, 2, 1)
  12. self.bn4 = nn.BatchNorm2d(nb_filter * 1)
  13. self.conv5 = nn.ConvTranspose2d(nb_filter * 1, 3, 4, 2, 1)
  14. self.__initialize_weights()
  15. def forward(self, input, cl):
  16. x = torch.mul(self.label_embedding(cl), input)
  17. x = x.view(x.size(0), -1, 1, 1)
  18. x = self.conv1(x)
  19. x = self.bn1(x)
  20. x = F.relu(x)
  21. x = self.conv2(x)
  22. x = self.bn2(x)
  23. x = F.relu(x)
  24. x = self.conv3(x)
  25. x = self.bn3(x)
  26. x = F.relu(x)
  27. x = self.conv4(x)
  28. x = self.bn4(x)
  29. x = F.relu(x)
  30. x = self.conv5(x)
  31. return torch.tanh(x)
  32. def __initialize_weights(self):
  33. for m in self.modules():
  34. if isinstance(m, nn.Conv2d):
  35. m.weight.data.normal_(0.0, 0.02)
  36. elif isinstance(m, nn.BatchNorm2d):
  37. m.weight.data.normal_(1.0, 0.02)
  38. m.bias.data.fill_(0)
  39. class Discriminator(nn.Module):
  40. def __init__(self, nb_filter, n_classes):
  41. super(Discriminator, self).__init__()
  42. self.nb_filter = nb_filter
  43. self.conv1 = nn.Conv2d(3, nb_filter, 4, 2, 1)
  44. self.conv2 = nn.Conv2d(nb_filter, nb_filter * 2, 4, 2, 1)
  45. self.bn2 = nn.BatchNorm2d(nb_filter * 2)
  46. self.conv3 = nn.Conv2d(nb_filter * 2, nb_filter * 4, 4, 2, 1)
  47. self.bn3 = nn.BatchNorm2d(nb_filter * 4)
  48. self.conv4 = nn.Conv2d(nb_filter * 4, nb_filter * 8, 4, 2, 1)
  49. self.bn4 = nn.BatchNorm2d(nb_filter * 8)
  50. self.conv5 = nn.Conv2d(nb_filter * 8, nb_filter * 1, 4, 1, 0)
  51. self.gan_linear = nn.Linear(nb_filter * 1, 1)
  52. self.aux_linear = nn.Linear(nb_filter * 1, n_classes)
  53. self.__initialize_weights()
  54. def forward(self, input):
  55. x = self.conv1(input)
  56. x = F.leaky_relu(x, 0.2)
  57. x = self.conv2(x)
  58. x = self.bn2(x)
  59. x = F.leaky_relu(x, 0.2)
  60. x = self.conv3(x)
  61. x = self.bn3(x)
  62. x = F.leaky_relu(x, 0.2)
  63. x = self.conv4(x)
  64. x = self.bn4(x)
  65. x = F.leaky_relu(x, 0.2)
  66. x = self.conv5(x)
  67. x = x.view(-1, self.nb_filter * 1)
  68. c = self.aux_linear(x)
  69. s = self.gan_linear(x)
  70. s = torch.sigmoid(s)
  71. return s.squeeze(1), c.squeeze(1)
  72. def __initialize_weights(self):
  73. for m in self.modules():
  74. if isinstance(m, nn.Conv2d):
  75. m.weight.data.normal_(0.0, 0.02)
  76. elif isinstance(m, nn.BatchNorm2d):
  77. m.weight.data.normal_(1.0, 0.02)
  78. m.bias.data.fill_(0)
c7rzv4ha

c7rzv4ha1#

要转换为一种热编码表示形式,可以使用内置的nn.functional.one_hot

相关问题