Update of a memory cell in an LSTM

Prerequisites

Description

This equation updates the memory cell inside of the LSTM block. It takes into consideration its previous value, as well as outputs from the input and forget gates. The memory cell is responsible for "remembering" a state over a time series.

Equation

\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) + \htmlId{tooltip-inputGateLSTM}{g^\text{input}} (\htmlId{tooltip-wholeNumber}{n}+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]

Symbols Used

\(u\)

This symbol represents the state of the input neuron to the LSTM.

\(g^\text{input}\)

This symbol represents the state of the input gate of the LSTM.

\(g^\text{forget}\)

This symbol represents the state of the forget gate of the LSTM.

\(c\)

This symbol represents the memory cell of an LSTM.

\(n\)

This symbol represents any given whole number, \( n \in \htmlId{tooltip-setOfWholeNumbers}{\mathbb{W}}\).

Derivation

The purpose of the memory cell is to retain some, but not all, information. This means that some of the information stored in the memory cell needs to be removed and replaced with new observations.

The forget gate

\[\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1) = \htmlId{tooltip-sigmoid}{\sigma}(\htmlId{tooltip-weightMatrix}{\mathbf{W}}^{\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}}[1;x^{\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}}])\]

generates a vector of "weights". If a weight is closer to 0, this value will be removed from the memory cell. On the other hand, if a weight is close to 1, the associated value in the memory cell will remain mostly unchanged. In order to forget some of the information, the weights, \(\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\) are multiplied element-wise with the past state of the memory cell:

\[\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n})\]

The result of this multiplication is the retained part of the previous memory.

Similarly, the input gate

\[\htmlId{tooltip-inputGateLSTM}{g^\text{input}}(\htmlId{tooltip-wholeNumber}{n}+1) = \htmlId{tooltip-sigmoid}{\sigma}(\htmlId{tooltip-weightMatrix}{\mathbf{W}}^{\htmlId{tooltip-inputGateLSTM}{g^\text{input}}}[1;x^{\htmlId{tooltip-inputGateLSTM}{g^\text{input}}}])\]

generates a vector of weights that judge how relevant is the new input, \(\htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\). Again, the weight are multiplied element-wise with the relevant input:

\[\htmlId{tooltip-inputGateLSTM}{g^\text{input}} (\htmlId{tooltip-wholeNumber}{n}+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]

The result of this multiplication is the "newly remembered" information that will not be stored in the memory cell.

Now, the new information can replace the parts that have been forgotten to obtain a new state of the memory cell. We do this by simply summing the old memories and new observations:

\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) + \htmlId{tooltip-inputGateLSTM}{g^\text{input}} (\htmlId{tooltip-wholeNumber}{n}+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]

Example

Let the current memory cell be

\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) = \begin{bmatrix}0.7 \\0.3\end{bmatrix}\]

and the state of the input neuron

\[\htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1) = \begin{bmatrix}0.4 \\0.6\end{bmatrix}\]

Now, the forget gate decided that the first value in the memory cell is completely irrelevant, making its weight 0. The second value is still important so its weight stays large:

\[\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1) = \begin{bmatrix}0.0 \\0.9\end{bmatrix}\]

The input gate decided that only that the first part of the input is more important to remember than the second one:

\[\htmlId{tooltip-inputGateLSTM}{g^\text{input}}(\htmlId{tooltip-wholeNumber}{n}+1) = \begin{bmatrix}0.8 \\0.2\end{bmatrix}\]

Using these values, we can calculate the new state of the memory cell:

\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\htmlId{tooltip-forgetGateLSTM}{g^\text{forget}}(\htmlId{tooltip-wholeNumber}{n}+1)\cdot \htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}) + \htmlId{tooltip-inputGateLSTM}{g^\text{input}} (n+1) \cdot \htmlId{tooltip-inputNeuronLSTM}{u}(\htmlId{tooltip-wholeNumber}{n}+1)\]

\[\htmlId{tooltip-memoryCellLSTM}{c}(n\htmlId{tooltip-wholeNumber}{n}+1)=\begin{bmatrix}0.0 \\0.9\end{bmatrix}\cdot\begin{bmatrix}0.7 \\0.3\end{bmatrix}+\begin{bmatrix}0.8 \\0.2\end{bmatrix}\cdot\begin{bmatrix}0.4 \\0.6\end{bmatrix}\]

\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\begin{bmatrix}0.0 \\0.27\end{bmatrix}+\begin{bmatrix}0.32 \\0.12\end{bmatrix}\]

\[\htmlId{tooltip-memoryCellLSTM}{c}(\htmlId{tooltip-wholeNumber}{n}+1)=\begin{bmatrix}0.32 \\0.39\end{bmatrix}\]

References

  1. Jaeger, H. (2024, May 4). Neural Networks (AI) (WBAI028-05) Lecture Notes BSc program in Artificial Intelligence. Retrieved from https://www.ai.rug.nl/minds/uploads/LN_NN_RUG.pdf

Was this page helpful?