YouTip LogoYouTip

Pytorch Torch Narrow

# Create a tensor x = torch.tensor([[1,2,3,4], [5,6,7,8], [9,10,11,12]]) # On the first dimension (rows), take 2 rows starting from index 0 y = torch.narrow(x, dim=0, start=0, length=2) print("Original tensor:") print(x) print(" Slice result:") print(y)

Output:

Original tensor:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
Slice result:
tensor([[1, 2, 3, 4],
        [5, 6, 7, 8]])

Example

import torch

# Create a tensor
x = torch.tensor([[1,2,3,4],
                  [5,6,7,8],
                  [9,10,11,12]])

# On the second dimension (columns), take 2 columns starting from index 1
y = torch.narrow(x, dim=1, start=1, length=2)

print("Original tensor:")
print(x)

print("
Slice result:")
print(y)

Output:

Original tensor:
tensor([[ 1,  2,  3,  4],
        [ 5,  6,  7,  8],
        [ 9, 10, 11, 12]])
Slice result:
tensor([[ 2,  3],
        [ 6,  7],
        [10, 11]])

Example

import torch

# Create a 3D tensor
x = torch.randn(5,6,7)

# On the first dimension, take 3 elements starting from index 2
y = torch.narrow(x, dim=0, start=2, length=3)

print("Original shape:", x.shape)
print("Shape after slicing:", y.shape)

Output:

Original shape: torch.Size([5, 6, 7])
Shape after slicing: torch.Size([3, 6, 7])

Note: torch.narrow returns a view of the original tensor, not a copy. If a copy is needed, you can use torch.narrow_copy.


Image 4: Pytorch torch Reference Manual PyTorch torch Reference Manual

← Pytorch Torch NormPytorch Torch Nanmedian β†’