YouTip LogoYouTip

Pytorch Torch Argwhere

# PyTorch torch.argwhere Function * * * [![Image 3: Pytorch torch Reference Manual]( Pytorch torch Reference Manual]( `torch.argwhere` is a function in PyTorch used to return the indices of elements that satisfy a condition. It returns the indices of elements with True (non-zero) values in the input tensor. ### Function Definition torch.argwhere(input) **Parameters**: * `input` (Tensor): Input tensor. **Return Value**: * `torch.Tensor`: Returns a 2D tensor where each row is an index of an element that satisfies the condition. * * * ## Usage Examples ## Example import torch # Create a tensor x = torch.tensor([[1,0,2], [0,3,0], [4,5,0]]) print("Original tensor:") print(x) # Return indices of non-zero elements indices = torch.argwhere(x) print("nIndices of non-zero elements:") print(indices) Output: Original tensor: tensor([[1, 0, 2], [0, 3, 0], [4, 5, 0]])Indices of non-zero elements: tensor([[0, 0], [0, 2], [1, 1], [2, 0], [2, 1]]) ## Example import torch # Boolean condition x = torch.tensor([[True,False,True], [False,True,False], [True,True,False]]) print("Boolean tensor:") print(x) indices = torch.argwhere(x) print("nIndices of True values:") print(indices) Output: Boolean tensor: tensor([[True, False, True], [False, True, False], [True, True, False]])Indices of True values: tensor([[0, 0], [0, 2], [1, 1], [2, 0], [2, 1]]) ## Example import torch # Find elements greater than a certain value x = torch.randn(3,4) threshold =0 print("Original tensor:") print(x) # Find indices of elements greater than threshold indices = torch.argwhere(x > threshold) print(f"nIndices of elements greater than {threshold}:") print(indices) # Can also use nonzero function with same effect indices2 = torch.nonzero(x > threshold) print("nResult using nonzero:") print(indices2) Output: Original tensor: tensor([[-1.2345, 0.5678, -0.8901, 1.2345], [ 0.3456, -0.6789, 0.9012, -0.1234], [-0.5678, 1.2345, -0.3456, 0.7890]])Indices of elements greater than 0: tensor([[0, 1], [0, 3], [1, 0], [1, 2], [2, 1], [2, 3]])Result using nonzero: tensor([[0, 1], [0, 3], [1, 0], [1, 2], [2, 1], [2, 3]]) ## Example import torch # 1D tensor x = torch.tensor([1,0,0,4,0,5,0]) indices = torch. ```
← Pytorch Torch AsarrayPytorch Torch Argmin β†’