Pytorch Torch Trace
## PyTorch `torch.trace` Tutorial
In linear algebra, the **trace** of a matrix is the sum of its main diagonal elements. PyTorch provides a built-in function, `torch.trace`, to compute this value efficiently.
This tutorial covers the syntax, parameters, behavior, and practical examples of using `torch.trace` in PyTorch.
---
## Function Definition
The `torch.trace` function takes a 2D tensor (matrix) and returns the sum of the elements along its main diagonal.
### Syntax
```python
torch.trace(input) -> Tensor
```
### Parameters
* **`input`**: A 2D input tensor (matrix).
### Return Value
* A 1D, single-element tensor containing the sum of the main diagonal elements.
---
## Code Examples
### Example 1: Calculating the Trace of a Square Matrix
A square matrix has an equal number of rows and columns. The trace is calculated as the sum of elements where the row index equals the column index ($A_{0,0} + A_{1,1} + A_{2,2} + \dots$).
```python
import torch
# Create a 3x3 square matrix
a = torch.tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
# Calculate the trace (1 + 5 + 9)
y = torch.trace(a)
print("Input Matrix:")
print(a)
print("\nTrace:")
print(y)
```
**Output:**
```text
Input Matrix:
tensor([[1, 2, 3],
[4, 5, 6],
[7, 8, 9]])
Trace:
tensor(15)
```
---
### Example 2: Calculating the Trace of a Non-Square Matrix
`torch.trace` can also be applied to non-square matrices. It will sum the elements along the diagonal starting from the top-left corner $(0,0)$ up to the boundary of the smaller dimension.
```python
import torch
# Create a 2x3 non-square matrix
a = torch.tensor([[1, 2, 3],
[4, 5, 6]])
# Calculate the trace (1 + 5)
y = torch.trace(a)
print("Input Matrix:")
print(a)
print("\nTrace:")
print(y)
```
**Output:**
```text
Input Matrix:
tensor([[1, 2, 3],
[4, 5, 6]])
Trace:
tensor(6)
```
---
## Important Considerations & Limitations
### 1. Dimensionality Constraint
The `torch.trace` function is strictly designed for **2D tensors (matrices)**. If you pass a 1D tensor or a tensor with more than 2 dimensions, PyTorch will raise a `RuntimeError`.
#### Error Example (Higher-Dimensional Tensor):
```python
import torch
# Attempting to use torch.trace on a 4D tensor will raise an error
try:
a = torch.randn(3, 4, 4, 5)
y = torch.trace(a)
except RuntimeError as e:
print(f"RuntimeError: {e}")
```
**Output:**
```text
RuntimeError: trace is not implemented for >= 3D tensors, use torch.diagonal().sum() instead
```
### 2. Alternative for Multi-Dimensional Tensors
If you need to compute the trace over specific dimensions of a multi-dimensional tensor, you should use `torch.diagonal()` combined with `.sum()`.
```python
import torch
# Create a 4D tensor of shape (3, 4, 4, 5)
a = torch.randn(3, 4, 4, 5)
# Compute the sum of the diagonal elements over the 2nd and 3rd dimensions (indices 1 and 2)
# This mimics the trace operation for higher-dimensional tensors
y = torch.diagonal(a, dim1=1, dim2=2).sum(-1)
print("Output shape:", y.shape)
```
**Output:**
```text
Output shape: torch.Size([3, 5])
```
YouTip