Pytorch Torch Nn Rnn
The PyTorch `torch.nn.RNN` class implements a multi-layer Elman Recurrent Neural Network (RNN) with either $\tanh$ or ReLU non-linearities.
Unlike feedforward neural networks, RNNs maintain an internal hidden state that acts as a memory, allowing them to process sequences of inputs of arbitrary length. This makes them highly effective for sequential data processing tasks such as time-series forecasting, natural language processing (NLP), and speech recognition.
---
## Introduction
An Elman RNN applies a recurrence relation to a sequence of input vectors. For each element in an input sequence, the RNN layer computes the next hidden state $h_t$ using the current input $x_t$ and the previous hidden state $h_{t-1}$:
$$h_t = \tanh(W_{ih} x_t + b_{ih} + W_{hh} h_{t-1} + b_{hh})$$
Where:
* $h_t$ is the hidden state at time $t$.
* $x_t$ is the input at time $t$.
* $h_{t-1}$ is the hidden state of the previous time step at time $t-1$ (or the initial hidden state $h_0$ at $t=0$).
* $W_{ih}$ and $W_{hh}$ are the learnable input-to-hidden and hidden-to-hidden weights.
* $b_{ih}$ and $b_{hh}$ are the corresponding bias vectors.
While more advanced architectures like LSTMs (`torch.nn.LSTM`) and GRUs (`torch.nn.GRU`) are often preferred in practice to mitigate the vanishing gradient problem, understanding `torch.nn.RNN` is fundamental to mastering sequence modeling in PyTorch.
---
## Syntax and Parameters
### Initialization Signature
```python
class torch.nn.RNN(*args, **kwargs)
```
### Key Constructor Parameters
| Parameter | Type | Default | Description |
| :--- | :--- | :--- | :--- |
| `input_size` | `int` | *Required* | The number of expected features in the input $x$. |
| `hidden_size` | `int` | *Required* | The number of features in the hidden state $h$. |
| `num_layers` | `int` | `1` | Number of recurrent layers stacked on top of each other. |
| `nonlinearity` | `str` | `'tanh'` | The non-linear activation function to use. Can be either `'tanh'` or `'relu'`. |
| `bias` | `bool` | `True` | If `False`, the layer does not use bias weights $b_{ih}$ and $b_{hh}$. |
| `batch_first` | `bool` | `False` | If `True`, the input and output tensors are provided as `(batch, seq, feature)` instead of `(seq, batch, feature)`. |
| `dropout` | `float` | `0.0` | If non-zero, introduces a Dropout layer on the outputs of each RNN layer except the last layer. |
| `bidirectional` | `bool` | `False` | If `True`, becomes a bidirectional RNN. |
### Input and Output Shapes
Assuming `batch_first=False` (the default):
* **Inputs**:
* `input` of shape `(seq_len, batch_size, input_size)`: Tensor containing the features of the input sequence.
* `h_0` of shape `(num_layers * num_directions, batch_size, hidden_size)`: Tensor containing the initial hidden state for each element in the batch. Defaults to zeros if not provided.
* **Outputs**:
* `output` of shape `(seq_len, batch_size, num_directions * hidden_size)`: Tensor containing the output features ($h_t$) from the last layer of the RNN, for each time step $t$.
* `h_n` of shape `(num_layers * num_directions, batch_size, hidden_size)`: Tensor containing the final hidden state for $t = \text{seq\_len}$.
---
## Code Example
Below is a complete, self-contained code example demonstrating how to initialize a `torch.nn.RNN` layer, pass data through it, and handle the output shapes.
```python
import torch
import torch.nn as nn
# 1. Define Hyperparameters
batch_size = 3
seq_len = 5
input_size = 10 # e.g., word embedding dimension
hidden_size = 20 # Dimension of the hidden state
num_layers = 2 # Stacked RNN layers
# 2. Instantiate the RNN Layer
# We set batch_first=True to align with standard data pipelines (Batch, Sequence, Feature)
rnn = nn.RNN(
input_size=input_size,
hidden_size=hidden_size,
num_layers=num_layers,
batch_first=True,
nonlinearity='tanh'
)
# 3. Create Dummy Input Data
# Shape: (batch_size, seq_len, input_size)
dummy_input = torch.randn(batch_size, seq_len, input_size)
# 4. Initialize the Hidden State (Optional)
# Shape: (num_layers * num_directions, batch_size, hidden_size)
# Since bidirectional=False, num_directions is 1.
h0 = torch.zeros(num_layers, batch_size, hidden_size)
# 5. Forward Pass
# If h0 is not explicitly passed, PyTorch automatically initializes it to zeros.
output, hn = rnn(dummy_input, h0)
# 6. Inspect Output Shapes
print("--- Shape Verification ---")
print(f"Input Shape: {dummy_input.shape}")
print(f"Output Shape: {output.shape}") # Expected: (batch_size, seq_len, hidden_size)
print(f"h_n Shape: {hn.shape}") # Expected: (num_layers, batch_size, hidden_size)
# Verify that the last step of the output matches the last layer's hidden state
# output[:, -1, :] should equal hn[-1, :, :]
last_step_output = output[:, -1, :]
last_layer_hidden = hn[-1, :, :]
assert torch.allclose(last_step_output, last_layer_hidden, atol=1e-6)
print("\nVerification successful: Last output step matches final hidden state.")
```
---
## Best Practices and Common Pitfalls
### 1. Watch Out for the `batch_first` Flag
By default, PyTorch recurrent layers expect inputs in the shape `(seq_len, batch_size, input_size)`. This is historically optimized for CUDA performance. However, most modern data loaders output tensors as `(batch_size, seq_len, input_size)`.
* **Tip**: Always explicitly set `batch_first=True` if your data pipeline uses batch-first dimensions to avoid silent shape mismatches or unexpected tensor permutations. Note that `batch_first` **only** affects the input and output tensors; the hidden state $h_0$ and $h_n$ tensors always retain the shape `(num_layers, batch, hidden_size)`.
### 2. The Vanishing and Exploding Gradient Problem
Standard RNNs struggle to learn long-term dependencies (sequences longer than 10β20 steps) due to vanishing or exploding gradients during backpropagation through time (BPTT).
* **Tip**: If you are training on long sequences, use `torch.nn.LSTM` or `torch.nn.GRU` instead of `torch.nn.RNN`. If you must use standard RNNs, apply gradient clipping using `torch.nn.utils.clip_grad_norm_` to prevent exploding gradients.
### 3. Reusing Hidden States Across Batches
When processing continuous sequences (like long text documents divided into chunks), you may want to pass the final hidden state `h_n` of the current batch as the initial hidden state `h_0` of the next batch.
* **Pitfall**: If you pass `h_n` directly to the next iteration without detaching it, PyTorch will attempt to backpropagate through the entire history of all previous batches, leading to massive memory consumption and `RuntimeError: CUDA out of memory`.
* **Solution**: Always call `.detach()` on the hidden state if you are carrying it over to a new batch:
```python
# Carry over state without tracking history
h0 = hn.detach()
```
YouTip