The `torch.split` function is a fundamental tensor manipulation utility in PyTorch. It allows developers to partition a single tensor into multiple smaller tensors (sub-tensors) along a specified dimension.
Unlike slicing, which extracts a single sub-tensor, `torch.split` is designed to segment an entire tensor in a single operation. This is highly useful in deep learning workflows, such as splitting a batch of data across multiple GPUs, separating multi-head attention projections in Transformers, or dividing feature maps along the channel dimension.
---
## Syntax and Parameters
The `torch.split` function takes an input tensor and splits it along a given dimension based on either a single chunk size or a list of specific split sizes.
### Function Signature
```python
torch.split(tensor, split_size_or_sections, dim=0) -> List
```
### Parameters
| Parameter | Type | Description |
| :--- | :--- | :--- |
| `tensor` | `torch.Tensor` | The source tensor to be split. |
| `split_size_or_sections` | `int` or `list(int)` | **If `int`**: Specifies the size of each split chunk along `dim`. The last chunk will be smaller if the tensor size along `dim` is not divisible by this value.
**If `list(int)`**: Specifies the exact size of each individual split chunk along `dim`. The sum of the list elements must equal the total size of the tensor along `dim`. |
| `dim` | `int` | The dimension along which to split the tensor. Defaults to `0` (the first dimension). |
### Return Value
* **`List`**: A list containing the resulting sub-tensors. These sub-tensors are **views** of the original tensor (they share the same underlying data storage).
---
## Code Examples
Below is a complete, runnable script demonstrating the two primary ways to use `torch.split`: splitting by equal chunk sizes and splitting by custom, variable sizes.
```python
import torch
# Set a manual seed for reproducibility
torch.manual_seed(42)
# Create a dummy tensor representing a batch of feature maps
# Shape: [Batch Size: 4, Channels: 6, Height: 2, Width: 2]
x = torch.arange(48).view(4, 6, 2, 2)
print("Original Tensor Shape:", x.shape)
# =====================================================================
# Example 1: Splitting into equal-sized chunks (using an integer)
# =====================================================================
print("\n--- Example 1: Split along Dimension 1 (Channels) into chunks of size 2 ---")
# Since dim 1 has size 6, splitting with size 2 yields 3 tensors of shape [4, 2, 2, 2]
splits_equal = torch.split(x, split_size_or_sections=2, dim=1)
print(f"Number of splits: {len(splits_equal)}")
for i, chunk in enumerate(splits_equal):
print(f" Chunk {i} shape: {chunk.shape}")
# =====================================================================
# Example 2: Splitting into variable-sized chunks (using a list)
# =====================================================================
print("\n--- Example 2: Split along Dimension 1 with custom sizes [1, 3, 2] ---")
# The sum of the list [1, 3, 2] must equal the size of dim 1 (6)
splits_custom = torch.split(x, split_size_or_sections=[1, 3, 2], dim=1)
print(f"Number of splits: {len(splits_custom)}")
for i, chunk in enumerate(splits_custom):
print(f" Chunk {i} shape: {chunk.shape}")
# =====================================================================
# Example 3: Handling non-divisible sizes
# =====================================================================
print("\n--- Example 3: Splitting with non-divisible chunk size ---")
# Splitting dim 1 (size 6) into chunks of size 4.
# This yields one chunk of size 4, and a remaining chunk of size 2.
splits_remainder = torch.split(x, split_size_or_sections=4, dim=1)
print(f"Number of splits: {len(splits_remainder)}")
for i, chunk in enumerate(splits_remainder):
print(f" Chunk {i} shape: {chunk.shape}")
```
### Output
```text
Original Tensor Shape: torch.Size([4, 6, 2, 2])
--- Example 1: Split along Dimension 1 (Channels) into chunks of size 2 ---
Number of splits: 3
Chunk 0 shape: torch.Size([4, 2, 2, 2])
Chunk 1 shape: torch.Size([4, 2, 2, 2])
Chunk 2 shape: torch.Size([4, 2, 2, 2])
--- Example 2: Split along Dimension 1 with custom sizes [1, 3, 2] ---
Number of splits: 3
Chunk 0 shape: torch.Size([4, 1, 2, 2])
Chunk 1 shape: torch.Size([4, 3, 2, 2])
Chunk 2 shape: torch.Size([4, 2, 2, 2])
--- Example 3: Splitting with non-divisible chunk size ---
Number of splits: 2
Chunk 0 shape: torch.Size([4, 4, 2, 2])
Chunk 1 shape: torch.Size([4, 2, 2, 2])
```
---
## Best Practices and Common Pitfalls
### 1. Memory Management: Views vs. Copies
`torch.split` returns **views** of the original tensor, not deep copies. This means the returned tensors share the same underlying memory buffer as the input tensor.
* **Pitfall**: Modifying an element in one of the split chunks will modify the original tensor (and vice-versa).
* **Best Practice**: If you need to modify the split tensors independently without affecting the source tensor, call `.clone()` on the split outputs:
```python
independent_chunk = splits_equal.clone()
```
### 2. `torch.split` vs. `torch.chunk`
PyTorch offers another splitting function called `torch.chunk`. It is important to know when to use which:
* Use **`torch.split`** when you want to define the **size of each chunk** (e.g., "give me chunks of size 2").
* Use **`torch.chunk`** when you want to define the **number of chunks** (e.g., "split this tensor into 3 parts").
### 3. Sum of Sections Must Match Dimension Size
When passing a list of integers to `split_size_or_sections`, the sum of the integers in the list must exactly equal the size of the dimension you are splitting.
* **Pitfall**: If your tensor has size `6` along `dim=1`, calling `torch.split(tensor, [2, 3], dim=1)` will raise a `RuntimeError` because $2 + 3 = 5 \neq 6$.
* **Best Practice**: Always verify that your custom split list accounts for all elements along the target dimension.
π Categories
- β‘ JavaScript (1589)
- π PHP (872)
- π Python3 (810)
- π HTML (691)
- βοΈ C# (650)
- π Python (594)
- β Java (552)
- βοΈ PyTorch (534)
- π§ Linux (472)
- βοΈ C (432)
- π¦ jQuery (406)
- π¨ CSS (377)
- π XML (259)
- π¦ jQuery UI (231)
- π― Bootstrap (220)
- βοΈ C++ (215)
- π °οΈ Angular (205)
- π HTML DOM (201)
- π΄ Redis (188)
- π Web Building (142)
- π Vue.js (141)
- π R (131)
- πΌ Pandas (124)
- ποΈ SQL (105)
- βοΈ Docker (86)
- βοΈ TypeScript (73)
- βοΈ Highcharts (70)
- π AI Agent (70)
- βοΈ React (68)
- π Node.js (65)
- βοΈ Machine Learning (60)
- π Git (59)
- π΅ Go (58)
- π Markdown (58)
- π’ NumPy (55)
- π§ͺ Flask (54)
- βοΈ Scala (53)
- ποΈ SQLite (52)
- π JSTL (52)
- βοΈ VS Code (51)
- π MongoDB (49)
- π Perl (48)
- π Ruby (47)
- π Matplotlib (47)
- βοΈ Uncategorized (46)
- π Swift (46)
- ποΈ PostgreSQL (46)
- βοΈ Data Structures (46)
- π Playwright (46)
- π iOS (45)
- ποΈ MySQL (44)
- βοΈ LangChain (43)
- π FastAPI (40)
- βοΈ Ionic (38)
- π Design Patterns (37)
- βοΈ Eclipse (37)
- π¨ CSS3 (34)
- π Lua (34)
- βοΈ Codex (34)
- πΈ Django (32)
- βοΈ OpenCV (32)
- π Rust (31)
- π JSP (31)
- βοΈ Claude Code (31)
- π Pillow (30)
- βοΈ OpenCode (28)
- π AI Skills (27)
- π Flutter (26)
- π Maven (26)
- π¨ Tailwind CSS (25)
- π§ TensorFlow (25)
- π Servlet (24)
- π Dart (23)
- π Assembly (23)
- βοΈ Memcached (22)
- βοΈ SVG (22)
- βοΈ Electron (22)
- π NLP (22)
- π Regex (21)
- π Android (20)
- π£ Kotlin (19)
- π Julia (19)
- π SOAP (17)
- π Selenium (17)
- π PowerShell (17)
- π Sass (16)
- π HTTP (16)
- π Zig (15)
- π AI (15)
- π AJAX (14)
- π Swagger (14)
- βοΈ Scikit-learn (13)
- βοΈ ECharts (13)
- βοΈ Chart.js (13)
- βοΈ Cursor (13)
- βοΈ SciPy (12)
- π RDF (12)
- π Ollama (12)
- π Next.js (12)
- π Plotly Dash (12)
- π JSON (11)
- π RESTful API (11)
- π WSDL (9)
- βοΈ CMake (8)
- π Firebug (7)
- π Nginx (6)
- βΈοΈ Kubernetes (6)
- π Jupyter (6)
- π LaTeX (4)
- π UniApp (4)
- ποΈ SQL Server (1)
YouTip