基于Pytorch的自动编码器隐藏特征提取

ubbxdtey  于 2022-11-09  发布在  其他
关注(0)|答案(2)|浏览(274)

按照本post中的教程,我正在尝试训练自动编码器并从其隐藏层中提取特征。
所以我的问题是:
1.在autocoder类中,有一个“forward”函数。但是,我在代码中找不到任何调用该函数的地方。那么,它是如何训练的呢?
1.我的问题是因为我觉得如果我想提取特征,我应该在自动编码器类中添加另一个函数(f“forward_hidden”):

def forward(self, features):
     #print("in forward")
     #print(type(features))
     activation = self.encoder_hidden_layer(features)
     activation = torch.relu(activation)
     code = self.encoder_output_layer(activation)
     code = torch.relu(code)
     activation = self.decoder_hidden_layer(code)
     activation = torch.relu(activation)
     activation = self.decoder_output_layer(activation)
     reconstructed = torch.relu(activation)
     return reconstructed

 def forward_hidden(self, features):
     activation = self.encoder_hidden_layer(features)
     activation = torch.relu(activation)
     code = self.encoder_output_layer(activation)
     code = torch.relu(code)
     return code

然后,经过训练,也就是在主代码中的这一行之后:

print("AE, epoch : {}/{}, loss = {:.6f}".format(epoch + 1, epochs_AE, loss))

我可以把下面的代码从隐藏层中检索出特征:

hidden_features = model_AE.forward_hidden(my_input)

这种方法正确吗?不过,我仍然想知道“forward”函数是如何用于训练的。因为我在被调用的代码中看不到它。

7hiiyaii

7hiiyaii1#

forward是模型的本质,实际上定义了模型的功能。
在培训期间,它被隐含地调用为model(input)
如果您想知道如何在运行模型后提取中间特征,您可以注册一个forward-hook,如here所述,它将为您“捕获”这些值。

doinxwow

doinxwow2#

当使用PyTorch创建nn.Module类时,forward函数被隐式调用,您不需要单独调用它。

相关问题