我正在CIFAR-10上训练一个acGAN。到目前为止,相应的标签由它们的数值[1,...,10]表示。现在,我想将这些值更改为一个单热点向量表示,例如,class[1] = [1,0,0,0,0,0,0,0,0,0,0]。我不知道如何改变生成器和鉴别器类,以实现这一点。感谢任何帮助!
以下是目前为止的类别:
class Generator(nn.Module):
def __init__(self, latent_size , nb_filter, n_classes):
super(Generator, self).__init__()
self.label_embedding = nn.Embedding(n_classes, latent_size)
self.conv1 = nn.ConvTranspose2d(latent_size, nb_filter * 8, 4, 1, 0)
self.bn1 = nn.BatchNorm2d(nb_filter * 8)
self.conv2 = nn.ConvTranspose2d(nb_filter * 8, nb_filter * 4, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(nb_filter * 4)
self.conv3 = nn.ConvTranspose2d(nb_filter * 4, nb_filter * 2, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(nb_filter * 2)
self.conv4 = nn.ConvTranspose2d(nb_filter * 2, nb_filter * 1, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(nb_filter * 1)
self.conv5 = nn.ConvTranspose2d(nb_filter * 1, 3, 4, 2, 1)
self.__initialize_weights()
def forward(self, input, cl):
x = torch.mul(self.label_embedding(cl), input)
x = x.view(x.size(0), -1, 1, 1)
x = self.conv1(x)
x = self.bn1(x)
x = F.relu(x)
x = self.conv2(x)
x = self.bn2(x)
x = F.relu(x)
x = self.conv3(x)
x = self.bn3(x)
x = F.relu(x)
x = self.conv4(x)
x = self.bn4(x)
x = F.relu(x)
x = self.conv5(x)
return torch.tanh(x)
def __initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
class Discriminator(nn.Module):
def __init__(self, nb_filter, n_classes):
super(Discriminator, self).__init__()
self.nb_filter = nb_filter
self.conv1 = nn.Conv2d(3, nb_filter, 4, 2, 1)
self.conv2 = nn.Conv2d(nb_filter, nb_filter * 2, 4, 2, 1)
self.bn2 = nn.BatchNorm2d(nb_filter * 2)
self.conv3 = nn.Conv2d(nb_filter * 2, nb_filter * 4, 4, 2, 1)
self.bn3 = nn.BatchNorm2d(nb_filter * 4)
self.conv4 = nn.Conv2d(nb_filter * 4, nb_filter * 8, 4, 2, 1)
self.bn4 = nn.BatchNorm2d(nb_filter * 8)
self.conv5 = nn.Conv2d(nb_filter * 8, nb_filter * 1, 4, 1, 0)
self.gan_linear = nn.Linear(nb_filter * 1, 1)
self.aux_linear = nn.Linear(nb_filter * 1, n_classes)
self.__initialize_weights()
def forward(self, input):
x = self.conv1(input)
x = F.leaky_relu(x, 0.2)
x = self.conv2(x)
x = self.bn2(x)
x = F.leaky_relu(x, 0.2)
x = self.conv3(x)
x = self.bn3(x)
x = F.leaky_relu(x, 0.2)
x = self.conv4(x)
x = self.bn4(x)
x = F.leaky_relu(x, 0.2)
x = self.conv5(x)
x = x.view(-1, self.nb_filter * 1)
c = self.aux_linear(x)
s = self.gan_linear(x)
s = torch.sigmoid(s)
return s.squeeze(1), c.squeeze(1)
def __initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
m.weight.data.normal_(0.0, 0.02)
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.normal_(1.0, 0.02)
m.bias.data.fill_(0)
1条答案
按热度按时间c7rzv4ha1#
要转换为一种热编码表示形式,可以使用内置的
nn.functional.one_hot
。