Issue
The code below is from https://pytorch.org/tutorials/beginner/nlp/sequence_models_tutorial.html
for i in inputs:
# Step through the sequence one element at a time.
# after each step, hidden contains the hidden state.
out, hidden = lstm(i.view(1, 1, -1), hidden)
For me it seems like handling the LSTM in this way breaks the computational graph as hidden keeps on getting overridden. Should all the hidden states not be stored in an array so the computational graph can be maintained, so backprop can flow through the hidden states?
Solution
In this case, no: you are providing the hidden state from one layer to the other at every loop iteration. This means the gradient flow is kept the and backpropagation will occur through the hidden states as well.
To give a clear answer to your question: yes the hidden
variable is been overwritten. However, the activations corresponding to those hidden states themselves have been cached in memory for backpropagation.
If you take the example from the tutorial page, they are looping through the sequence of elements one by one:
torch.manual_seed(1)
lstm = nn.LSTM(3, 3) # H_in = 3, H_hidden = H_out = 3
inputs = torch.randn(5, 1, 3) # sequence length = 5
Our data sample is shaped as (sequence_length=5, batch_size=1, feature_size=3)
.
The hidden states h_0
and c_0
are initialized once:
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3))
Then we loop over the sequence and doing inference on the LSTM block with one element of the sequence at a time:
for element in inputs:
out, hidden = lstm(element[None], hidden)
The reshape with view
in the tutorial page is superfluous and won't work in general cases where batch_size
is not equal to 1
... Doing element[None]
will just add an additional dimension to the tensor which is what we want.
So the input element passed is essentially a single-element sequence shaped i.e. (1, 1, 3)
. Do note this is a stateful call since we are indeed providing the hidden states from the previous layer.
Here the final output and hidden states are:
>>> out
tensor([[[-0.3600, 0.0893, 0.0215]]], grad_fn=<StackBackward>) # <- h_5
>>> hidden
(tensor([[[-0.3600, 0.0893, 0.0215]]], grad_fn=<StackBackward>), # <- h_5
tensor([[[-1.1298, 0.4467, 0.0254]]], grad_fn=<StackBackward>)) # <- c_5
This is actually the default behavior performed by nn.LSTM
, i.e. calling it with the whole sequence: hidden states will be passed from one sequence element to another.
torch.manual_seed(1)
hidden = (torch.randn(1, 1, 3),
torch.randn(1, 1, 3))
out, hidden = lstm(inputs, hidden)
Then:
>>> out
tensor([[[-0.2682, 0.0304, -0.1526]], # <- h_1
[[-0.5370, 0.0346, -0.1958]], # <- h_2
[[-0.3947, 0.0391, -0.1217]], # <- h_3
[[-0.1854, 0.0740, -0.0979]], # <- h_4
[[-0.3600, 0.0893, 0.0215]]], grad_fn=<StackBackward>) # <- h_5
>>> hidden
(tensor([[[-0.3600, 0.0893, 0.0215]]], grad_fn=<StackBackward>), # <- h_5
tensor([[[-1.1298, 0.4467, 0.0254]]], grad_fn=<StackBackward>)) # <- c_5
You can see, here out
contains the consecutive hidden states, while it only contained the last hidden state in the previous example.
Answered By - Ivan
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.