PyTorch torch.is_inference_mode_enabled Function | CaiNiao Tutorial
PyTorch torch Reference Manual
torch.is_inference_mode_enabled is a PyTorch function used to check whether inference mode is currently enabled. It returns a boolean indicating whether the current context is within inference_mode.
This is useful when writing code that needs to execute different logic based on the inference mode state.
Function Definition
torch.is_inference_mode_enabled()
Parameters:
- No parameters.
Returns:
- Returns a boolean:
Trueif inference mode is currently enabled; otherwiseFalse.
Usage Examples
Example 1: Basic Usage
Instance
import torch
# By default, inference mode is disabled
print("Default state:", torch.is_inference_mode_enabled())
# Inside inference_mode context
with torch.inference_mode():
print("Inside inference_mode:", torch.is_inference_mode_enabled())
# After exiting context
print("After exit:", torch.is_inference_mode_enabled())
Output:
Default state: False
Inside inference_mode: True
After exit: False
Example 2: Compare is_grad_enabled and is_inference_mode_enabled
Instance
import torch
# Inside no_grad context
with torch.no_grad():
print("Inside no_grad:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# Inside inference_mode context
with torch.inference_mode():
print("Inside inference_mode:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
# Default state
print("Default state:")
print(" is_grad_enabled:", torch.is_grad_enabled())
print(" is_inference_mode_enabled:", torch.is_inference_mode_enabled())
Output:
Inside no_grad:
is_grad_enabled: False
is_inference_mode_enabled: False
Inside inference_mode:
is_grad_enabled: False
is_inference_mode_enabled: True
Default state:
is_grad_enabled: True
is_inference_mode_enabled: False
Example 3: Use in Conditional Checks
Instance
import torch
import torch.nn as nn
def forward_pass(x, model):
"""Execute different optimizations based on inference mode"""
if torch.is_inference_mode_enabled():
print("Using inference mode optimization")
return model(x)
elif torch.is_grad_enabled():
print("Training mode")
return model(x)
else:
print("Eval mode")
return model(x)
model = nn.Linear(10,5)
x = torch.randn(5,10)
# Test different states
print("=== Training Mode ===")
with torch.enable_grad():
result = forward_pass(x, model)
print("n=== Inference Mode ===")
with torch.inference_mode():
result = forward_pass(x, model)
print("n=== Eval Mode ===")
with torch.no_grad():
result = forward_pass(x, model)
Output:
=== Training Mode ===
Training mode
=== Inference Mode ===
Using inference mode optimization
=== Eval Mode ===
Eval mode
Example 4: Detection in Custom Modules
Instance
import torch
import torch.nn as nn
class OptimizedLayer(nn.Module):
def __init__(self):
super().__init__()
self.weight = nn.Parameter(torch.randn(20,10))
def forward(self, x):
# Detect inference mode and apply optimization
if torch.is_inference_mode_enabled():
# Inference mode: use more efficient computation
return torch.mm(x, self.weight)
elif torch.is_grad_enabled():
# Training mode: retain gradient computation
return torch.mm(x, self.weight)
else:
# Eval mode: no gradient but can use optimization
with torch.no_grad():
return torch.mm(x, self.weight)
layer = OptimizedLayer()
x = torch.randn(5,10)
print("=== Training Mode ===")
with torch.enable_grad():
_ = layer(x)
print("n=== Inference Mode ===")
with torch.inference_mode():
_ = layer(x)
print("n=== Eval Mode ===")
with torch.no_grad():
_ = layer(x)
Output:
=== Training Mode ===
Training mode
=== Inference Mode ===
Inference mode: use more efficient computation
=== Eval Mode ===
Eval mode
Related Functions
torch.inference_mode(): Context manager to enable inference mode.torch.no_grad(): Context manager to disable gradient calculation.torch.is_grad_enabled(): Check if gradient calculation is enabled.
Notes
is_inference_mode_enabledis a read-only function that doesn't modify any state.- It specifically detects
inference_mode, notno_grad. - Inside
inference_mode,is_grad_enabledalso returnsFalse, but the reverse is not true. - Use this function to distinguish between different operation modes when writing generic code.
YouTip