YouTip LogoYouTip

Pytorch Lstm Gru

PyTorch LSTM / GRU

\\n\\n

Recurrent Neural Networks (RNNs) face the vanishing gradient problem when processing sequential data, making it difficult to learn long-range dependencies.

\\n\\n

Long Short-Term Memory (LSTM) networks and Gated Recurrent Units (GRUs) address this issue by introducing gating mechanisms, forming the core models for sequence tasks such as time series analysis, natural language processing, and speech recognition.

\\n\\n
\\n\\n

1. Limitations of RNNs and Gating Mechanisms

\\n\\n

The standard RNN combines the current input with the previous hidden state at each time step to compute a new hidden state:

\\n\\n

$$\\nh_{t} = tanh left( W_{h} cdot h_{t - 1} + W_{x} cdot x_{t} + b right)\\n$$

\\n\\n

This structure has two core issues:

\\n\\n
    \\n
  • Vanishing gradients: During backpropagation, gradients are repeatedly multiplied by weight matrices. For long sequences, gradients decay exponentially, causing parameters for early time steps to be nearly unchanged; thus, the model cannot learn long-range dependencies.
  • \\n
  • Exploding gradients: When the largest eigenvalue of the weight matrix exceeds 1, gradients grow exponentially during backpropagation, leading to unstable training (typically mitigated by gradient clipping).
  • \\n
\\n\\n

The core idea of gating mechanisms is to introduce learnable β€œswitches” that allow the network to autonomously decide, at each time step, which information to retain, which to discard, and which new information to write into memory.

\\n\\n

LSTM uses three gates (forget gate, input gate, output gate) along with a separate cell state; GRU simplifies the structure into two gates (reset gate, update gate), resulting in fewer parameters and faster training.

\\n\\n
\\n\\n

2. LSTM Principles

\\n\\n

2.1 Core Structure and Three Gates

\\n\\n

LSTM maintains two state vectors across time steps:

\\n\\n
    \\n
  • Cell State (c_t): Carrier of long-term memory, allowing information to flow almost unchanged.
  • \\n
  • Hidden State (h_t): Short-term memory, also serving as the output at the current time step.
  • \\n
\\n\\n

All three gates are linear transformations followed by Sigmoid activation, producing values between 0 and 1, acting as β€œvalves”:

\\n\\n
    \\n
  • Forget Gate: Decides which information to discard from the cell state.
  • \\n
  • Input Gate: Decides which new information to store in the cell state.
  • \\n
  • Output Gate: Decides what to output based on the cell state.
  • \\n
\\n\\n

2.2 Forward Computation Formulas

\\n\\n

$$\\ntext{Input: } & x_{t} text{ (current input)}, h_{t - 1} text{ (previous hidden state)}, c_{t - 1} text{ (previous cell state)} \\ntext{Forget gate: } & f_{t} = sigma left( W_{f} cdot left[ h_{t - 1}, x_{t} right] + b_{f} right) \\ntext{Input gate: } & i_{t} = sigma left( W_{i} cdot left[ h_{t - 1}, x_{t} right] + b_{i} right) \\ntext{Candidate value: } & tilde{g}_{t} = tanh left( W_{g} cdot left[ h_{t - 1}, x_{t} right] + b_{g} right) \\ntext{Output gate: } & o_{t} = sigma left( W_{o} cdot left[ h_{t - 1}, x_{t} right] + b_{o} right) \\ntext{Update cell state: } & c_{t} = f_{t} bigodot c_{t - 1} + i_{t} bigodot tilde{g}_{t} \\ntext{Update hidden state: } & h_{t} = o_{t} bigodot tanh left( c_{t} right)\\n$$

\\n\\n

Here, $bigodot$ denotes element-wise multiplication (Hadamard product), and $sigma$ denotes the Sigmoid function.

\\n\\n

Computation logic interpretation:

\\n\\n
    \\n
  • f_t βŠ™ c_{t-1}: The forget gate determines how much historical memory to retain (near 0 = forget, near 1 = keep).
  • \\n
  • i_t βŠ™ g_t: The input gate determines how much new information to write; g_t is the candidate new content.
  • \\n
  • o_t βŠ™ tanh(c_t): The output gate determines what to extract from the cell state as the hidden state output.
  • \\n
\\n\\n
\\n\\n

3. GRU Principles

\\n\\n

3.1 Core Structure and Two Gates

\\n\\n

GRU merges the forget gate and input gate of LSTM into a single update gate, and eliminates the independent cell state, retaining only the hidden stateβ€”resulting in a simpler structure.

\\n\\n
    \\n
  • Reset Gate: Decides how much of the historical state to ignore when computing the candidate hidden state.
  • \\n
  • Update Gate: Decides how much of the historical state to retain and how much of the new candidate state to write in.
  • \\n
