Pytorch Torch Nn Module
[ PyTorch torch.nn Reference Manual](#)
* * *
`torch.nn.Module` is the base class for all neural network modules in PyTorch.
All custom network models should inherit from this class, which provides parameter management, device migration, model saving, and other functionalities.
### Class Definition
torch.nn.Module
### Main Attributes
* `parameters()`: Returns all learnable parameters of the model
* `named_parameters()`: Returns an iterator over parameter names and values
* `children()`: Returns sub-modules of the model
* `named_children()`: Returns an iterator over sub-module names and modules
* `modules()`: Returns all modules
* `state_dict()`: Returns a dictionary containing all parameters
* * *
## Usage Examples
### Example 1: Creating a Custom Module
Inherit from nn.Module to create a custom network:
## Instance
import torch
import torch.nn as nn
class SimpleNet(nn.Module):
def __init__ (self, input_dim, hidden_dim, output_dim):
super(SimpleNet,self). __init__ ()
# Define network layers
self.fc1= nn.Linear(input_dim, hidden_dim)
self.relu= nn.ReLU()
self.fc2= nn.Linear(hidden_dim, output_dim)
def forward(self, x):
# Define forward propagation
x =self.fc1(x)
x =self.relu(x)
x =self.fc2(x)
return x
# Instantiate model
model = SimpleNet(input_dim=784, hidden_dim=256, output_dim=10)
# Test
x = torch.randn(32,784)
output = model(x)
print("Model structure:")
print(model)
print("nInput shape:", x.shape)
print("Output shape:", output.shape)
### Example 2: Managing Parameters
Access and manage model parameters:
## Instance
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__ (self):
super(Net,self). __init__ ()
self.conv1= nn.Conv2d(3,16,3)
self.bn1= nn.BatchNorm2d(16)
self.fc= nn.Linear(16 * 14 * 14,10)
def forward(self, x):
x =self.conv1(x)
x =self.bn1(x)
x = x.view(x.size(0), -1)
x =self.fc(x)
return x
model = Net()
# Count parameters
total_params =sum(p.numel()for p in model.parameters())
trainable_params =sum(p.numel()for p in model.parameters()if p.requires_grad)
print("Total parameters:", total_params)
print("Trainable parameters:", trainable_params)
# Access specific parameters
print("nconv1 weight shape:", model.conv1.weight.shape)
print("fc bias shape:", model.fc.bias.shape)
### Example 3: Model Saving and Loading
Save and load models:
## Instance
import torch
import torch.nn as nn
import tempfile
import os
class Net(nn.Module):
def __init__ (self):
super(Net,self). __init__ ()
self.fc= nn.Linear(10,5)
def forward(self, x):
return self.fc(x)
model = Net()
# Save entire model
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth')as f:
torch.save(model, f.name)
path_full = f.name
# Save state_dict (recommended way)
with tempfile.NamedTemporaryFile(delete=False, suffix='.pth')as f:
torch.save(model.state_dict(), f.name)
path_state = f.name
# Load model
loaded_model = Net()
loaded_model.load_state_dict(torch.load(path_state))
loaded_model.eval()
# Test loaded model
x = torch.randn(2,10)
output1 = model(x)
output2 = loaded_model(x)
print("Original output:", output1.tolist())
print("Loaded output:", output2.tolist())
# Cleanup
os.remove(path_full)
os.remove(path_state)
### Example 4: Device Migration
Migrate models between different devices:
## Instance
import torch
import torch.nn as nn
class Net(nn.Module):
def __init__ (self):
super(Net,self). __init__ ()
self.fc= nn.Linear(10,5)
def forward(self, x):
return self.fc(x)
model = Net()
# Check current device
print("Parameter device:", model.fc.weight.device)
# Move to GPU (if available)
if torch.cuda.is_available():
model = model.cuda()
print("After moving to GPU:", model.fc.weight.device)
# Move back to CPU
model = model.cpu()
print("After moving back to CPU:", model.fc.weight.device)
### Example 5: Using apply for Initialization
Use apply for recursive initialization:
## Instance
import torch
import torch.nn as nn
def init_weights(module):
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
elif isinstance(module, nn.Conv2d):
nn.init.kaiming_normal_(module.weight)
class Net(nn.Module):
def __init__ (self):
super(Net,self). __init__ ()
self.fc1= nn.Linear(10,20)
self.fc2= nn.Linear(20,5)
def forward(self, x):
return self.fc2(self.fc1(x))
model = Net()
model.apply(init_weights)
print("fc1 weights:", model.fc1.weight[0, :3].tolist())
print("fc2 weights:", model.fc2.weight[0, :3].tolist())
### Example 6: Complex Network Structure
Build a network with branches:
## Instance
import torch
import torch.nn as nn
class BranchNet(nn.Module):
def __init__ (self):
super(BranchNet,self). __init__ ()
# Backbone
self.shared= nn.Linear(10,20)
# Branches
self.branch_a= nn.Linear(20,5)
self.branch_b= nn.Linear(20,3)
def forward(self, x):
feat =self.shared(x)
out_a =self.branch_a(feat)
out_b =self.branch_b(feat)
return out_a, out_b
model = BranchNet()
x = torch.randn(4,10)
out_a, out_b = model(x)
print("Branch A output:", out_a.shape)
print("Branch B output:", out_b.shape)
* * *
## Frequently Asked Questions
### Q1: Why must super().__init__() be called?
It calls the parent class's initialization method, ensuring that parameters are correctly registered.
<h3 Q2: How to view model structure?
Simply print(model) or use the summary function from torchvision.
YouTip