This equation updates the memory cell inside of the LSTM block. It takes into consideration its previous value, as well as outputs from the input and forget gates. The memory cell is responsible for "remembering" a state over a time series.
\(u\) | This symbol represents the state of the input neuron to the LSTM. |
\(g^\text{input}\) | This symbol represents the state of the input gate of the LSTM. |
\(g^\text{forget}\) | This symbol represents the state of the forget gate of the LSTM. |
\(c\) | This symbol represents the memory cell of an LSTM. |
\(n\) | This symbol represents any given whole number, \( n \in \htmlId{tooltip-setOfWholeNumbers}{\mathbb{W}}\). |
The purpose of the memory cell is to retain some, but not all, information. This means that some of the information stored in the memory cell needs to be removed and replaced with new observations.
The forget gate
\[\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1) = \htmlId{tooltip-sigmoid}{\sigma}(\htmlId{tooltip-weightMatrix}{\mathbf{W}}^{\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}}[1;x^{\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}}])\]
generates a vector of "weights". If a weight is closer to 0, this value will be removed from the memory cell. On the other hand, if a weight is close to 1, the associated value in the memory cell will remain mostly unchanged. In order to forget some of the information, the weights, \(\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\) are multiplied element-wise with the past state of the memory cell:
\[\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n})\]
The result of this multiplication is the retained part of the previous memory.
Similarly, the input gate
\[\htmlId{tooltip-inputGateLSTM}{g^\text{input}}(\htmlId{tooltip-wholeNumber}{n}+1) = \htmlId{tooltip-sigmoid}{\sigma}(\htmlId{tooltip-weightMatrix}{\mathbf{W}}^{\htmlId{tooltip-inputGateLSTM}{g^\text{input}}}[1;x^{\htmlId{tooltip-inputGateLSTM}{g^\text{input}}}])\]
generates a vector of weights that judge how relevant is the new input, \(\htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\). Again, the weight are multiplied element-wise with the relevant input:
\[\htmlId{tooltip-inputGateLSTM}{g^\text{input}} (\htmlId{tooltip-wholeNumber}{n}+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]
The result of this multiplication is the "newly remembered" information that will not be stored in the memory cell.
Now, the new information can replace the parts that have been forgotten to obtain a new state of the memory cell. We do this by simply summing the old memories and new observations:
\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) + \htmlId{tooltip-inputGateLSTM}{g^\text{input}} (\htmlId{tooltip-wholeNumber}{n}+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]
Let the current memory cell be
\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) = \begin{bmatrix}0.7 \\0.3\end{bmatrix}\]
and the state of the input neuron
\[\htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1) = \begin{bmatrix}0.4 \\0.6\end{bmatrix}\]
Now, the forget gate decided that the first value in the memory cell is completely irrelevant, making its weight 0. The second value is still important so its weight stays large:
\[\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1) = \begin{bmatrix}0.0 \\0.9\end{bmatrix}\]
The input gate decided that only that the first part of the input is more important to remember than the second one:
\[\htmlId{tooltip-inputGateLSTM}{g^\text{input}}(\htmlId{tooltip-wholeNumber}{n}+1) = \begin{bmatrix}0.8 \\0.2\end{bmatrix}\]
Using these values, we can calculate the new state of the memory cell:
\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) + \htmlId{tooltip-inputGateLSTM}{g^\text{input}} (n+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]
\[\htmlId{tooltip-memoryCellLSTM}{c}(n\htmlId{tooltip-wholeNumber}{n}+1)=\begin{bmatrix}0.0 \\0.9\end{bmatrix}\cdot\begin{bmatrix}0.7 \\0.3\end{bmatrix}+\begin{bmatrix}0.8 \\0.2\end{bmatrix}\cdot\begin{bmatrix}0.4 \\0.6\end{bmatrix}\]
\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\begin{bmatrix}0.0 \\0.27\end{bmatrix}+\begin{bmatrix}0.32 \\0.12\end{bmatrix}\]
\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\begin{bmatrix}0.32 \\0.39\end{bmatrix}\]
Was this page helpful?