pytorch 未计算损失函数梯度- FastAI卷积VAE

zvokhttg  于 2022-11-09  发布在  其他
关注(0)|答案(1)|浏览(186)

我将Pytorch VAE example的例子修改为卷积网络,然后我想在FastAI中实现它。

class convVAE(nn.Module):
def __init__(self, dim_z=20):

    super(convVAE, self).__init__()

    self.cv1 = nn.Conv2d(1, 32, 3, stride=2)
    self.cv2 = nn.Conv2d(32, 64, 3, stride=2)
    self.fc31 = nn.Linear(2304, dim_z)
    self.fc32 = nn.Linear(2304, dim_z)
    self.fc4 = nn.Linear(dim_z, 2304)
    self.cv5 = nn.ConvTranspose2d(64, 32, 3, stride=2)
    self.cv6 = nn.ConvTranspose2d(32, 1, 3, stride=2, output_padding=1)

def encode(self, x):
    h1 = F.leaky_relu(self.cv1(x))
    h2 = F.leaky_relu(self.cv2(h1)).view(-1, 2304)

    return self.fc31(h2), self.fc32(h2)

def reparameterize(self, mu, logvar):
    std = torch.exp(0.5*logvar)
    eps = torch.randn_like(std)
    return mu + eps*std

def decode(self, z):
    h5 = F.leaky_relu(self.fc4(z)).view(-1, 64, 6, 6)
    h6 = F.leaky_relu(self.cv5(h5))
    return torch.sigmoid(self.cv6(h6))

def forward(self, x):
    mu, logvar = self.encode(x)
    z = self.reparameterize(mu, logvar)
    return self.decode(z).view(-1, 784), mu, logvar

def get_loss(res,y):
    y_hat, mu, logvar = res

    BCE = F.binary_cross_entropy(
        y.view(-1, 784),
        y_hat,
        reduction='sum')

    KLD = -0.5 * torch.sum(1 + logvar -
                           mu.pow(2) - logvar.exp())

    return BCE + KLD
block = DataBlock(
blocks=(ImageBlock(cls=PILImageBW),ImageBlock(cls=PILImageBW)),
get_items=get_image_files,
splitter=RandomSplitter(valid_pct=0.2, seed=42),
get_y=(lambda x: x),
batch_tfms=aug_transforms(mult=2., do_flip=False))

path = untar_data(URLs.MNIST)
loaders = block.dataloaders(path/“training”,num_workers=0,bs=32)
loaders.train.show_batch(max_n=4, nrows=1)

mdl = convVAE(5)
learn = Learner(loaders, mdl, loss_func = convVAE.get_loss)
learn.fit(1, cbs=ShortEpochCallback())

梯度不是根据损失计算的,因为在一个步骤之后所有参数都变成NaN。损失函数确实计算,但是相对较大的O(1e6)。模型和损失函数在原生Pytorch实现中工作。
编辑:解决方案似乎是由于def init(.)而不是def __init__(.) * 面部手掌 *

kmbjn2e3

kmbjn2e31#

您的BCE计算中存在错误:

BCE = F.binary_cross_entropy(
    y.view(-1, 784),  # this should be your model prediction
    y_hat,  # this should be the ground truth
    reduction='sum')

一个简单的解决方法是交换这两个参数。

相关问题