YouTip LogoYouTip

Pytorch Torch Vmap

## PyTorch `torch.vmap` Tutorial `torch.vmap` is the vectorizing map function in PyTorch. It accepts a function as input and returns a new, vectorized function that automatically applies the operation over batch dimensions of the input tensors. Similar to `vmap` in JAX, `torch.vmap` eliminates the need for manual looping or complex tensor reshaping, allowing you to write code for single samples and seamlessly scale it to batches. This often results in significant performance improvements because the operations are batched at the C++ level. --- ## Syntax and Parameters ### Function Definition ```python torch.vmap(func, in_dims=0, out_dims=0, randomness='error') ``` ### Parameter Descriptions * **`func`** *(callable)*: The function to be vectorized. It must take one or more tensors as arguments and return one or more tensors. * **`in_dims`** *(int | tuple | dict, optional)*: Specifies which dimension of the input tensors should be mapped over (treated as the batch dimension). * If it is an integer, all input tensors are assumed to have the batch dimension at that index. * If it is a tuple, each element corresponds to the batch dimension of the respective input argument. * Default: `0`. * **`out_dims`** *(int | tuple, optional)*: Specifies where the mapped (batch) dimension should appear in the output tensors. * Default: `0`. * **`randomness`** *(str, optional)*: Specifies how random operations (like `torch.randn` or `torch.rand`) inside `func` should behave across the batch. * `"error"` (Default): Raises an error if any random operation is encountered. This ensures deterministic behavior. * `"different"`: Generates different random values for each batch element. * `"same"`: Generates the same random values across all batch elements. --- ## Code Examples ### Example 1: Basic Vectorization This example demonstrates how to vectorize a simple function that operates on a 1D tensor so that it can process a 2D batch of tensors. ```python import torch # Define a simple function designed for a single sample def simple_func(x): return x * 2 + 1 # Vectorize the function using vmap vectorized_func = torch.vmap(simple_func) # Create a batch of inputs (batch dimension is 0, shape: [batch_size, features]) batch_input = torch.randn(4, 3) # Apply the vectorized function output = vectorized_func(batch_input) print("Input shape:", batch_input.shape) print("Output shape:", output.shape) print("Output:\n", output) ``` **Output:** ```text Input shape: torch.Size([4, 3]) Output shape: torch.Size([4, 3]) Output: tensor([[ 0.2345, 1.5678, -0.3456], [ 2.1234, -1.2345, 0.5678], [-0.8765, 1.2345, 2.3456], [ 1.5678, 0.1234, -1.2345]]) ``` --- ### Example 2: Specifying Custom Batch Dimensions (`in_dims` and `out_dims`) Sometimes, your batch dimension is not at index `0`. You can use `in_dims` and `out_dims` to specify custom batch dimensions for both inputs and outputs. ```python import torch # A function that computes the dot product of two 1D vectors def dot_product(x, y): return torch.dot(x, y) # Inputs where the batch dimension is at index 1 (shape: [features, batch_size]) x_batch = torch.randn(5, 4) y_batch = torch.randn(5, 4) # We specify in_dims=(1, 1) because the batch dimension is at index 1 for both inputs. # We specify out_dims=0 to place the batch dimension at index 0 in the output. vectorized_dot = torch.vmap(dot_product, in_dims=(1, 1), out_dims=0) output = vectorized_dot(x_batch, y_batch) print("Input X shape:", x_batch.shape) print("Input Y shape:", y_batch.shape) print("Output shape (batch dimension at 0):", output.shape) ``` **Output:** ```text Input X shape: torch.Size([5, 4]) Input Y shape: torch.Size([5, 4]) Output shape (batch dimension at 0): torch.Size() ``` --- ### Example 3: Handling Randomness When your function contains random operations, you must explicitly define the `randomness` behavior. ```python import torch def random_addition(x): # Generates a random tensor of the same shape as x return x + torch.randn_like(x) # Case 1: Different random values for each batch element vmap_diff = torch.vmap(random_addition, randomness="different") # Case 2: Same random values applied across all batch elements vmap_same = torch.vmap(random_addition, randomness="same") batch_input = torch.zeros(3, 2) print("Randomness 'different':\n", vmap_diff(batch_input)) print("Randomness 'same':\n", vmap_same(batch_input)) ``` --- ## Key Considerations & Best Practices 1. **Performance Benefits**: `torch.vmap` avoids the overhead of Python loops and performs batching natively in C++. It is highly recommended for tasks like computing Jacobians/Hessians, batching custom loss functions, or running parallel model evaluations. 2. **In-place Operations**: Functions passed to `torch.vmap` should avoid in-place operations (e.g., `x.add_()`). In-place modifications on inputs or intermediate tensors can lead to runtime errors. 3. **Compatibility**: While `torch.vmap` supports most standard PyTorch operators, some complex third-party operations or custom C++ extensions might not support batching rules out of the box.
← Pytorch Torch VstackPytorch Torch View_As_Complex β†’