YouTip LogoYouTip

Pytorch Torch Take_Along_Dim

* * Pytorch torch Reference Manual](#) `torch.take_along_dim` is a function in PyTorch used to fetch elements at index positions along a specified dimension. It retrieves values from `input` along the `dim` dimension according to the indices specified in `indices`. ### Function Definition torch.take_along_dim(input, indices, dim) **Parameters**: * `input` (Tensor): The input tensor. * `indices` (Tensor): The index tensor specifying the positions of elements to retrieve. The shape must be compatible with `input` along the `dim` dimension. * `dim` (int): The dimension along which to index. **Returns**: * `torch.Tensor`: A new tensor composed of elements retrieved according to the indices. * * * ## Usage Examples ## Example import torch # Create a 2D tensor x = torch.tensor([[1,2,3], [4,5,6], [7,8,9]]) print("Original tensor:") print(x) # Fetch elements along dim=1 indices = torch.tensor([[0,1,2], [2,1,0], [0,0,0]]) y = torch.take_along_dim(x, indices, dim=1) print("nIndices:") print(indices) print("nElements fetched along dim=1:") print(y) Output: Original tensor: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])Indices: tensor([[0, 1, 2], [2, 1, 0], [0, 0, 0]])Elements fetched along dim=1: tensor([[1, 2, 3], [6, 5, 4], [7, 7, 7]]) ## Example import torch # Fetch elements along dim=0 x = torch.tensor([[1,2,3], [4,5,6], [7,8,9]]) # Fetch different rows for each column indices = torch.tensor([[0,1,2], [2,0,1], [1,2,0]]) y = torch.take_along_dim(x, indices, dim=0) print("Original tensor:") print(x) print("nIndices:") print(indices) print("nElements fetched along dim=0:") print(y) Output: Original tensor: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])Indices: tensor([[0, 1, 2], [2, 0, 1], [1, 2, 0]])Elements fetched along dim=0: tensor([[1, 5, 9], [7, 2, 6], [4, 8, 3]])

Example

import torch

# Using on a 3D tensor
x = torch.arange(24).reshape(2, 3, 4)
print("Original shape:", x.shape)

# Fetch elements along dim=1
indices = torch.tensor([[0, 1, 2],
                        [2, 0, 1]])
y = torch.take_along_dim(x, indices, dim=1)

print("Indices shape:", indices.shape)
print("Result shape:", y.shape)
print("nResult:")
print(y)

Output:

Original shape: torch.Size([2, 3, 4])Indices shape: torch.Size([2, 3])Result shape: torch.Size([2, 3, 4])Result: tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[16, 17, 18, 19], [12, 13, 14, 15], [20, 21, 22, 23]]])

## Example

import torch

# Application: select specific elements along the batch dimension

# For example, selecting specific key-value pairs in attention mechanism

batch_size =2

 num_heads =3

 seq_len =4

 head_dim =5

# Simulate attention weights for query

 attn_weights = torch.randn(batch_size, num_heads, seq_len)

# Get top-k indices for each head

 k =2

 indices = torch.argsort(attn_weights, dim=-1, descending=True)[..., :k]

print("Indices shape:", indices.shape)

# Simulate value tensor

 value = torch.randn(batch_size, num_heads, seq_len, head_dim)

# Fetch corresponding values along seq_len dimension

 selected_value = torch.take_along_dim(value, indices.unsqueeze(-1).expand(-1, -1, -1, head_dim), dim=2)

print("Value shape:", value.shape)

print("Selected Value shape:", selected_value.shape)

Output:

Indices shape: torch.Size([2, 3, 4])Value shape: torch.Size([2, 3, 4, 5])Selected Value shape: torch.Size([2, 3, 2, 5])

* * *

Note: `torch.take_along_dim` allows indexing by dimension, which is more flexible than `torch.take`, since the latter always treats the tensor as one-dimensional.

* * Pytorch torch Reference Manual](#)
← Pytorch Torch TanhPytorch Torch T β†’