YouTip LogoYouTip

Pytorch Torch Split

The `torch.split` function is a fundamental tensor manipulation utility in PyTorch. It allows developers to partition a single tensor into multiple smaller tensors (sub-tensors) along a specified dimension. Unlike slicing, which extracts a single sub-tensor, `torch.split` is designed to segment an entire tensor in a single operation. This is highly useful in deep learning workflows, such as splitting a batch of data across multiple GPUs, separating multi-head attention projections in Transformers, or dividing feature maps along the channel dimension. --- ## Syntax and Parameters The `torch.split` function takes an input tensor and splits it along a given dimension based on either a single chunk size or a list of specific split sizes. ### Function Signature ```python torch.split(tensor, split_size_or_sections, dim=0) -> List ``` ### Parameters | Parameter | Type | Description | | :--- | :--- | :--- | | `tensor` | `torch.Tensor` | The source tensor to be split. | | `split_size_or_sections` | `int` or `list(int)` | **If `int`**: Specifies the size of each split chunk along `dim`. The last chunk will be smaller if the tensor size along `dim` is not divisible by this value.

**If `list(int)`**: Specifies the exact size of each individual split chunk along `dim`. The sum of the list elements must equal the total size of the tensor along `dim`. | | `dim` | `int` | The dimension along which to split the tensor. Defaults to `0` (the first dimension). | ### Return Value * **`List`**: A list containing the resulting sub-tensors. These sub-tensors are **views** of the original tensor (they share the same underlying data storage). --- ## Code Examples Below is a complete, runnable script demonstrating the two primary ways to use `torch.split`: splitting by equal chunk sizes and splitting by custom, variable sizes. ```python import torch # Set a manual seed for reproducibility torch.manual_seed(42) # Create a dummy tensor representing a batch of feature maps # Shape: [Batch Size: 4, Channels: 6, Height: 2, Width: 2] x = torch.arange(48).view(4, 6, 2, 2) print("Original Tensor Shape:", x.shape) # ===================================================================== # Example 1: Splitting into equal-sized chunks (using an integer) # ===================================================================== print("\n--- Example 1: Split along Dimension 1 (Channels) into chunks of size 2 ---") # Since dim 1 has size 6, splitting with size 2 yields 3 tensors of shape [4, 2, 2, 2] splits_equal = torch.split(x, split_size_or_sections=2, dim=1) print(f"Number of splits: {len(splits_equal)}") for i, chunk in enumerate(splits_equal): print(f" Chunk {i} shape: {chunk.shape}") # ===================================================================== # Example 2: Splitting into variable-sized chunks (using a list) # ===================================================================== print("\n--- Example 2: Split along Dimension 1 with custom sizes [1, 3, 2] ---") # The sum of the list [1, 3, 2] must equal the size of dim 1 (6) splits_custom = torch.split(x, split_size_or_sections=[1, 3, 2], dim=1) print(f"Number of splits: {len(splits_custom)}") for i, chunk in enumerate(splits_custom): print(f" Chunk {i} shape: {chunk.shape}") # ===================================================================== # Example 3: Handling non-divisible sizes # ===================================================================== print("\n--- Example 3: Splitting with non-divisible chunk size ---") # Splitting dim 1 (size 6) into chunks of size 4. # This yields one chunk of size 4, and a remaining chunk of size 2. splits_remainder = torch.split(x, split_size_or_sections=4, dim=1) print(f"Number of splits: {len(splits_remainder)}") for i, chunk in enumerate(splits_remainder): print(f" Chunk {i} shape: {chunk.shape}") ``` ### Output ```text Original Tensor Shape: torch.Size([4, 6, 2, 2]) --- Example 1: Split along Dimension 1 (Channels) into chunks of size 2 --- Number of splits: 3 Chunk 0 shape: torch.Size([4, 2, 2, 2]) Chunk 1 shape: torch.Size([4, 2, 2, 2]) Chunk 2 shape: torch.Size([4, 2, 2, 2]) --- Example 2: Split along Dimension 1 with custom sizes [1, 3, 2] --- Number of splits: 3 Chunk 0 shape: torch.Size([4, 1, 2, 2]) Chunk 1 shape: torch.Size([4, 3, 2, 2]) Chunk 2 shape: torch.Size([4, 2, 2, 2]) --- Example 3: Splitting with non-divisible chunk size --- Number of splits: 2 Chunk 0 shape: torch.Size([4, 4, 2, 2]) Chunk 1 shape: torch.Size([4, 2, 2, 2]) ``` --- ## Best Practices and Common Pitfalls ### 1. Memory Management: Views vs. Copies `torch.split` returns **views** of the original tensor, not deep copies. This means the returned tensors share the same underlying memory buffer as the input tensor. * **Pitfall**: Modifying an element in one of the split chunks will modify the original tensor (and vice-versa). * **Best Practice**: If you need to modify the split tensors independently without affecting the source tensor, call `.clone()` on the split outputs: ```python independent_chunk = splits_equal.clone() ``` ### 2. `torch.split` vs. `torch.chunk` PyTorch offers another splitting function called `torch.chunk`. It is important to know when to use which: * Use **`torch.split`** when you want to define the **size of each chunk** (e.g., "give me chunks of size 2"). * Use **`torch.chunk`** when you want to define the **number of chunks** (e.g., "split this tensor into 3 parts"). ### 3. Sum of Sections Must Match Dimension Size When passing a list of integers to `split_size_or_sections`, the sum of the integers in the list must exactly equal the size of the dimension you are splitting. * **Pitfall**: If your tensor has size `6` along `dim=1`, calling `torch.split(tensor, [2, 3], dim=1)` will raise a `RuntimeError` because $2 + 3 = 5 \neq 6$. * **Best Practice**: Always verify that your custom split list accounts for all elements along the target dimension.
← Pytorch Torch SqrtPytorch Torch Sparse_Csc_Tenso β†’