Pytorch Torch Softmax
The Softmax function is one of the most fundamental activation functions in deep learning, particularly for multi-class classification tasks. In PyTorch, `torch.nn.functional.softmax` and its object-oriented counterpart `torch.nn.Softmax` are used to convert a vector of raw, unnormalized scores (logits) into a probability distribution over predicted output classes.
---
## Introduction
The Softmax function takes an $N$-dimensional vector of real numbers and scales it so that every element lies in the range $[0, 1]$ and the sum of all elements equals $1$.
Mathematically, for an input vector $\mathbf{z} = [z_1, z_2, \dots, z_K]$, the Softmax value of the $i$-th element is defined as:
$$\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}$$
### Why is it used?
* **Probability Distribution:** It maps raw network outputs (logits) to a valid probability distribution, making it easy to interpret model predictions.
* **Differentiability:** Unlike the `argmax` function, Softmax is fully differentiable, allowing gradients to flow backward during backpropagation.
* **Multi-Class Classification:** It serves as the standard final-layer activation function for single-label, multi-class classification problems.
---
## Syntax and Parameters
PyTorch provides two primary ways to apply Softmax:
1. **Functional API:** `torch.nn.functional.softmax(input, dim=None, dtype=None)`
2. **Module API:** `torch.nn.Softmax(dim=None)`
### Parameters
| Parameter | Type | Description | Required / Optional |
| :--- | :--- | :--- | :--- |
| `input` | `torch.Tensor` | The input tensor containing logits. (Functional API only) | Required |
| `dim` | `int` | The dimension along which Softmax will be computed. | **Required** (Throws a warning/error if omitted in newer versions) |
| `dtype` | `torch.dtype` | If specified, the input tensor is cast to this type before the operation. | Optional |
### Input and Output Shapes
* **Input:** A tensor of any shape $(N, *)$.
* **Output:** A tensor of the **exact same shape** as the input, with values normalized along the specified dimension `dim`.
### Understanding the `dim` Parameter
The `dim` parameter determines the axis along which the exponentials are summed and normalized:
* `dim=-1` (or `dim=1` for a 2D tensor of shape `[batch_size, classes]`) applies Softmax across the columns (classes), which is the most common use case.
* `dim=0` applies Softmax down the rows (across different samples in a batch), which is rarely what you want in classification.
---
## Code Example
Below is a complete, runnable script demonstrating how to use Softmax on 1D and 2D tensors, how the `dim` parameter changes the output, and how to verify that the outputs sum to $1$.
```python
import torch
import torch.nn as nn
import torch.nn.functional as F
# Set random seed for reproducibility
torch.manual_seed(42)
# ----------------------------------------------------------------
# 1. 1D Tensor Example (Single Sample)
# ----------------------------------------------------------------
logits_1d = torch.tensor([1.0, 2.0, 5.0])
print("--- 1D Tensor Example ---")
print(f"Raw Logits: {logits_1d}")
# Apply functional softmax along the only dimension (dim=0)
probabilities_1d = F.softmax(logits_1d, dim=0)
print(f"Softmax Probabilities: {probabilities_1d}")
print(f"Sum of Probabilities: {probabilities_1d.sum().item():.4f}\n")
# ----------------------------------------------------------------
# 2. 2D Tensor Example (Batch of Samples)
# ----------------------------------------------------------------
# Simulating a batch of 2 samples with 4 class logits each
logits_2d = torch.tensor([
[2.0, 1.0, 0.1, 3.0], # Sample 1
[0.5, 4.5, 1.2, 0.2] # Sample 2
])
print("--- 2D Tensor Example ---")
print(f"Raw Logits (Shape: {logits_2d.shape}):\n{logits_2d}\n")
# Apply Softmax across classes (dim=-1 or dim=1)
probabilities_2d = F.softmax(logits_2d, dim=-1)
print("Softmax Probabilities (dim=-1):")
print(probabilities_2d)
# Verify that each row sums to 1.0
row_sums = probabilities_2d.sum(dim=-1)
print(f"Sum along dim=-1 for each sample: {row_sums}\n")
# ----------------------------------------------------------------
# 3. Object-Oriented Module API (nn.Softmax)
# ----------------------------------------------------------------
print("--- nn.Softmax Module Example ---")
# Instantiate the Softmax module specifying the dimension
softmax_layer = nn.Softmax(dim=-1)
# Pass the tensor through the module
output = softmax_layer(logits_2d)
print(f"Module Output:\n{output}")
```
---
## Best Practices and Common Pitfalls
### 1. Avoid `nn.Softmax` with `nn.CrossEntropyLoss`
A very common beginner mistake is manually applying Softmax to the model's final layer when using PyTorch's `nn.CrossEntropyLoss`.
* **Why?** PyTorchβs `nn.CrossEntropyLoss` internally combines `nn.LogSoftmax` and `nn.NLLLoss` (Negative Log-Likelihood Loss) into a single class.
* **The Pitfall:** If you apply Softmax in your model's forward pass and then pass those probabilities to `nn.CrossEntropyLoss`, you are applying Softmax twice. This will severely degrade training performance and gradient stability.
* **The Solution:** Your model should output raw, unnormalized **logits**. Only apply Softmax explicitly when you are running inference to get human-readable probabilities.
```python
# INCORRECT for training
model_output = F.softmax(logits, dim=-1)
loss = nn.CrossEntropyLoss()(model_output, targets) # Double softmax!
# CORRECT for training
model_output = logits # Raw outputs
loss = nn.CrossEntropyLoss()(model_output, targets)
```
### 2. Always Explicitly Define `dim`
In older versions of PyTorch, omitting the `dim` parameter was allowed and defaulted to a behavior that analyzed the tensor's shape. In modern PyTorch, omitting `dim` triggers a deprecation warning or runtime error. Always explicitly pass `dim=-1` (the last dimension) or the specific dimension you intend to normalize.
### 3. Numerical Stability: Softmax vs. LogSoftmax
When calculating loss functions manually or working with extremely large or small logit values, standard Softmax can suffer from underflow or overflow (numerical instability).
* If you need to compute log-probabilities, use `torch.nn.functional.log_softmax` instead of `torch.log(torch.softmax(...))`. `log_softmax` uses a highly optimized, mathematically stable implementation that avoids exponentiating large numbers directly.
YouTip