YouTip LogoYouTip

Pytorch Torch Softmax

The Softmax function is one of the most fundamental activation functions in deep learning, particularly for multi-class classification tasks. In PyTorch, `torch.nn.functional.softmax` and its object-oriented counterpart `torch.nn.Softmax` are used to convert a vector of raw, unnormalized scores (logits) into a probability distribution over predicted output classes. --- ## Introduction The Softmax function takes an $N$-dimensional vector of real numbers and scales it so that every element lies in the range $[0, 1]$ and the sum of all elements equals $1$. Mathematically, for an input vector $\mathbf{z} = [z_1, z_2, \dots, z_K]$, the Softmax value of the $i$-th element is defined as: $$\text{Softmax}(z_i) = \frac{e^{z_i}}{\sum_{j=1}^{K} e^{z_j}}$$ ### Why is it used? * **Probability Distribution:** It maps raw network outputs (logits) to a valid probability distribution, making it easy to interpret model predictions. * **Differentiability:** Unlike the `argmax` function, Softmax is fully differentiable, allowing gradients to flow backward during backpropagation. * **Multi-Class Classification:** It serves as the standard final-layer activation function for single-label, multi-class classification problems. --- ## Syntax and Parameters PyTorch provides two primary ways to apply Softmax: 1. **Functional API:** `torch.nn.functional.softmax(input, dim=None, dtype=None)` 2. **Module API:** `torch.nn.Softmax(dim=None)` ### Parameters | Parameter | Type | Description | Required / Optional | | :--- | :--- | :--- | :--- | | `input` | `torch.Tensor` | The input tensor containing logits. (Functional API only) | Required | | `dim` | `int` | The dimension along which Softmax will be computed. | **Required** (Throws a warning/error if omitted in newer versions) | | `dtype` | `torch.dtype` | If specified, the input tensor is cast to this type before the operation. | Optional | ### Input and Output Shapes * **Input:** A tensor of any shape $(N, *)$. * **Output:** A tensor of the **exact same shape** as the input, with values normalized along the specified dimension `dim`. ### Understanding the `dim` Parameter The `dim` parameter determines the axis along which the exponentials are summed and normalized: * `dim=-1` (or `dim=1` for a 2D tensor of shape `[batch_size, classes]`) applies Softmax across the columns (classes), which is the most common use case. * `dim=0` applies Softmax down the rows (across different samples in a batch), which is rarely what you want in classification. --- ## Code Example Below is a complete, runnable script demonstrating how to use Softmax on 1D and 2D tensors, how the `dim` parameter changes the output, and how to verify that the outputs sum to $1$. ```python import torch import torch.nn as nn import torch.nn.functional as F # Set random seed for reproducibility torch.manual_seed(42) # ---------------------------------------------------------------- # 1. 1D Tensor Example (Single Sample) # ---------------------------------------------------------------- logits_1d = torch.tensor([1.0, 2.0, 5.0]) print("--- 1D Tensor Example ---") print(f"Raw Logits: {logits_1d}") # Apply functional softmax along the only dimension (dim=0) probabilities_1d = F.softmax(logits_1d, dim=0) print(f"Softmax Probabilities: {probabilities_1d}") print(f"Sum of Probabilities: {probabilities_1d.sum().item():.4f}\n") # ---------------------------------------------------------------- # 2. 2D Tensor Example (Batch of Samples) # ---------------------------------------------------------------- # Simulating a batch of 2 samples with 4 class logits each logits_2d = torch.tensor([ [2.0, 1.0, 0.1, 3.0], # Sample 1 [0.5, 4.5, 1.2, 0.2] # Sample 2 ]) print("--- 2D Tensor Example ---") print(f"Raw Logits (Shape: {logits_2d.shape}):\n{logits_2d}\n") # Apply Softmax across classes (dim=-1 or dim=1) probabilities_2d = F.softmax(logits_2d, dim=-1) print("Softmax Probabilities (dim=-1):") print(probabilities_2d) # Verify that each row sums to 1.0 row_sums = probabilities_2d.sum(dim=-1) print(f"Sum along dim=-1 for each sample: {row_sums}\n") # ---------------------------------------------------------------- # 3. Object-Oriented Module API (nn.Softmax) # ---------------------------------------------------------------- print("--- nn.Softmax Module Example ---") # Instantiate the Softmax module specifying the dimension softmax_layer = nn.Softmax(dim=-1) # Pass the tensor through the module output = softmax_layer(logits_2d) print(f"Module Output:\n{output}") ``` --- ## Best Practices and Common Pitfalls ### 1. Avoid `nn.Softmax` with `nn.CrossEntropyLoss` A very common beginner mistake is manually applying Softmax to the model's final layer when using PyTorch's `nn.CrossEntropyLoss`. * **Why?** PyTorch’s `nn.CrossEntropyLoss` internally combines `nn.LogSoftmax` and `nn.NLLLoss` (Negative Log-Likelihood Loss) into a single class. * **The Pitfall:** If you apply Softmax in your model's forward pass and then pass those probabilities to `nn.CrossEntropyLoss`, you are applying Softmax twice. This will severely degrade training performance and gradient stability. * **The Solution:** Your model should output raw, unnormalized **logits**. Only apply Softmax explicitly when you are running inference to get human-readable probabilities. ```python # INCORRECT for training model_output = F.softmax(logits, dim=-1) loss = nn.CrossEntropyLoss()(model_output, targets) # Double softmax! # CORRECT for training model_output = logits # Raw outputs loss = nn.CrossEntropyLoss()(model_output, targets) ``` ### 2. Always Explicitly Define `dim` In older versions of PyTorch, omitting the `dim` parameter was allowed and defaulted to a behavior that analyzed the tensor's shape. In modern PyTorch, omitting `dim` triggers a deprecation warning or runtime error. Always explicitly pass `dim=-1` (the last dimension) or the specific dimension you intend to normalize. ### 3. Numerical Stability: Softmax vs. LogSoftmax When calculating loss functions manually or working with extremely large or small logit values, standard Softmax can suffer from underflow or overflow (numerical instability). * If you need to compute log-probabilities, use `torch.nn.functional.log_softmax` instead of `torch.log(torch.softmax(...))`. `log_softmax` uses a highly optimized, mathematically stable implementation that avoids exponentiating large numbers directly.
← Pytorch Torch SortPytorch Torch Slice_Scatter β†’