PyTorch torch.index_add Function
\\n\\n Pytorch torch Reference Manual
torch.index_add is a function in PyTorch used to add values from a source tensor to specified index positions. It adds the values of source at the index positions specified by index along the specified dimension dim.
Function Definition
\\n\\ntorch.index_add(input, dim, index, source, *, alpha=1)
\\nParameters:
\\n\\n- \\n
input(Tensor): Input tensor. \\n dim(int): Dimension of the index. \\n index(Tensor): One-dimensional integer tensor specifying the positions to add to. \\n source(Tensor): Source tensor, the values to be added. \\n alpha(float, optional): Scaling factor for source, defaults to 1. \\n
Return Value:
\\n\\n- \\n
torch.Tensor: Returns the modified tensor. \\n
\\n\\n
Usage Examples
\\n\\nExample
\\n\\nimport torch\\n\\n# Create Input Tensor\\n\\ninput= torch.randn(4,5)\\n\\n# Create Index and Source\\n\\n index = torch.tensor([0,2,3])\\n\\n source = torch.randn(3,5)\\n\\n# Along dim=0 Addition\\n\\n output = torch.index_add(input, dim=0, index=index, source=source)\\n\\nprint("InputShape:",input.shape)\\n\\nprint("Index:", index)\\n\\nprint("Source Shape:", source.shape)\\n\\nprint("Result Shape:", output.shape)\\n\\nprint("nResult:")\\n\\nprint(output)\\n\\nThe output result is:
\\n\\nInput shape: torch.Size([4, 5])Index: tensor([0, 2, 3])Source shape: torch.Size([3, 5])Result shape: torch.Size([4, 5])Result: tensor([[ 1.8435, 0.3463, -0.1024, 0.5678, 0.1234], [-0.5678, 0.8901, -0.2345, 0.6789, -0.1234], [ 2.3456, 0.4567, 0.7890, -0.3456, 0.5678], [-0.7890, 1.2345, 0.3456, -0.8901, 0.2345]])
\\n\\nExample
\\n\\nimport torch\\n\\n# Scaling Source Using the alpha Parameter\\n\\ninput= torch.zeros(5)\\n\\n index = torch.tensor([0,2,4])\\n\\n source = torch.tensor([10,20,30])\\n\\n# alpha=2 Indicates Adding Source * 2 After Addition\\n\\n output = torch.index_add(input, dim=0, index=index, source=source, alpha=2)\\n\\nprint("Input:",input)\\n\\nprint("Source:", source)\\n\\nprint("alpha=2 Result After:", output)\\n\\nThe output result is:
\\n\\nInput: tensor([0., 0., 0., 0., 0.])Source: tensor([10., 20., 30.]) Result after alpha=2: tensor([20., 0., 40., 0., 60.])
\\n\\nExample
\\n\\nimport torch\\n\\n# Addition Along Other Dimensions\\n\\ninput= torch.zeros(3,4,5)\\n\\n index = torch.tensor([1,3])\\n\\n source = torch.randn(2,4,5)\\n\\n# Along dim=1 Addition\\n\\n output = torch.index_add(input, dim=1, index=index, source=source)\\n\\nprint("InputShape:",input.shape)\\n\\nprint("IndexShape:", index.shape)\\n\\nprint("Source Shape:", source.shape)\\n\\nprint("Result Shape:", output.shape)\\n\\nThe output result is:
\\n\\nInput shape: torch.Size([3, 4, 5])Index shape: torch.Size()Source shape: torch.Size([2, 4, 5])Result shape: torch.Size([3, 4, 5])
\\n\\nExample
\\n\\nimport torch\\n\\n# Applications in Neural Networks: Attention Mechanism\\n\\n# Assuming Multiple Keys-valueCorrect, Needs to Be Aggregated to the Query\\n\\n# Simulating Query and Key-value\\n\\n num_queries =2\\n\\n num_kv =4\\n\\n dim =3\\n\\n# queryIndex of\\n\\n query_idx = torch.tensor([0,1])\\n\\n# Corresponding Value\\n\\n values = torch.randn(num_queries, dim) * 10\\n\\n# Output\\n\\n output = torch.zeros(num_kv, dim)\\n\\n# Adding Values to Corresponding Positions\\n\\n output = torch.index_add(output, dim=0, index=query_idx, source=values)\\n\\nprint("QueryIndex:", query_idx)\\n\\nprint("Values:", values)\\n\\nprint("Aggregated Result:", output)\\n\\n\\n\\n
Note: torch.index_add does not modify the original input tensor, but returns a new tensor. If there are duplicate indices in index, the values will be accumulated. The alpha parameter can be used to scale the source values.
\\n\\n
YouTip