YouTip LogoYouTip

Pytorch Torch Nn Module

[![Image 1: PyTorch torch.nn Reference Manual](#) PyTorch torch.nn Reference Manual](#) * * * `torch.nn.Module` is the base class for all neural network modules in PyTorch. All custom network models should inherit from this class, which provides parameter management, device migration, model saving, and other functionalities. ### Class Definition torch.nn.Module ### Main Attributes * `parameters()`: Returns all learnable parameters of the model * `named_parameters()`: Returns an iterator over parameter names and values * `children()`: Returns sub-modules of the model * `named_children()`: Returns an iterator over sub-module names and modules * `modules()`: Returns all modules * `state_dict()`: Returns a dictionary containing all parameters * * * ## Usage Examples ### Example 1: Creating a Custom Module Inherit from nn.Module to create a custom network: ## Instance import torch import torch.nn as nn class SimpleNet(nn.Module): def __init__ (self, input_dim, hidden_dim, output_dim): super(SimpleNet,self). __init__ () # Define network layers self.fc1= nn.Linear(input_dim, hidden_dim) self.relu= nn.ReLU() self.fc2= nn.Linear(hidden_dim, output_dim) def forward(self, x): # Define forward propagation x =self.fc1(x) x =self.relu(x) x =self.fc2(x) return x # Instantiate model model = SimpleNet(input_dim=784, hidden_dim=256, output_dim=10) # Test x = torch.randn(32,784) output = model(x) print("Model structure:") print(model) print("nInput shape:", x.shape) print("Output shape:", output.shape) ### Example 2: Managing Parameters Access and manage model parameters: ## Instance import torch import torch.nn as nn class Net(nn.Module): def __init__ (self): super(Net,self). __init__ () self.conv1= nn.Conv2d(3,16,3) self.bn1= nn.BatchNorm2d(16) self.fc= nn.Linear(16 * 14 * 14,10) def forward(self, x): x =self.conv1(x) x =self.bn1(x) x = x.view(x.size(0), -1) x =self.fc(x) return x model = Net() # Count parameters total_params =sum(p.numel()for p in model.parameters()) trainable_params =sum(p.numel()for p in model.parameters()if p.requires_grad) print("Total parameters:", total_params) print("Trainable parameters:", trainable_params) # Access specific parameters print("nconv1 weight shape:", model.conv1.weight.shape) print("fc bias shape:", model.fc.bias.shape) ### Example 3: Model Saving and Loading Save and load models: ## Instance import torch import torch.nn as nn import tempfile import os class Net(nn.Module): def __init__ (self): super(Net,self). __init__ () self.fc= nn.Linear(10,5) def forward(self, x): return self.fc(x) model = Net() # Save entire model with tempfile.NamedTemporaryFile(delete=False, suffix='.pth')as f: torch.save(model, f.name) path_full = f.name # Save state_dict (recommended way) with tempfile.NamedTemporaryFile(delete=False, suffix='.pth')as f: torch.save(model.state_dict(), f.name) path_state = f.name # Load model loaded_model = Net() loaded_model.load_state_dict(torch.load(path_state)) loaded_model.eval() # Test loaded model x = torch.randn(2,10) output1 = model(x) output2 = loaded_model(x) print("Original output:", output1.tolist()) print("Loaded output:", output2.tolist()) # Cleanup os.remove(path_full) os.remove(path_state) ### Example 4: Device Migration Migrate models between different devices: ## Instance import torch import torch.nn as nn class Net(nn.Module): def __init__ (self): super(Net,self). __init__ () self.fc= nn.Linear(10,5) def forward(self, x): return self.fc(x) model = Net() # Check current device print("Parameter device:", model.fc.weight.device) # Move to GPU (if available) if torch.cuda.is_available(): model = model.cuda() print("After moving to GPU:", model.fc.weight.device) # Move back to CPU model = model.cpu() print("After moving back to CPU:", model.fc.weight.device) ### Example 5: Using apply for Initialization Use apply for recursive initialization: ## Instance import torch import torch.nn as nn def init_weights(module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) elif isinstance(module, nn.Conv2d): nn.init.kaiming_normal_(module.weight) class Net(nn.Module): def __init__ (self): super(Net,self). __init__ () self.fc1= nn.Linear(10,20) self.fc2= nn.Linear(20,5) def forward(self, x): return self.fc2(self.fc1(x)) model = Net() model.apply(init_weights) print("fc1 weights:", model.fc1.weight[0, :3].tolist()) print("fc2 weights:", model.fc2.weight[0, :3].tolist()) ### Example 6: Complex Network Structure Build a network with branches: ## Instance import torch import torch.nn as nn class BranchNet(nn.Module): def __init__ (self): super(BranchNet,self). __init__ () # Backbone self.shared= nn.Linear(10,20) # Branches self.branch_a= nn.Linear(20,5) self.branch_b= nn.Linear(20,3) def forward(self, x): feat =self.shared(x) out_a =self.branch_a(feat) out_b =self.branch_b(feat) return out_a, out_b model = BranchNet() x = torch.randn(4,10) out_a, out_b = model(x) print("Branch A output:", out_a.shape) print("Branch B output:", out_b.shape) * * * ## Frequently Asked Questions ### Q1: Why must super().__init__() be called? It calls the parent class's initialization method, ensuring that parameters are correctly registered. <h3 Q2: How to view model structure? Simply print(model) or use the summary function from torchvision.

Tip: All custom Modules need to implement the forward method. * * PyTorch torch.nn Reference Manual](#)

← Pytorch Torch Nn MultiheadattePytorch Torch Nn Mseloss β†’