Issue
I'm writing a German->English translator using an encoder/decoder pattern, where the encoder connects to the decoder by passing the state output of its last LSTM layer as the input state of the decoder's LSTM.
I'm stuck, though, because I don't know how to interpret the output of the encoder's LSTM. A small example:
tensor = tf.random.normal( shape = [ 2, 2, 2 ])
lstm = tf.keras.layers.LSTM(units=4, return_sequences=True, return_state=True )
result = lstm( ( tensor )
print( "result:\n", result )
Executing this in Tensorflow 2.0.0 produces:
result:
[
<tf.Tensor: id=6423, shape=(2, 2, 3), dtype=float32, numpy=
array([[[ 0.05060377, -0.00500009, -0.10052835],
[ 0.01804499, 0.0022153 , 0.01820258]],
[[ 0.00813384, -0.08705016, 0.06510869],
[-0.00241707, -0.05084776, 0.08321179]]], dtype=float32)>,
<tf.Tensor: id=6410, shape=(2, 3), dtype=float32, numpy=
array([[ 0.01804499, 0.0022153 , 0.01820258],
[-0.00241707, -0.05084776, 0.08321179]], dtype=float32)>,
<tf.Tensor: id=6407, shape=(2, 3), dtype=float32, numpy=
array([[ 0.04316794, 0.00382055, 0.04829971],
[-0.00499733, -0.10105743, 0.1755833 ]], dtype=float32)>
]
The result is a list of three tensors. The first appears to be the output of all
hidden states, as selected by return_sequences=True
. My question is: What is the interpretation of the second and third Tensors in result
?
Solution
An LSTM cell in Keras gives you three outputs:
- an output state
o_t
(1st output) - a hidden state
h_t
(2nd output) - a cell state
c_t
(3rd output)
and you can see an LSTM cell here:
The output state is generally passed to any upper layers, but not to any layers to the right. You would use this state when predicting your final output.
The cell state is information that is transported from previous LSTM cells to the current LSTM cell. When it arrives in the LSTM cell, the cell decides whether information from the cell state should be deleted, i.e. we will "forget" some states. This is done by a forget gate: This gate takes the current features x_t
as an input and the hidden state from the previous cell h_{t-1}
. It outputs a vector of probabilities that we multiply with the last cell state c_{t-1}
. After determining what information we want to forget, we update the cell state with the input gate. This gate takes the current features x_t
as an input and the hidden state from the previous cell h_{t-1}
and produces an input which is added to the last cell state (from which we have already forgotten information). This sum is the new cell state c_t
.
To get the new hidden state, we combine the cell state with a hidden state vector, which is again a vector of probabilities that determines which information from the cell state should be kept and which should be discarded.
As you have correctly interpreted, the first tensor is the output of all hidden states.
The second tensor is the hidden output, i.e. $h_t$, which acts as the short-term memory of the neural network The third tensor is the cell output, i.e. $c_t$, which acts as the long-term memory of the neural network
In the keras-documentation it is written that
whole_seq_output, final_memory_state, final_carry_state = lstm(inputs)
Unfortunately they do not use the term hidden and cell state. In their terminology the memory state is the short-term memory, i.e. the hidden state. The carry state is carried through all LSTM cells, i.e. it is the cell state.
We can also verify this using the source code of the LSTM cell, where a forward step is given by
def step(cell_inputs, cell_states):
"""Step function that will be used by Keras RNN backend."""
h_tm1 = cell_states[0] #previous memory state
c_tm1 = cell_states[2] #previous carry state
z = backend.dot(cell_inputs, kernel)
z += backend.dot(h_tm1, recurrent_kernel)
z = backend.bias_add(z, bias)
z0, z1, z2, z3 = array_ops.split(z, 4, axis=1)
i = nn.sigmoid(z0)
f = nn.sigmoid(z1)
c = f * c_tm1 + i * nn.tanh(z2)
o = nn.sigmoid(z3)
h = o * nn.tanh(c)
return h, [h, c]
From the formulas we can easily see that the first and second outputs are the output/hidden state and the third output is the cell state. and it also states that they name the hidden state "memory state" and the cell state "carry state"
Answered By - yuki
0 comments:
Post a Comment
Note: Only a member of this blog may post a comment.