Pytorch Torch Is_Grad_Enabled
* * *
[
**Parameters**:
* No parameters.
**Returns**:
* Returns a boolean: `True` if gradient computation is currently enabled; `False` otherwise.
* * *
## Usage Examples
### Example 1: Basic Usage
## Example
import torch
# By default, gradient computation is enabled
print("Default state:", torch.is_grad_enabled())
# In no_grad context
with torch.no_grad():
print("In no_grad:", torch.is_grad_enabled())
# Restore after exiting
print("After exiting no_grad:", torch.is_grad_enabled())
Output:
Default state: TrueIn no_grad: FalseAfter exiting no_grad: True
### Example 2: Using with set_grad_enabled
## Example
import torch
# Check current state
print("Current state:", torch.is_grad_enabled())
# Disable gradients
torch.set_grad_enabled(False)
print("After disabling:", torch.is_grad_enabled())
# Enable gradients
torch.set_grad_enabled(True)
print("After enabling:", torch.is_grad_enabled())
Output:
Current state: TrueAfter disabling: FalseAfter enabling: True
### Example 3: Using in Conditional Statements
## Example
import torch
def process_tensor(x):
"""Process tensor based on gradient state"""
if torch.is_grad_enabled():
print("Gradient computation enabled")
# Can perform backward propagation
y = x * 2
return y
else:
print("Gradient computation disabled")
# Fast computation to save memory
y = x * 2
return y
# Test different states
x = torch.tensor([1.0,2.0,3.0])
print("=== Gradient enabled ===")
result1 = process_tensor(x)
print("n=== Gradient disabled ===")
with torch.no_grad():
result2 = process_tensor(x)
Output:
=== Gradient enabled ===Gradient computation enabled=== Gradient disabled ===Gradient computation disabled
### Example 4: Using in Custom Layers
## Example
import torch
import torch.nn as nn
class CustomLayer(nn.Module):
def __init__ (self):
super(). __init__ ()
self.weight= nn.Parameter(torch.randn(10,10))
def forward(self, x):
# Check gradient state for different processing
if torch.is_grad_enabled():
print("Training mode")
# Normal computation during training
return torch.mm(x,self.weight)
else:
print("Inference mode")
# Can use optimized version during inference
with torch.no_grad():
return torch.mm(x,self.weight)
layer = CustomLayer()
x = torch.randn(5,10)
# Training
layer.train()
output1 = layer(x)
# Inference
layer.eval()
with torch.no_grad():
output2 = layer(x)
Output:
Training modeInference mode
* * *
## Related Functions
* `torch.no_grad()`: Context manager that disables gradient computation.
* `torch.enable_grad()`: Context manager that enables gradient computation.
* `torch.set_grad_enabled(grad)`: Sets whether gradient computation is enabled.
* * *
## Notes
* `is_grad_enabled` is a read-only function and does not change any state.
* It checks the global gradient computation state, not the `requires_grad` attribute of individual tensors.
* When writing general-purpose code, this function can be used to execute different optimization strategies based on the current state.
* * *
[![Image 2: Pytorch torch Reference Manual]( Pytorch torch Reference Manual](
YouTip