Example
import torch
# Using on a 3D tensor
x = torch.arange(24).reshape(2, 3, 4)
print("Original shape:", x.shape)
# Fetch elements along dim=1
indices = torch.tensor([[0, 1, 2],
[2, 0, 1]])
y = torch.take_along_dim(x, indices, dim=1)
print("Indices shape:", indices.shape)
print("Result shape:", y.shape)
print("nResult:")
print(y)
# Using on a 3D tensor
x = torch.arange(24).reshape(2, 3, 4)
print("Original shape:", x.shape)
# Fetch elements along dim=1
indices = torch.tensor([[0, 1, 2],
[2, 0, 1]])
y = torch.take_along_dim(x, indices, dim=1)
print("Indices shape:", indices.shape)
print("Result shape:", y.shape)
print("nResult:")
print(y)
Output:
Original shape: torch.Size([2, 3, 4])Indices shape: torch.Size([2, 3])Result shape: torch.Size([2, 3, 4])Result: tensor([[[ 0, 1, 2, 3], [ 4, 5, 6, 7], [ 8, 9, 10, 11]], [[16, 17, 18, 19], [12, 13, 14, 15], [20, 21, 22, 23]]])
## Example
import torch
# Application: select specific elements along the batch dimension
# For example, selecting specific key-value pairs in attention mechanism
batch_size =2
num_heads =3
seq_len =4
head_dim =5
# Simulate attention weights for query
attn_weights = torch.randn(batch_size, num_heads, seq_len)
# Get top-k indices for each head
k =2
indices = torch.argsort(attn_weights, dim=-1, descending=True)[..., :k]
print("Indices shape:", indices.shape)
# Simulate value tensor
value = torch.randn(batch_size, num_heads, seq_len, head_dim)
# Fetch corresponding values along seq_len dimension
selected_value = torch.take_along_dim(value, indices.unsqueeze(-1).expand(-1, -1, -1, head_dim), dim=2)
print("Value shape:", value.shape)
print("Selected Value shape:", selected_value.shape)
Output:
Indices shape: torch.Size([2, 3, 4])Value shape: torch.Size([2, 3, 4, 5])Selected Value shape: torch.Size([2, 3, 2, 5])
* * *
Note: `torch.take_along_dim` allows indexing by dimension, which is more flexible than `torch.take`, since the latter always treats the tensor as one-dimensional.
* * Pytorch torch Reference Manual](#)
YouTip