YouTip LogoYouTip

Pytorch Torch Trace

## PyTorch `torch.trace` Tutorial In linear algebra, the **trace** of a matrix is the sum of its main diagonal elements. PyTorch provides a built-in function, `torch.trace`, to compute this value efficiently. This tutorial covers the syntax, parameters, behavior, and practical examples of using `torch.trace` in PyTorch. --- ## Function Definition The `torch.trace` function takes a 2D tensor (matrix) and returns the sum of the elements along its main diagonal. ### Syntax ```python torch.trace(input) -> Tensor ``` ### Parameters * **`input`**: A 2D input tensor (matrix). ### Return Value * A 1D, single-element tensor containing the sum of the main diagonal elements. --- ## Code Examples ### Example 1: Calculating the Trace of a Square Matrix A square matrix has an equal number of rows and columns. The trace is calculated as the sum of elements where the row index equals the column index ($A_{0,0} + A_{1,1} + A_{2,2} + \dots$). ```python import torch # Create a 3x3 square matrix a = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) # Calculate the trace (1 + 5 + 9) y = torch.trace(a) print("Input Matrix:") print(a) print("\nTrace:") print(y) ``` **Output:** ```text Input Matrix: tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) Trace: tensor(15) ``` --- ### Example 2: Calculating the Trace of a Non-Square Matrix `torch.trace` can also be applied to non-square matrices. It will sum the elements along the diagonal starting from the top-left corner $(0,0)$ up to the boundary of the smaller dimension. ```python import torch # Create a 2x3 non-square matrix a = torch.tensor([[1, 2, 3], [4, 5, 6]]) # Calculate the trace (1 + 5) y = torch.trace(a) print("Input Matrix:") print(a) print("\nTrace:") print(y) ``` **Output:** ```text Input Matrix: tensor([[1, 2, 3], [4, 5, 6]]) Trace: tensor(6) ``` --- ## Important Considerations & Limitations ### 1. Dimensionality Constraint The `torch.trace` function is strictly designed for **2D tensors (matrices)**. If you pass a 1D tensor or a tensor with more than 2 dimensions, PyTorch will raise a `RuntimeError`. #### Error Example (Higher-Dimensional Tensor): ```python import torch # Attempting to use torch.trace on a 4D tensor will raise an error try: a = torch.randn(3, 4, 4, 5) y = torch.trace(a) except RuntimeError as e: print(f"RuntimeError: {e}") ``` **Output:** ```text RuntimeError: trace is not implemented for >= 3D tensors, use torch.diagonal().sum() instead ``` ### 2. Alternative for Multi-Dimensional Tensors If you need to compute the trace over specific dimensions of a multi-dimensional tensor, you should use `torch.diagonal()` combined with `.sum()`. ```python import torch # Create a 4D tensor of shape (3, 4, 4, 5) a = torch.randn(3, 4, 4, 5) # Compute the sum of the diagonal elements over the 2nd and 3rd dimensions (indices 1 and 2) # This mimics the trace operation for higher-dimensional tensors y = torch.diagonal(a, dim1=1, dim2=2).sum(-1) print("Output shape:", y.shape) ``` **Output:** ```text Output shape: torch.Size([3, 5]) ```
← Pytorch Torch TrapezoidPytorch Torch Tile β†’