Pytorch:转换为单热点矢量表示

g6ll5ycj  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(116)

我正在CIFAR-10上训练一个acGAN。到目前为止,相应的标签由它们的数值[1,...,10]表示。现在,我想将这些值更改为一个单热点向量表示,例如,class[1] = [1,0,0,0,0,0,0,0,0,0,0]。我不知道如何改变生成器和鉴别器类,以实现这一点。感谢任何帮助!
以下是目前为止的类别:

class Generator(nn.Module):

    def __init__(self, latent_size , nb_filter, n_classes):
        super(Generator, self).__init__()
        self.label_embedding = nn.Embedding(n_classes, latent_size)
        self.conv1 = nn.ConvTranspose2d(latent_size, nb_filter * 8, 4, 1, 0)
        self.bn1 = nn.BatchNorm2d(nb_filter * 8)
        self.conv2 = nn.ConvTranspose2d(nb_filter * 8, nb_filter * 4, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(nb_filter * 4)
        self.conv3 = nn.ConvTranspose2d(nb_filter * 4, nb_filter * 2, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(nb_filter * 2)
        self.conv4 = nn.ConvTranspose2d(nb_filter * 2, nb_filter * 1, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(nb_filter * 1)
        self.conv5 = nn.ConvTranspose2d(nb_filter * 1, 3, 4, 2, 1)
        self.__initialize_weights()

    def forward(self, input, cl):
        x = torch.mul(self.label_embedding(cl), input)
        x = x.view(x.size(0), -1, 1, 1)
        x = self.conv1(x)
        x = self.bn1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.relu(x)
        x = self.conv4(x)
        x = self.bn4(x)
        x = F.relu(x)
        x = self.conv5(x)
        return torch.tanh(x)

    def __initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)

class Discriminator(nn.Module):

    def __init__(self, nb_filter, n_classes):
        super(Discriminator, self).__init__()
        self.nb_filter = nb_filter
        self.conv1 = nn.Conv2d(3, nb_filter, 4, 2, 1)
        self.conv2 = nn.Conv2d(nb_filter, nb_filter * 2, 4, 2, 1)
        self.bn2 = nn.BatchNorm2d(nb_filter * 2)
        self.conv3 = nn.Conv2d(nb_filter * 2, nb_filter * 4, 4, 2, 1)
        self.bn3 = nn.BatchNorm2d(nb_filter * 4)
        self.conv4 = nn.Conv2d(nb_filter * 4, nb_filter * 8, 4, 2, 1)
        self.bn4 = nn.BatchNorm2d(nb_filter * 8)
        self.conv5 = nn.Conv2d(nb_filter * 8, nb_filter * 1, 4, 1, 0)
        self.gan_linear = nn.Linear(nb_filter * 1, 1)
        self.aux_linear = nn.Linear(nb_filter * 1, n_classes)
        self.__initialize_weights()

    def forward(self, input):
        x = self.conv1(input)
        x = F.leaky_relu(x, 0.2)
        x = self.conv2(x)
        x = self.bn2(x)
        x = F.leaky_relu(x, 0.2)
        x = self.conv3(x)
        x = self.bn3(x)
        x = F.leaky_relu(x, 0.2)
        x = self.conv4(x)
        x = self.bn4(x)
        x = F.leaky_relu(x, 0.2)
        x = self.conv5(x)
        x = x.view(-1, self.nb_filter * 1)
        c = self.aux_linear(x)
        s = self.gan_linear(x)
        s = torch.sigmoid(s)
        return s.squeeze(1), c.squeeze(1)

    def __initialize_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                m.weight.data.normal_(0.0, 0.02)
            elif isinstance(m, nn.BatchNorm2d):
                m.weight.data.normal_(1.0, 0.02)
                m.bias.data.fill_(0)
c7rzv4ha

c7rzv4ha1#

要转换为一种热编码表示形式,可以使用内置的nn.functional.one_hot

相关问题