Pytorch中MLP结构到RNN结构的转换

vwoqyblh  于 2023-01-13  发布在  其他
关注(0)|答案(1)|浏览(126)

我想设计一个RNN类结构,它产生与MLP类结构相同的输入和输出。
实际上我的MLP代码是

class sample(nn.Module):
   def__init__(self):
      super(sample, self).init()
      self.linear = nn.Linear(1, 20)
   def forward(self, t, is_train = False, y = None):
       a = self.linear(t)
       return a

我试过

class sample(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layer):
        super(sample, self).__init__()
        self.input_size = input_size
        self.hidden_dim = hidden_dim
        self.num_layer = num_layer
        self.linear = nn.rnn(1, 20,1, batch_size = false)
    def forward(self, t, is_train = False, y = None):
        a = self.rnn(t) 
        return a

但我认为它是错误的,我应该如何修改代码?

6rqinv9w

6rqinv9w1#

你是说像这样的东西吗?

class sample(nn.Module):
    def __init__(self, input_size, hidden_dim, num_layer):
        super().__init__()
        self.rnn = nn.RNN(input_size, hidden_dim, num_layer)
    def forward(self, x):
        return self.rnn(x)

但如果实在无法更改代码,您可以简单地将sample替换为nn.RNNsample = nn.RNN
torch.nn.RNN
顺便说一句,https://discuss.pytorch.org/可能是问这类问题的更好的地方...

相关问题