我在fastai书中从头开始学习第12章关于RNN/LSTM的内容,但是在尝试从头开始训练一个定制的LSTM时遇到了困难。
这是样板位(遵循书中的示例)
from fastai.text.all import *
path = untar_data(URLs.HUMAN_NUMBERS)
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
text = ' . '.join([l.strip() for l in lines])
tokens = text.split(' ')
vocab = L(*tokens).unique()
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
def group_chunks(ds, bs):
m = len(ds) // bs
new_ds = L()
for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
return new_ds
sl = 3
bs = 64
seqs = L((tensor(nums[i:i+sl]), nums[i+sl])
for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
group_chunks(seqs[cut:], bs),
bs=bs, drop_last=True, shuffle=False)
这是最重要的
class LSTMCell(Module):
def __init__(self, ni, nh):
self.forget_gate = nn.Linear(ni + nh, nh)
self.input_gate = nn.Linear(ni + nh, nh)
self.cell_gate = nn.Linear(ni + nh, nh)
self.output_gate = nn.Linear(ni + nh, nh)
def forward(self, input, state):
h, c = state
h = torch.cat([h, input], dim=1)
c = c * torch.sigmoid(self.forget_gate(h))
c = c + torch.sigmoid(self.input_gate(h)) * torch.tanh(self.cell_gate(h))
h = torch.sigmoid(self.output_gate(h)) * torch.tanh(c)
return h, (h, c)
class MyModel(Module):
def __init__(self, vocab_sz, n_hidden):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.cells = [LSTMCell(bs, n_hidden) for _ in range(sl)]
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = torch.zeros(bs, n_hidden)
self.c = torch.zeros(bs, n_hidden)
def forward(self, x):
x = self.i_h(x)
h, c = self.h, self.c
for i, cell in enumerate(self.cells):
res, (h, c) = cell(x[:, i, :], (h, c))
self.h = h.detach()
self.c = c.detach()
return self.h_o(res)
def reset(self):
self.h.zero_()
self.c.zero_()
learn = Learner(dls, MyModel(len(vocab), 64), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(5, 1e-2)
培训输出如下所示
任何帮助都很感激
1条答案
按热度按时间6tqwzwtp1#
经过一番尝试,我终于弄明白了。问题是我初始化单元格列表的方式。在
MyModule.__init__
中,我只需要将行改为它被破坏的原因是通过在常规列表中初始化模块,参数对pytorch/fastai隐藏。通过使用
nn.ModuleList
,参数被注册并可以被训练