Pytorch Torch Load
## PyTorch `torch.load` Function
In PyTorch, saving and restoring models, tensors, and other serialized objects is a fundamental part of the machine learning workflow. The `torch.load` function is the primary utility used to deserialize and load PyTorch objects that have been saved to disk using `torch.save`.
---
## Function Definition
The `torch.load` function deserializes objects from a file, allowing you to restore tensors, model states, dictionaries, or entire neural network architectures.
```python
torch.load(f, map_location=None, weights_only=False, **kwargs)
```
### Parameters
* **`f`**: A file-like object (has to implement `read()`, `readline()`, `tell()`, and `seek()`), or a string/os.PathLike object containing a file name.
* **`map_location`**: Specifies how to remap storage locations. It can be a function, a `torch.device`, a string (e.g., `'cpu'`, `'cuda:0'`), or a dictionary specifying how to map tensors from one device to another (e.g., mapping GPU tensors to the CPU).
* **`weights_only`**: A boolean flag (introduced in newer PyTorch versions). When set to `True`, it restricts the unpickler to only load basic Python types, tensors, and dictionaries. This is highly recommended for security reasons to prevent arbitrary code execution from untrusted files.
---
## Basic Usage Example
The following example demonstrates how to save a PyTorch tensor to disk using `torch.save` and then load it back into memory using `torch.load`.
```python
import torch
# Create and save a tensor
x = torch.tensor([1, 2, 3, 4, 5])
torch.save(x, 'tensor.pt')
# Load the tensor back
loaded_x = torch.load('tensor.pt')
print(f"Loaded tensor: {loaded_x}")
# Output: Loaded tensor: tensor([1, 2, 3, 4, 5])
```
---
## Advanced Examples
### 1. Loading a Model's `state_dict` (Recommended Practice)
In production workflows, it is best practice to save and load only the model's parameters (`state_dict`) rather than the entire model object. This approach offers maximum flexibility and avoids dependency on specific directory structures.
```python
import torch
import torch.nn as nn
# Define a simple model
class SimpleModel(nn.Module):
def __init__(self):
super(SimpleModel, self).__init__()
self.fc = nn.Linear(10, 2)
def forward(self, x):
return self.fc(x)
# Instantiate and save the state_dict
model = SimpleModel()
torch.save(model.state_dict(), 'model_weights.pth')
# Re-instantiate the model structure and load the weights
new_model = SimpleModel()
new_model.load_state_dict(torch.load('model_weights.pth'))
new_model.eval() # Set the model to evaluation mode
print("Model state_dict loaded successfully!")
```
### 2. Device Remapping with `map_location`
When training models on a GPU (`cuda`) and deploying them on a CPU-only machine (or vice versa), you must remap the storage location during the load phase.
#### Loading a GPU-saved model onto a CPU:
```python
# Load a tensor/model saved on GPU directly to CPU
loaded_on_cpu = torch.load('model_weights.pth', map_location=torch.device('cpu'))
```
#### Loading a CPU-saved model onto a specific GPU:
```python
# Load a tensor/model directly to GPU 0
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
loaded_on_gpu = torch.load('model_weights.pth', map_location=device)
```
### 3. Secure Loading with `weights_only`
To protect your system against malicious code execution embedded in untrusted `.pt` or `.pth` files, use the `weights_only=True` parameter.
```python
# Securely load weights without executing arbitrary pickle code
safe_loaded_weights = torch.load('model_weights.pth', weights_only=True)
```
---
## Important Considerations
1. **Security Warning**: `torch.load` uses Python's `pickle` module under the hood by default. Never load files from untrusted sources, as they can execute arbitrary Python code during deserialization. Always use `weights_only=True` when loading third-party weights.
2. **Class Definitions**: When saving an entire model object (e.g., `torch.save(model, 'model.pt')`) instead of just the `state_dict`, the serialized data is bound to the specific class definition and directory structure. If you move or rename the model class in your codebase, `torch.load` will fail to find the class definition and throw an error.
3. **File Extensions**: While PyTorch does not enforce a strict file extension, the community standard is to use `.pt` or `.pth` for saved tensors and model weights.
YouTip