Sklearn Model Save Load
In machine learning, the model training process is usually time-consuming. To avoid retraining the model every time, we can save the trained model for easy loading and prediction later.
`scikit-learn` provides two commonly used methods to save and load models: `joblib` and `pickle`.
## 1. Using `joblib` to Save and Load Models
`joblib` is an efficient Python serialization tool, especially suitable for saving objects containing large numerical arrays (such as numpy arrays, scikit-learn models, etc.). Compared to `pickle`, `joblib` is more efficient when handling large-scale data.
joblib is an external library for Python and can be installed using the following command:
pip install joblib
### Save Model
joblib provides a simple API to save and load objects.
We can use the joblib.dump() method to save the model to a file.
## Example
import joblib
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
# Load data
data = load_iris()
X, y = data.data, data.target
# Split training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create and train model
model = SVC(kernel='linear')
model.fit(X_train, y_train)
# Save model to file
joblib.dump(model,'svm_model.joblib')
### Load Model
Use the joblib.load() method to load the saved model object.
## Example
# Load saved model
loaded_model = joblib.load('svm_model.joblib')
# Use loaded model for prediction
y_pred = loaded_model.predict(X_test)
# Print prediction results
print("Predictions:", y_pred)
Through the above steps, we successfully saved the trained model to a file and can load the model at any later time for prediction.
* * *
## 2. Using `pickle` to Save and Load Models
pickle is a built-in Python module that allows serialization and deserialization of Python objects.
Although joblib is more suitable for handling large amounts of data, pickle is also a commonly used tool for saving and loading models and is suitable for general situations.
### Save Model
Similar to joblib, pickle also has a simple API to save and load objects.
The code to save the model is as follows:
## Example
import pickle
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
from sklearn.svm import SVC
# Load data
data = load_iris()
X, y = data.data, data.target
# Split training set and test set
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# Create and train model
model = SVC(kernel='linear')
model.fit(X_train, y_train)
# Save model using pickle
with open('svm_model.pkl','wb')as f:
pickle.dump(model, f)
### Load Model
Use pickle.load() to load the model:
## Example
# Use pickle to load saved model
with open('svm_model.pkl','rb')as f:
loaded_model =pickle.load(f)
# Use loaded model for prediction
y_pred = loaded_model.predict(X_test)
# Print prediction results
print("Predictions:", y_pred)
* * *
## 3. joblib vs pickle
joblib and pickle are two commonly used methods for saving and loading models.
joblib is more suitable for saving large data objects, while pickle is Python's standard serialization tool and is suitable for general situations.
* **`joblib`**: Usually suitable for saving objects containing large amounts of numerical data (such as numpy arrays). `joblib` is more efficient than `pickle` when handling large-scale data.
* **`pickle`**: Suitable for saving smaller objects or regular Python objects. It is Python's built-in library and requires no additional installation.
If the model contains a large number of numerical arrays or matrices (such as Support Vector Machines, Random Forests, etc.), it is recommended to use joblib, which is more efficient than pickle. For smaller models or models that do not contain a large amount of numerical data, pickle is sufficient.
* * *
## 4. Save and Load Pipeline
In practical applications, a model is not just a single model; sometimes it combines multiple processing steps (such as data preprocessing, feature selection, model training, etc.). These processing steps can be completed using scikit-learn's Pipeline. Pipeline can also be saved and loaded through joblib or pickle.
Save Pipeline:
## Example
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler
from sklearn.svm import SVC
import joblib
# Create a pipeline
pipeline = Pipeline([
('scaler', StandardScaler()),
('svc', SVC(kernel='linear'))
])
# Train pipeline
pipeline.fit(X_train, y_train)
# Save pipeline to file
joblib.dump(pipeline,'pipeline_model.joblib')
Load Pipeline:
## Example
# Load pipeline
loaded_pipeline = joblib.load('pipeline_model.joblib')
# Use loaded pipeline for prediction
y_pred = loaded_pipeline.predict(X_test)
# Print prediction results
print("Predictions:", y_pred)
The process of saving and loading a pipeline is the same as for a single model; just make sure to save and load the entire pipeline object.
* * *
## 5. Model Version Management
In practical applications of machine learning, model updates and version management are crucial. Each time you train and save a model, it is best to add a timestamp or version number to the model file name to distinguish between different versions of the model. For example:
## Example
import time
# Create timestamp
timestamp =time.strftime("%Y%m%d-%H%M%S")
# Save model with timestamp
joblib.dump(model, f'svm_model_{timestamp}.joblib')
This way, we can manage different versions of the model based on timestamps, making it easy to trace back and update models.
* * *
## 6. Model Persistence
Once the model is trained and saved, we can load the model in subsequent practical applications for prediction without retraining.
For example, we can integrate the saved model with web services, batch jobs, or other applications, allowing the model to be reused repeatedly without retraining.
### Using Loaded Model in Web Services
For example, suppose we are using Flask to create a simple web service that provides model prediction services through an API interface. In this case, we can load the saved model for real-time prediction.
## Example
from flask import Flask, request, jsonify
import joblib
import numpy as np
app = Flask( __name__ )
# Load model
model = joblib.load('svm_model.joblib')
@app.route('/predict', methods=['POST'])
def predict():
data = request.get_json()# Get input data
features = np.array(data['features']).reshape(1, -1)# Convert to format suitable for prediction
prediction = model.predict(features)# Use loaded model for prediction
return jsonify({'prediction': prediction.tolist()})# Return prediction result
if __name__ =='__main__':
app.run(debug=True)
YouTip