\\n\\n

3.2 Forward Computation Formulas

\\n\\n

$$\\ntext{Input: } & x_{t} text{ (current input)}, h_{t - 1} text{ (previous hidden state)} \\ntext{Reset gate: } & r_{t} = sigma left( W_{r} cdot left[ h_{t - 1}, x_{t} right] + b_{r} right) \\ntext{Update gate: } & z_{t} = sigma left( W_{z} cdot left[ h_{t - 1}, x_{t} right] + b_{z} right) \\ntext{Candidate value: } & tilde{h}_{t} = tanh left( W_{h} cdot left[ r_{t} bigodot h_{t - 1}, x_{t} right] + b_{h} right) \\ntext{Update hidden state: } & h_{t} = left(1 - z_{t}right) bigodot h_{t - 1} + z_{t} bigodot tilde{h}_{t}\\n$$

\\n\\n

Computation logic interpretation:

\\n\\n
    \\n
  • When the reset gate r_t is near 0, the candidate state hΜƒ_t barely depends on the historical stateβ€”effectively starting fresh.
  • \\n
  • When the update gate z_t is near 1, the new state favors the candidate value; when near 0, it retains more of the historical state.
  • \\n
  • GRU has no independent cell state and uses approximately 75% of the parameters required by LSTM.
  • \\n
\\n\\n
\\n\\n

4. LSTM in PyTorch

\\n\\n

This section details the parameters, input/output shapes, and hidden state initialization methods of nn.LSTM.

\\n\\n

4.1 nn.LSTM Parameters Explained

