\\n\\n
torch.inference_mode is a context manager used for inference mode in PyTorch. It is stricter than torch.no_grad, not only disabling gradient calculation but also disabling all tracking functions of the autograd engine.
This is more efficient than no_grad during model inference, further reducing memory usage and improving inference speed.
Function Definition
\\n\\ntorch.inference_mode(mode=True)
\\nParameters:
\\n\\n- \\n
mode(bool, optional): IfTrue(default), enables inference mode; ifFalse, exits inference mode. Defaults toTrue. \\n
Return Value:
\\n\\n- \\n
- Returns a context manager that disables gradients and autograd in that context. \\n
\\n\\n
Usage Examples
\\n\\nExample 1: Basic Usage
\\n\\nInstance
\\n\\nimport torch\\n\\nx = torch.tensor([1.0,2.0,3.0], requires_grad=True)\\n\\n# In the inference_mode context\\n\\nwith torch.inference_mode():\\n\\n y = x * 2\\n\\nprint("In inference_mode:", y.requires_grad)\\n\\n# In the no_grad context\\n\\nwith torch.no_grad():\\n\\n z = x * 2\\n\\nprint("In no_grad:", z.requires_grad)\\n\\n\\nOutput result is:
\\n\\nIn inference_mode: FalseIn no_grad: False\\n\\n\\nExample 2: Comparing no_grad and inference_mode
\\n\\nInstance
\\n\\nimport torch\\n\\n# Create tensor\\n\\n x = torch.randn(100,100)\\n\\n# In inference_mode\\n\\nwith torch.inference_mode():\\n\\n# Performs extensive computations\\n\\nfor _ in range(10):\\n\\n x = torch.mm(x, x)\\n\\n# Even after computation completes, tensors within the context cannot be used for backpropagation\\n\\n result = x.sum()\\n\\n# Check if it can be converted to a tensor requiring gradients\\n\\nprint("In inference_mode:", result.is_leaf)\\n\\n# In no_gradPerform the same computation\\n\\n x2 = torch.randn(100,100)\\n\\nwith torch.no_grad():\\n\\nfor _ in range(10):\\n\\n x2 = torch.mm(x2, x2)\\n\\n result2 = x2.sum()\\n\\nprint("In no_grad:", result2.is_leaf)\\n\\n\\nOutput result is:
\\n\\nIn inference_mode: FalseIn no_grad: True\\n\\n\\nExample 3: Model Inference
\\n\\nInstance
\\n\\nimport torch\\n\\nimport torch.nn as nn\\n\\n# Define a simple model\\n\\n model = nn.Sequential(\\n\\n nn.Linear(10,20),\\n\\n nn.ReLU(),\\n\\n nn.Linear(20,5)\\n\\n)\\n\\nmodel.eval()\\n\\n# Create input data\\n\\n x = torch.randn(100,10)\\n\\n# Use inference_mode for inference\\n\\nwith torch.inference_mode():\\n\\n output = model(x)\\n\\nprint("Output shape:", output.shape)\\n\\nprint("Output requires_grad:", output.requires_grad)\\n\\n# Can also use a decorator\\n\\n@torch.inference_mode()\\n\\ndef predict(x):\\n\\nreturn model(x)\\n\\nresult = predict(x)\\n\\nprint("Decorator approach - Output shape:", result.shape)\\n\\n\\nOutput result is:
\\n\\nOutput shape: torch.Size([100, 5])Output requires_grad: FalseDecorator approach - Output shape: torch.Size([100, 5])\\n\\n\\nExample 4: Memory Optimization Comparison
\\n\\nInstance
\\n\\nimport torch\\n\\nimport torch.nn as nn\\n\\nmodel = nn.Sequential(\\n\\n nn.Linear(1000,1000),\\n\\n nn.ReLU(),\\n\\n nn.Linear(1000,1000),\\n\\n nn.ReLU(),\\n\\n nn.Linear(1000,10)\\n\\n)\\n\\n# Test memory usage of different modes\\n\\n x = torch.randn(50,1000)\\n\\nprint("Without using any context manager:")\\n\\n _ = model(x)\\n\\nprint("nUsing no_grad:")\\n\\nwith torch.no_grad():\\n\\n _ = model(x)\\n\\nprint("nUsing inference_mode:")\\n\\nwith torch.inference_mode():\\n\\n _ = model(x)\\n\\n\\nUsing inference_mode can further optimize memory because it completely disables the autograd engine.
\\n\\n
Related Functions
\\n\\n- \\n
torch.no_grad(): Disables gradient calculation, but still retains some autograd functionality. \\ntorch.enable_grad(): Enables gradient calculation. \\ntorch.is_inference_mode_enabled(): Checks if inference mode is enabled. \\n
\\n\\n
Notes
\\n\\n- \\n
inference_modeis stricter thanno_gradand disables more features. \\n- Tensors created in
inference_modeare marked as non-leaf nodes and cannot be used for backpropagation. \\n - It is recommended to use
inference_modeduring model inference and evaluation for best performance. \\n inference_modecannot be nested withno_grad. \\n
\\n\\n
YouTip