YouTip LogoYouTip

Pytorch Torch Load

## PyTorch `torch.load` Function In PyTorch, saving and restoring models, tensors, and other serialized objects is a fundamental part of the machine learning workflow. The `torch.load` function is the primary utility used to deserialize and load PyTorch objects that have been saved to disk using `torch.save`. --- ## Function Definition The `torch.load` function deserializes objects from a file, allowing you to restore tensors, model states, dictionaries, or entire neural network architectures. ```python torch.load(f, map_location=None, weights_only=False, **kwargs) ``` ### Parameters * **`f`**: A file-like object (has to implement `read()`, `readline()`, `tell()`, and `seek()`), or a string/os.PathLike object containing a file name. * **`map_location`**: Specifies how to remap storage locations. It can be a function, a `torch.device`, a string (e.g., `'cpu'`, `'cuda:0'`), or a dictionary specifying how to map tensors from one device to another (e.g., mapping GPU tensors to the CPU). * **`weights_only`**: A boolean flag (introduced in newer PyTorch versions). When set to `True`, it restricts the unpickler to only load basic Python types, tensors, and dictionaries. This is highly recommended for security reasons to prevent arbitrary code execution from untrusted files. --- ## Basic Usage Example The following example demonstrates how to save a PyTorch tensor to disk using `torch.save` and then load it back into memory using `torch.load`. ```python import torch # Create and save a tensor x = torch.tensor([1, 2, 3, 4, 5]) torch.save(x, 'tensor.pt') # Load the tensor back loaded_x = torch.load('tensor.pt') print(f"Loaded tensor: {loaded_x}") # Output: Loaded tensor: tensor([1, 2, 3, 4, 5]) ``` --- ## Advanced Examples ### 1. Loading a Model's `state_dict` (Recommended Practice) In production workflows, it is best practice to save and load only the model's parameters (`state_dict`) rather than the entire model object. This approach offers maximum flexibility and avoids dependency on specific directory structures. ```python import torch import torch.nn as nn # Define a simple model class SimpleModel(nn.Module): def __init__(self): super(SimpleModel, self).__init__() self.fc = nn.Linear(10, 2) def forward(self, x): return self.fc(x) # Instantiate and save the state_dict model = SimpleModel() torch.save(model.state_dict(), 'model_weights.pth') # Re-instantiate the model structure and load the weights new_model = SimpleModel() new_model.load_state_dict(torch.load('model_weights.pth')) new_model.eval() # Set the model to evaluation mode print("Model state_dict loaded successfully!") ``` ### 2. Device Remapping with `map_location` When training models on a GPU (`cuda`) and deploying them on a CPU-only machine (or vice versa), you must remap the storage location during the load phase. #### Loading a GPU-saved model onto a CPU: ```python # Load a tensor/model saved on GPU directly to CPU loaded_on_cpu = torch.load('model_weights.pth', map_location=torch.device('cpu')) ``` #### Loading a CPU-saved model onto a specific GPU: ```python # Load a tensor/model directly to GPU 0 device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') loaded_on_gpu = torch.load('model_weights.pth', map_location=device) ``` ### 3. Secure Loading with `weights_only` To protect your system against malicious code execution embedded in untrusted `.pt` or `.pth` files, use the `weights_only=True` parameter. ```python # Securely load weights without executing arbitrary pickle code safe_loaded_weights = torch.load('model_weights.pth', weights_only=True) ``` --- ## Important Considerations 1. **Security Warning**: `torch.load` uses Python's `pickle` module under the hood by default. Never load files from untrusted sources, as they can execute arbitrary Python code during deserialization. Always use `weights_only=True` when loading third-party weights. 2. **Class Definitions**: When saving an entire model object (e.g., `torch.save(model, 'model.pt')`) instead of just the `state_dict`, the serialized data is bound to the specific class definition and directory structure. If you move or rename the model class in your codebase, `torch.load` will fail to find the class definition and throw an error. 3. **File Extensions**: While PyTorch does not enforce a strict file extension, the community standard is to use `.pt` or `.pth` for saved tensors and model weights.
← Pytorch Torch LobpcgPytorch Torch Linalg Svd β†’