\\n\\n
import torch\\nimport torch.nn as nn\\n\\nlstm = nn.LSTM(\\n    input_size=64,      # dimension of input vector per time step\\n    hidden_size=128,    # dimension of hidden state (and cell state)\\n    num_layers=2,       # number of stacked layers (default: 1)\\n    bias=True,          # whether to use bias terms (default: True)\\n    batch_first=False,  # whether batch dimension is first in input/output (default: False)\\n    dropout=0.0,        # dropout probability between layers (only effective if num_layers > 1)\\n    bidirectional=False, # whether to use bidirectional LSTM (default: False)\\n    proj_size=0         # projection layer dimension (LSTM with projection); default 0 means no projection\\n)\\n\\n# Count total parameters\\ntotal_params = sum(p.numel() for p in lstm.parameters())\\nprint(f"LSTM Number of Parameters: {total_params:,}")\\n# ~197,632 parameters for input_size=64, hidden_size=128, num_layers=2\\n
\\n\\n

Parameter count estimation formula (single-layer, unidirectional):

\\n\\n

$$\\ntext{Parameters per layer} = 4 times left( hidden_size times input_size + hidden_size times hidden_size + hidden_size right) = 4 times hidden_size times left( input_size + hidden_size + 1 right)\\n$$

\\n\\n

Here, the factor 4 corresponds to four sets of weight matrices: forget gate, input gate, candidate value, and output gate.

\\n\\n

4.2 Input and Output Shapes

\\n\\n

This is a common source of errors when using LSTM; special attention must be paid to the effect of the batch_first parameter.

\\n\\n
import torch\\nimport torch.nn as nn\\n\\n# ── batch_first=False (default) ────────────────────────\\nlstm = nn.LSTM(input_size=32, hidden_size=64, batch_first=False)\\n\\n# Input shape: (seq_len, batch_size, input_size)\\nseq_len, batch_size, input_size = 10, 4, 32\\nx = torch.randn(seq_len, batch_size, input_size)\\noutput, (h_n, c_n) = lstm(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([10, 4, 64]) β†’ (seq_len, batch_size, hidden_size)\\n# Hidden state output at each time step\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([1, 4, 64]) β†’ (num_layers * num_directions, batch_size, hidden_size)\\n# Hidden state at the final time step\\n\\nprint(f"c_n Shape: {c_n.shape}")\\n# torch.Size([1, 4, 64]) β†’ same as h_n, cell state at the final time step\\n\\n# ── batch_first=True (recommended, more intuitive) ─────────────\\nlstm_bf = nn.LSTM(input_size=32, hidden_size=64, batch_first=True)\\n\\n# Input shape: (batch_size, seq_len, input_size)\\nx = torch.randn(batch_size, seq_len, input_size)\\noutput, (h_n, c_n) = lstm_bf(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([4, 10, 64]) β†’ (batch_size, seq_len, hidden_size)\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([1, 4, 64]) β†’ (num_layers, batch_size, hidden_size)\\n# Note: h_n shape is unaffected by batch_first\\n
\\n\\n
# ── Multi-layer bidirectional LSTM output shapes ─────────────────\\nlstm_bd = nn.LSTM(\\n    input_size=32, hidden_size=64,\\n    num_layers=3, bidirectional=True, batch_first=True\\n)\\nx = torch.randn(batch_size, seq_len, input_size)\\noutput, (h_n, c_n) = lstm_bd(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([4, 10, 128])\\n# hidden_size Γ— 2 = 128, due to concatenation in bidirectional mode\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([6, 4, 64])\\n# num_layers Γ— num_directions = 3 Γ— 2 = 6\\n
\\n\\n

Output shape summary:

\\n\\n\\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n \\n
Variablebatch_first=Falsebatch_first=True
output(seq_len, N, H * D)(N, seq_len, H * D)
h_n(L * D, N, H)(L * D, N, H)
c_n(L * D, N, H)(L * D, N, H)
\\n\\n
\\n

(N) = batch_size, (H) = hidden_size, (L) = num_layers, (D) = 2 (bidirectional) or 1 (unidirectional)

\\n
\\n\\n

4.3 Hidden State Initialization

\\n\\n
import torch\\nimport torch.nn as nn\\n\\nlstm = nn.LSTM(input_size=32, hidden_size=64, num_layers=2, batch_first=True)\\nbatch_size = 8\\ndevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")\\nlstm = lstm.to(device)\\n\\n# Method 1: No initial state passed; PyTorch automatically initializes to zeros\\nx = torch.randn(batch_size, 10, 32).to(device)\\noutput, (h_n, c_n) = lstm(x)\\n\\n# Method 2: Manually initialize to zeros (equivalent to method 1, but explicit)\\nnum_layers, num_directions = 2, 1\\nh_0 = torch.zeros(num_layers * num_directions, batch_size, 64).to(device)\\nc_0 = torch.zeros(num_layers * num_directions, batch_size, 64).to(device)\\noutput, (h_n, c_n) = lstm(x, (h_0, c_0))\\n\\n# Method 3: Stateful mode β€” pass state across batches\\n# Useful for slicing long sequences (e.g., language modeling, long text generation)\\n# Requires detach() to break the computational graph and prevent memory leaks\\nh, c = h_0, c_0\\nfor batch_x in data_loader:\\n    batch_x = batch_x.to(device)\\n    output, (h, c) = lstm(batch_x, (h, c))\\n    h = h.detach()  # Break gradient flow, keep only values\\n    c = c.detach()\\n\\n# Method 4: Xavier or normal initialization (faster convergence in some cases)\\ndef init_hidden(lstm_module, batch_size, device):\\n    num_layers = lstm_module.num_layers\\n    hidden_size = lstm_module.hidden_size\\n    directions = 2 if lstm_module.bidirectional else 1\\n    h = torch.zeros(num_layers * directions, batch_size, hidden_size, device=device)\\n    c = torch.zeros(num_layers * directions, batch_size, hidden_size, device=device)\\n    nn.init.orthogonal_(h)  # Orthogonal initialization helps stabilize training\\n    return h, c\\n
\\n\\n
\\n\\n

5. GRU in PyTorch

\\n\\n

The GRU interface is nearly identical to LSTM’s, with the main difference being the absence of the cell state c.

\\n\\n

5.1 nn.GRU Parameters Explained

\\n\\n
import torch.nn as nn\\n\\ngru = nn.GRU(\\n    input_size=64,\\n    hidden_size=128,\\n    num_layers=2,\\n    bias=True,\\n    batch_first=True,  # Recommended to set to True\\n    dropout=0.3,       # Dropout between layers\\n    bidirectional=False,\\n)\\n
\\n\\n

5.2 Basic Usage Example

\\n\\n
import torch\\nimport torch.nn as nn\\n\\ngru = nn.GRU(input_size=32, hidden_size=64, batch_first=True)\\nbatch_size, seq_len = 8, 10\\nx = torch.randn(batch_size, seq_len, 32)\\n\\n# GRU returns only output and h_n (no c_n)\\noutput, h_n = gru(x)\\n\\nprint(f"output Shape: {output.shape}")\\n# torch.Size([8, 10, 64]) β†’ (batch_size, seq_len, hidden_size)\\n\\nprint(f"h_n Shape: {h_n.shape}")\\n# torch.Size([1, 8, 64]) β†’ (num_layers, batch_size, hidden_size)\\n\\n# Extract output of the last time step (e.g., for classification tasks)\\nlast_hidden = output[:, -1, :]  # (batch_size, hidden_size)\\n# Or equivalently:\\nlast_hidden = h_n.squeeze()\\n
← Pytorch GanPytorch Transfer Learning β†’