YouTip LogoYouTip

Pytorch Torch Nn Lstm

PyTorch torch.nn.LSTM function | Rookie Tutorial

Image 1: PyTorch torch.nn Reference Manual PyTorch torch.nn Reference Manual


torch.nn.LSTM is the module in PyTorch used for Long Short-Term Memory networks.

LSTM is a special type of recurrent neural network capable of learning long-term dependencies, widely used in sequence modeling tasks.

Function Definition

torch.nn.LSTM(input_size, hidden_size, num_layers=1, bias=True, batch_first=True, dropout=0, bidirectional=False)

Parameter Explanation:

  • input_size (int): Dimension of input features.
  • hidden_size (int): Dimension of hidden state.
  • num_layers (int): Number of LSTM layers. Default is 1.
  • bias (bool): Whether to use bias. Default is True.
  • batch_first (bool): If True, input and output shapes are (batch, seq, feature). Default is True.
  • dropout (float): Dropout rate applied to non-last layers. Default is 0.
  • bidirectional (bool): Whether to use bidirectional LSTM. Default is False.

Input and Output

Input:

  • input: Tensor with shape (batch, seq_len, input_size)
  • h_0: Initial hidden state, shape (num_layers * num_directions, batch, hidden_size)
  • c_0: Initial cell state, shape (num_layers * num_directions, batch, hidden_size)

Output:

  • output: Output from the last hidden layer, shape (batch, seq_len, num_directions * hidden_size)
  • h_n: Final hidden states for all layers
  • c_n: Final cell states for all layers

Usage Examples

Example 1: Basic Usage

Create and use an LSTM:

import torch
import torch.nn as nn

# Create LSTM: input dimension 256, hidden dimension 512, 2 layers
lstm = nn.LSTM(input_size=256, hidden_size=512, num_layers=2, batch_first=True)

# Create input: batch=4, sequence length=10, input dimension=256
input_tensor = torch.randn(4, 10, 256)

# Forward pass
output, (h_n, c_n) = lstm(input_tensor)

print("Input shape:", input_tensor.shape)
print("Output shape:", output.shape)  # (4, 10, 512)
print("Hidden state shape:", h_n.shape)  # (2, 4, 512)
print("Cell state shape:", c_n.shape)  # (2, 4, 512)

Example 2: Bidirectional LSTM

Use bidirectional LSTM to capture bidirectional context:

import torch
import torch.nn as nn

# Bidirectional LSTM
bilstm = nn.LSTM(input_size=256, hidden_size=256, num_layers=2, batch_first=True, bidirectional=True)

input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = bilstm(input_tensor)

print("Bidirectional LSTM output shape:", output.shape)  # (4, 10, 512) = 256*2
print("Hidden state shape:", h_n.shape)  # (4, 4, 256) = 2 layers * 2 directions
print("Last layer hidden state:", h_n[-2:, :, :].shape)  # Forward and backward

Example 3: Initialize Hidden State Manually

Manually initialize hidden states:

import torch
import torch.nn as nn

lstm = nn.LSTM(input_size=256, hidden_size=512, batch_first=True)

# Manually create initial hidden states
batch_size = 4
num_layers = 2
hidden_size = 512

h_0 = torch.zeros(num_layers, batch_size, hidden_size)
c_0 = torch.zeros(num_layers, batch_size, hidden_size)

# Pass in initial states
input_tensor = torch.randn(4, 10, 256)
output, (h_n, c_n) = lstm(input_tensor, (h_0, c_0))

print("Custom initial state used successfully")
print("Output shape:", output.shape)

Example 4: Complete Sentiment Classification Model

Text classification based on LSTM:

import torch
import torch.nn as nn

class LSTMClassifier(nn.Module):
    def __init__(self, vocab_size, embed_dim, hidden_dim, num_classes=2):
        super(LSTMClassifier, self).__init__()
        # Embedding layer
        self.embedding = nn.Embedding(vocab_size, embed_dim, padding_idx=0)
        # LSTM layer
        self.lstm = nn.LSTM(
            input_size=embed_dim,
            hidden_size=hidden_dim,
            num_layers=2,
            batch_first=True,
            bidirectional=True,
            dropout=0.3
        )
        # Fully connected classification layer
        self.fc = nn.Linear(hidden_dim * 2, num_classes)

    def forward(self, x):
        # x: (batch, seq_len)
        embedded = self.embedding(x)  # (batch, seq_len, embed_dim)
        # LSTM output
        output, (hidden, cell) = self.lstm(embedded)
        # Concatenate the final hidden states from both directions
        # hidden: (4, batch, hidden_dim) - 2 layers * 2 directions
        hidden = torch.cat([hidden, hidden], dim=1)  # (batch, hidden_dim*2)
        # Classification
        logits = self.fc(hidden)
        return logits

# Instantiate model
vocab_size = 10000
model = LSTMClassifier(vocab_size=vocab_size, embed_dim=128, hidden_dim=128, num_classes=2)

# Test input: batch=8, sequence length=50
input_ids = torch.randint(1, vocab_size, (8, 50))
output = model(input_ids)

print("Model structure:")
print(model)
print("Input shape:", input_ids.shape)
print("Output shape:", output.shape)  # (8, 2)

Example 5: Stacked Multi-layer LSTM

Deep LSTM network:

import torch
import torch.nn as nn

# 4-layer stacked LSTM with dropout
deep_lstm = nn.LSTM(
    input_size=256,
    hidden_size=512,
    num_layers=4,
    batch_first=True,
    dropout=0.4  # Dropout between layers
)

input_tensor = torch.randn(2, 20, 256)
output, (h_n, c_n) = deep_lstm(input_tensor)

print("4-layer LSTM output shape:", output.shape)
print("Hidden state shape (4 layers):", h_n.shape)
print("Cell state shape (4 layers):", c_n.shape)

Concept of LSTM Gates

LSTM controls information flow through three gates:

  • Forget Gate: Determines how much information from the previous time step to retain
  • Input Gate: Determines how much new information to add
  • Output Gate: Determines how much information to output

Frequently Asked Questions

Q1: What does batch_first=True mean?

The first dimension of input and output tensors is batch_size. If False, the first dimension is sequence length.

Q2: When to use bidirectional LSTM?

For tasks requiring bidirectional context such as sequence labeling, sentiment analysis. Machine translation commonly uses encoder-decoder architecture.

Q3: How to choose hidden layer size?

Typically between 128–512; adjust based on task complexity and data volume. Too small leads to underfitting; too large may cause overfitting.


Application Scenarios

nn.LSTM is mainly used in the following scenarios:

  • Natural Language Processing: Text classification, named entity recognition
  • Time Series Prediction: Stock prediction, speech recognition
  • Sequence-to-Sequence Tasks: Machine translation, text generation

Tip: When using bidirectional=True, the output dimension becomes hidden_size * 2.


Image 2: PyTorch torch.nn Reference Manual PyTorch torch.nn Reference Manual

← Pytorch Torch Nn LogsoftmaxPytorch Torch Nn Instancenorm2 β†’