Issue
I'm following Chapter 12 on RNNs/LSTMs from scratch in the fastai book, but getting stuck trying to train a custom built LSTM from scratch. Here is my code
This is the boilerplate bit (following the examples in the book)
from fastai.text.all import *
path = untar_data(URLs.HUMAN_NUMBERS)
lines = L()
with open(path/'train.txt') as f: lines += L(*f.readlines())
with open(path/'valid.txt') as f: lines += L(*f.readlines())
text = ' . '.join([l.strip() for l in lines])
tokens = text.split(' ')
vocab = L(*tokens).unique()
word2idx = {w:i for i,w in enumerate(vocab)}
nums = L(word2idx[i] for i in tokens)
def group_chunks(ds, bs):
m = len(ds) // bs
new_ds = L()
for i in range(m): new_ds += L(ds[i + m*j] for j in range(bs))
return new_ds
sl = 3
bs = 64
seqs = L((tensor(nums[i:i+sl]), nums[i+sl])
for i in range(0,len(nums)-sl-1,sl))
cut = int(len(seqs) * 0.8)
dls = DataLoaders.from_dsets(group_chunks(seqs[:cut], bs),
group_chunks(seqs[cut:], bs),
bs=bs, drop_last=True, shuffle=False)
And this is the meat of the thing
class LSTMCell(Module):
def __init__(self, ni, nh):
self.forget_gate = nn.Linear(ni + nh, nh)
self.input_gate = nn.Linear(ni + nh, nh)
self.cell_gate = nn.Linear(ni + nh, nh)
self.output_gate = nn.Linear(ni + nh, nh)
def forward(self, input, state):
h, c = state
h = torch.cat([h, input], dim=1)
c = c * torch.sigmoid(self.forget_gate(h))
c = c + torch.sigmoid(self.input_gate(h)) * torch.tanh(self.cell_gate(h))
h = torch.sigmoid(self.output_gate(h)) * torch.tanh(c)
return h, (h, c)
class MyModel(Module):
def __init__(self, vocab_sz, n_hidden):
self.i_h = nn.Embedding(vocab_sz, n_hidden)
self.cells = [LSTMCell(bs, n_hidden) for _ in range(sl)]
self.h_o = nn.Linear(n_hidden, vocab_sz)
self.h = torch.zeros(bs, n_hidden)
self.c = torch.zeros(bs, n_hidden)
def forward(self, x):
x = self.i_h(x)
h, c = self.h, self.c
for i, cell in enumerate(self.cells):
res, (h, c) = cell(x[:, i, :], (h, c))
self.h = h.detach()
self.c = c.detach()
return self.h_o(res)
def reset(self):
self.h.zero_()
self.c.zero_()
learn = Learner(dls, MyModel(len(vocab), 64), loss_func=CrossEntropyLossFlat(), metrics=accuracy, cbs=ModelResetter)
learn.fit_one_cycle(5, 1e-2)
The training output looks like this
Any help appreciated
Solution
After some playing around I was able to figure it out. The issue was the way I was initialising the list of cells. In MyModule.__init__
I only needed to change the line to
self.cells = nn.ModuleList([LSTMCell(bs, n_hidden) for _ in range(sl)])
The reason it was broken was that by initialising the Modules in a regular list, the parameters were hidden from pytorch/fastai. By using a nn.ModuleList
the parameters are registered and can be trained
Answered By - Mark Dunne
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.