PyTorch LSTM / GRU
\\n\\nRecurrent Neural Networks (RNNs) face the vanishing gradient problem when processing sequential data, making it difficult to learn long-range dependencies.
\\n\\nLong Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs) address this issue by introducing gating mechanisms, forming the core models for sequence tasks such as time series analysis, natural language processing, and speech recognition.
\\n\\n\\n\\n
1. Limitations of RNNs and Gating Mechanisms
\\n\\nThe standard RNN combines the current input with the previous hidden state at each time step to compute a new hidden state:
\\n\\n$$\\nh_{t} = tanh left( W_{h} cdot h_{t - 1} + W_{x} cdot x_{t} + b right)\\n$$
\\n\\nThis structure has two core issues:
\\n\\n- \\n
- Vanishing gradients: During backpropagation, gradients are repeatedly multiplied by weight matrices. For long sequences, gradients decay exponentially, causing parameters for early time steps to be nearly unchanged; thus, the model cannot learn long-range dependencies. \\n
- Exploding gradients: When the largest eigenvalue of the weight matrix exceeds 1, gradients grow exponentially during backpropagation, leading to unstable training (typically mitigated by gradient clipping). \\n
The core idea of gating mechanisms is to introduce learnable βswitchesβ that allow the network to autonomously decide, at each time step, which information to retain, which to discard, and which new information to write into memory.
\\n\\nLSTM uses three gates (forget gate, input gate, output gate) along with a separate cell state; GRU simplifies the structure into two gates (reset gate, update gate), resulting in fewer parameters and faster training.
\\n\\n\\n\\n
2. LSTM Principles
\\n\\n2.1 Core Structure and Three Gates
\\n\\nLSTM maintains two state vectors across time steps:
\\n\\n- \\n
- Cell State (c_t): Carrier of long-term memory, allowing information to flow almost unchanged. \\n
- Hidden State (h_t): Short-term memory, also serving as the output at the current time step. \\n
All three gates are linear transformations followed by Sigmoid activation, producing values between 0 and 1, acting as βvalvesβ:
\\n\\n- \\n
- Forget Gate: Decides which information to discard from the cell state. \\n
- Input Gate: Decides which new information to store in the cell state. \\n
- Output Gate: Decides what to output based on the cell state. \\n
2.2 Forward Computation Formulas
\\n\\n$$\\ntext{Input: } & x_{t} text{ (current input)}, h_{t - 1} text{ (previous hidden state)}, c_{t - 1} text{ (previous cell state)} \\ntext{Forget gate: } & f_{t} = sigma left( W_{f} cdot left[ h_{t - 1}, x_{t} right] + b_{f} right) \\ntext{Input gate: } & i_{t} = sigma left( W_{i} cdot left[ h_{t - 1}, x_{t} right] + b_{i} right) \\ntext{Candidate value: } & tilde{g}_{t} = tanh left( W_{g} cdot left[ h_{t - 1}, x_{t} right] + b_{g} right) \\ntext{Output gate: } & o_{t} = sigma left( W_{o} cdot left[ h_{t - 1}, x_{t} right] + b_{o} right) \\ntext{Update cell state: } & c_{t} = f_{t} bigodot c_{t - 1} + i_{t} bigodot tilde{g}_{t} \\ntext{Update hidden state: } & h_{t} = o_{t} bigodot tanh left( c_{t} right)\\n$$
\\n\\nHere, $bigodot$ denotes element-wise multiplication (Hadamard product), and $sigma$ denotes the Sigmoid function.
\\n\\nComputation logic interpretation:
\\n\\n- \\n
f_t β c_{t-1}: The forget gate determines how much historical memory to retain (near 0 = forget, near 1 = keep). \\n i_t β g_t: The input gate determines how much new information to write;g_tis the candidate new content. \\n o_t β tanh(c_t): The output gate determines what to extract from the cell state as the hidden state output. \\n
\\n\\n
3. GRU Principles
\\n\\n3.1 Core Structure and Two Gates
\\n\\nGRU merges the forget gate and input gate of LSTM into a single update gate, and eliminates the independent cell state, retaining only the hidden stateβresulting in a simpler structure.
\\n\\n- \\n
- Reset Gate: Decides how much of the historical state to ignore when computing the candidate hidden state. \\n
- Update Gate: Decides how much of the historical state to retain and how much of the new candidate state to write in. \\n
3.2 Forward Computation Formulas
\\n\\n$$\\ntext{Input: } & x_{t} text{ (current input)}, h_{t - 1} text{ (previous hidden state)} \\ntext{Reset gate: } & r_{t} = sigma left( W_{r} cdot left[ h_{t - 1}, x_{t} right] + b_{r} right) \\ntext{Update gate: } & z_{t} = sigma left( W_{z} cdot left[ h_{t - 1}, x_{t} right] + b_{z} right) \\ntext{Candidate value: } & tilde{h}_{t} = tanh left( W_{h} cdot left[ r_{t} bigodot h_{t - 1}, x_{t} right] + b_{h} right) \\ntext{Update hidden state: } & h_{t} = left(1 - z_{t}right) bigodot h_{t - 1} + z_{t} bigodot tilde{h}_{t}\\n$$
\\n\\nComputation logic interpretation:
\\n\\n- \\n
- When the reset gate
r_tis near 0, the candidate statehΜ_tbarely depends on the historical stateβeffectively starting fresh. \\n - When the update gate
z_tis near 1, the new state favors the candidate value; when near 0, it retains more of the historical state. \\n - GRU has no independent cell state and uses approximately 75% of the parameters required by LSTM. \\n
\\n\\n
4. LSTM in PyTorch
\\n\\nThis section details the parameters, input/output shapes, and hidden state initialization methods of nn.LSTM.
4.1 nn.LSTM Parameters Explained
\\n\\nimport torch\\nimport torch.nn as nn\\n\\nlstm = nn.LSTM(\\n input_size=64, # dimension of input vector per time step\\n hidden_size=128, # dimension of hidden state (and cell state)\\n num_layers=2, # number of stacked layers (default: 1)\\n bias=True, # whether to use bias terms (default: True)\\n batch_first=False, # whether batch dimension is first in input/output (default: False)\\n dropout=0.0, # dropout probability between layers (only effective if num_layers > 1)\\n bidirectional=False, # whether to use bidirectional LSTM (default: False)\\n proj_size=0 # projection layer dimension (LSTM with projection); default 0 means no projection\\n)\\n\\n# Count total parameters\\ntotal_params = sum(p.numel() for p in lstm.parameters())\\nprint(f"LSTM Number of Parameters: {total_params:,}")\\n# ~197,632 parameters for input_size=64, hidden_size=128, num_layers=2\\n\\n\\nParameter count estimation formula (single-layer, unidirectional):
\\n\\n$$\\ntext{Parameters per layer} = 4 times left( hidden_size times input_size + hidden_size times hidden_size + hidden_size right) = 4 times hidden_size times left( input_size + hidden_size + 1 right)\\n$$
\\n\\nHere, the factor 4 corresponds to four sets of weight matrices: forget gate, input gate, candidate value, and output gate.
\\n\\n4.2 Input and Output Shapes
\\n\\nThis is a common source of errors when using LSTM; special attention must be paid to the effect of the batch_first parameter.
import torch\\nimport torch.nn as nn\\n\\n# ββ batch_first=False (default) ββββββββββββββββββββββββ\\nlstm = nn.LSTM(input_size=32, hidden_size=64, batch_first=False)\\n\\n# Input shape: (seq_len, batch_size, input_size)\\nseq_len, batch_size, input_size = 10, 4, 32\\nx = torch.randn(seq_len, batch_size, input_size)\\noutput, (h_n, c_n) = lstm(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([10, 4, 64]) β (seq_len, batch_size, hidden_size)\\n# Hidden state output at each time step\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([1, 4, 64]) β (num_layers * num_directions, batch_size, hidden_size)\\n# Hidden state at the final time step\\n\\nprint(f"c_n Shape: {c_n.shape}")\\n# torch.Size([1, 4, 64]) β same as h_n, cell state at the final time step\\n\\n# ββ batch_first=True (recommended, more intuitive) βββββββββββββ\\nlstm_bf = nn.LSTM(input_size=32, hidden_size=64, batch_first=True)\\n\\n# Input shape: (batch_size, seq_len, input_size)\\nx = torch.randn(batch_size, seq_len, input_size)\\noutput, (h_n, c_n) = lstm_bf(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([4, 10, 64]) β (batch_size, seq_len, hidden_size)\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([1, 4, 64]) β (num_layers, batch_size, hidden_size)\\n# Note: h_n shape is unaffected by batch_first\\n\\n\\n# ββ Multi-layer bidirectional LSTM output shapes βββββββββββββββββ\\nlstm_bd = nn.LSTM(\\n input_size=32, hidden_size=64,\\n num_layers=3, bidirectional=True, batch_first=True\\n)\\nx = torch.randn(batch_size, seq_len, input_size)\\noutput, (h_n, c_n) = lstm_bd(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([4, 10, 128])\\n# hidden_size Γ 2 = 128, due to concatenation in bidirectional mode\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([6, 4, 64])\\n# num_layers Γ num_directions = 3 Γ 2 = 6\\n\\n\\nOutput shape summary:
\\n\\n| Variable | \\nbatch_first=False | \\nbatch_first=True | \\n
|---|---|---|
output | \\n (seq_len, N, H * D) | \\n (N, seq_len, H * D) | \\n
h_n | \\n (L * D, N, H) | \\n (L * D, N, H) | \\n
c_n | \\n (L * D, N, H) | \\n (L * D, N, H) | \\n
\\n\\n\\n(N) = batch_size, (H) = hidden_size, (L) = num_layers, (D) = 2 (bidirectional) or 1 (unidirectional)
\\n
4.3 Hidden State Initialization
\\n\\nimport torch\\nimport torch.nn as nn\\n\\nlstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)\\nbatch_size = 8\\ndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")\\nlstm = lstm.to(device)\\n\\n# Method 1: No initial state passed; PyTorch automatically initializes to zeros\\nx = torch.randn(batch_size, 10, 32).to(device)\\noutput, (h_n, c_n) = lstm(x)\\n\\n# Method 2: Manually initialize to zeros (equivalent to method 1, but explicit)\\nnum_layers, num_directions = 2, 1\\nh_0 = torch.zeros(num_layers * num_directions, batch_size, 64).to(device)\\nc_0 = torch.zeros(num_layers * num_directions, batch_size, 64).to(device)\\noutput, (h_n, c_n) = lstm(x, (h_0, c_0))\\n\\n# Method 3: Stateful mode β pass state across batches\\n# Useful for slicing long sequences (e.g., language modeling, long text generation)\\n# Requires detach() to break the computational graph and prevent memory leaks\\nh, c = h_0, c_0\\nfor batch_x in data_loader:\\n batch_x = batch_x.to(device)\\n output, (h, c) = lstm(batch_x, (h, c))\\n h = h.detach() # Break gradient flow, keep only values\\n c = c.detach()\\n\\n# Method 4: Xavier or normal initialization (faster convergence in some cases)\\ndef init_hidden(lstm_module, batch_size, device):\\n num_layers = lstm_module.num_layers\\n hidden_size = lstm_module.hidden_size\\n directions = 2 if lstm_module.bidirectional else 1\\n h = torch.zeros(num_layers * directions, batch_size, hidden_size, device=device)\\n c = torch.zeros(num_layers * directions, batch_size, hidden_size, device=device)\\n nn.init.orthogonal_(h) # Orthogonal initialization helps stabilize training\\n return h, c\\n\\n\\n\\n\\n
5. GRU in PyTorch
\\n\\nThe GRU interface is nearly identical to LSTMβs, with the main difference being the absence of the cell state c.
5.1 nn.GRU Parameters Explained
\\n\\nimport torch.nn as nn\\n\\ngru = nn.GRU(\\n input_size=64,\\n hidden_size=128,\\n num_layers=2,\\n bias=True,\\n batch_first=True, # Recommended to set to True\\n dropout=0.3, # Dropout between layers\\n bidirectional=False,\\n)\\n\\n\\n5.2 Basic Usage Example
\\n\\nimport torch\\nimport torch.nn as nn\\n\\ngru = nn.GRU(input_size=32, hidden_size=64, batch_first=True)\\nbatch_size, seq_len = 8, 10\\nx = torch.randn(batch_size, seq_len, 32)\\n\\n# GRU returns only output and h_n (no c_n)\\noutput, h_n = gru(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([8, 10, 64]) β (batch_size, seq_len, hidden_size)\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([1, 8, 64]) β (num_layers, batch_size, hidden_size)\\n\\n# Extract output of the last time step (e.g., for classification tasks)\\nlast_hidden = output[:, -1, :] # (batch_size, hidden_size)\\n# Or equivalently:\\nlast_hidden = h_n.squeeze()\\n
YouTip