YouTip LogoYouTip

Pytorch Torch Scatter_Reduce

PyTorch torch.scatter_reduce Function | Rookie Tutorial

Image 1: Pytorch torch Reference Manual Pytorch torch Reference Manual

torch.scatter_reduce is a function in PyTorch used to aggregate values from a source tensor to specified positions in a specified manner. It supports multiple aggregation methods such as sum, product, maximum, minimum, etc.

Function Definition

torch.scatter_reduce(input, dim, index, src, reduce='sum', *, include_self=True)
Parameters:

  • input (Tensor): The input tensor.
  • dim (int): The dimension along which to perform the aggregation.
  • index (Tensor): The index tensor specifying where to scatter the values of src into input.
  • src (Tensor): The source tensor containing the values to be aggregated.
  • reduce (str): The method of aggregation; optional values are 'sum', 'prod', 'mean', 'amax', 'amin', 'multiply'. Default is 'sum'.
  • include_self (bool, optional): Whether to include the original value at the indexed position in the aggregation. Default is True.

Return Value:

  • torch.Tensor: Returns the tensor after aggregation.

Usage Examples

Example

import torch

# Create input tensor

input= torch.ones(3,5)

# Create index and source

 index = torch.tensor([[0,1,2,0,0],
[1,2,0,1,2],
[2,0,1,2,0]])

 src = torch.tensor([[1,1,1,1,1],
[2,2,2,2,2],
[3,3,3,3,3]])

# Use sum aggregation

 output = torch.scatter_reduce(input, dim=0, index=index, src=src,reduce='sum')

print("Input:")

print(input)

print("nAggregation method: sum")

print("Result:")

print(output)

The output result is:

Input: tensor([[1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.], [1., 1., 1., 1., 1.]])
Aggregation method: sum
Result: tensor([[4., 2., 3., 4., 4.], [2., 2., 2., 2., 2.], [3., 3., 2., 3., 3.]])

Example

import torch

# Test different aggregation methods

input= torch.zeros(3,3)

 index = torch.tensor([[0,0,0],
[1,1,1],
[2,2,2]])

 src = torch.tensor([[2,3,4],
[5,6,7],
[8,9,10]])

# Use prod (product)

 output_prod = torch.scatter_reduce(input,0, index, src,reduce='prod')

print("prod Aggregation:")

print(output_prod)

# Use amax (maximum)

 output_max = torch.scatter_reduce(input,0, index, src,reduce='amax')

print("namax Aggregation:")

print(output_max)

# Use amin (minimum)

 output_min = torch.scatter_reduce(input,0, index, src,reduce='amin')

print("namin Aggregation:")

print(output_min)

The output result is:

prod Aggregation: tensor([[ 2., 3., 4.], [ 5., 6., 7.], [ 8., 9., 10.]])
amax Aggregation: tensor([[ 2., 3., 4.], [ 5., 6., 7.], [ 8., 9., 10.]])
amin Aggregation: tensor([[ 2., 3., 4.], [ 5., 6., 7.], [ 8., 9., 10.]])

Example

import torch

# Use include_self parameter

input= torch.tensor([1.0,2.0,3.0])

 index = torch.tensor([0,0,0])

 src = torch.tensor([10.0,20.0,30.0])

# Include self-values (default)

 output1 = torch.scatter_reduce(input,0, index, src,reduce='sum', include_self=True)

print("include_self=True:", output1)

# Do not include self-values

 output2 = torch.scatter_reduce(input,0, index, src,reduce='sum', include_self=False)

print("include_self=False:", output2)

The output result is:

include_self=True: tensor([66., 2., 3.])
include_self=False: tensor([60., 2., 3.])


Note: torch.scatter_reduce does not modify the original input tensor but returns a new tensor. Multiple indices can point to the same location, and values will be aggregated according to the specified aggregation method. include_self=False is useful in certain scenarios, such as avoiding repeated computation of node features in graph neural networks.


Image 2: Pytorch torch Reference Manual Pytorch torch Reference Manual

← Pytorch Torch SeedPytorch Torch Scatter β†’