Pytorch Torch Nn Flatten
# PyTorch torch.nn.Flatten Function
[ PyTorch torch.nn Reference Manual](#)
* * *
`torch.nn.Flatten` is a tensor flattening module in PyTorch.
It flattens multi-dimensional tensors into one dimension, commonly used for connecting convolutional layers and fully connected layers.
### Function Definition
torch.nn.Flatten(start_dim=1, end_dim=-1)
**Parameter Description:**
* `start_dim` (int): The dimension to start flattening. Default is 1 (preserves batch dimension).
* `end_dim` (int): The dimension to end flattening. Default is -1 (to the last dimension).
* * *
## Usage Examples
### Example 1: Basic Usage
## Example
import torch
import torch.nn as nn
flatten = nn.Flatten()
# 4D input: (batch, channels, height, width)
x = torch.randn(4,3,32,32)
output = flatten(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
print("After flattening: 3*32*32 = 3072 dimensions")
### Example 2: Preserve Batch Dimension
## Example
import torch
import torch.nn as nn
# start_dim=1 preserves batch dimension
x = torch.randn(8,64,8,8)
print("Input:", x.shape)
# Flatten to (8, 4096)
out1 = nn.Flatten(start_dim=1)(x)
print("Start from dimension 1:", out1.shape)
# Do not preserve batch
out2 = nn.Flatten(start_dim=0)(x)
print("Start from dimension 0:", out2.shape)
### Example 3: 3D Input
## Example
import torch
import torch.nn as nn
# 3D input: (batch, seq_len, features)
x = torch.randn(4,100,512)
# Flatten sequence and features
flatten = nn.Flatten(start_dim=1)
output = flatten(x)
print("Input shape:", x.shape)
print("Output shape:", output.shape)
### Example 4: Complete CNN Example
## Example
import torch
import torch.nn as nn
model = nn.Sequential(
nn.Conv2d(3,32,3, padding=1),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(32,64,3, padding=1),
nn.ReLU(),
nn.AdaptiveAvgPool2d(1),
nn.Flatten(),
nn.Linear(64,10)
)
x = torch.randn(4,3,32,32)
output = model(x)
print("Input:", x.shape,"-> Output:", output.shape)
* * *
## Usage Scenarios
* **CNN to FC**: Flatten convolutional layer output to connect to fully connected layer
* **Dimension Transformation**: Adjust tensor shape
* * *
[![Image 4: PyTorch torch.nn Reference Manual]
YouTip