Pytorch Torch Equal
## PyTorch `torch.equal` Function
The `torch.equal` function in PyTorch is a utility used to check if two tensors are completely identical. It returns a single boolean value: `True` if both tensors have the exact same shape and elements, and `False` otherwise.
---
### Function Definition
```python
torch.equal(input, other) -> bool
```
#### Parameters:
* **`input`** *(Tensor)*: The first tensor to compare.
* **`other`** *(Tensor)*: The second tensor to compare.
#### Returns:
* **`bool`**: `True` if the tensors are equal in both shape and content, `False` otherwise.
---
## Code Examples
The following example demonstrates how to use `torch.equal` to compare identical and non-identical tensors.
```python
import torch
# Create two identical tensors
x = torch.tensor([1, 2, 3])
y = torch.tensor([1, 2, 3])
# Check if they are equal
result = torch.equal(x, y)
print(f"Are x and y equal? {result}")
# Create two different tensors (different elements)
a = torch.tensor([1, 2, 4])
b = torch.tensor([1, 2, 3])
result2 = torch.equal(a, b)
print(f"Are a and b equal? {result2}")
# Create two tensors with different shapes
c = torch.tensor([[1, 2, 3]])
result3 = torch.equal(x, c)
print(f"Are x and c equal? {result3}")
```
### Output
```text
Are x and y equal? True
Are a and b equal? False
Are x and c equal? False
```
---
## Key Considerations
When using `torch.equal`, keep the following behaviors in mind:
### 1. No Broadcasting
Unlike element-wise comparison operators (such as `==` or `torch.eq`), `torch.equal` **does not** perform broadcasting. If the shapes of the two tensors do not match exactly, the function immediately returns `False` without raising an error.
### 2. `torch.equal` vs. `torch.eq`
It is important to distinguish between `torch.equal` and `torch.eq` (or the `==` operator):
* **`torch.equal(tensor1, tensor2)`**: Returns a **single boolean** (`True` or `False`). It checks if the entire contents and shapes are identical.
* **`torch.eq(tensor1, tensor2)`**: Returns a **boolean tensor** of the same shape (after broadcasting), representing the element-wise comparison results.
#### Comparison Example:
```python
import torch
tensor1 = torch.tensor([1, 2, 3])
tensor2 = torch.tensor([1, 5, 3])
# Element-wise comparison
print(torch.eq(tensor1, tensor2))
# Output: tensor([ True, False, True])
# Global comparison
print(torch.equal(tensor1, tensor2))
# Output: False
```
### 3. Floating-Point Precision
When comparing floating-point tensors (`torch.float32`, `torch.float64`), tiny numerical precision errors can cause `torch.equal` to return `False` even if the values are mathematically expected to be equal.
For comparing floating-point tensors with a tolerance margin, use **`torch.allclose`** instead:
```python
# Using torch.allclose for floating-point comparisons
float_tensor1 = torch.tensor([1.0, 2.0])
float_tensor2 = torch.tensor([1.0000001, 2.0])
print(torch.equal(float_tensor1, float_tensor2)) # Output: False
print(torch.allclose(float_tensor1, float_tensor2)) # Output: True
```
YouTip