YouTip LogoYouTip

Sklearn Model Save Load

In machine learning, the model training process is usually time-consuming. To avoid retraining the model every time, we can save the trained model for easy loading and prediction later. `scikit-learn` provides two commonly used methods to save and load models: `joblib` and `pickle`. ## 1. Using `joblib` to Save and Load Models `joblib` is an efficient Python serialization tool, especially suitable for saving objects containing large numerical arrays (such as numpy arrays, scikit-learn models, etc.). Compared to `pickle`, `joblib` is more efficient when handling large-scale data. joblib is an external library for Python and can be installed using the following command: pip install joblib ### Save Model joblib provides a simple API to save and load objects. We can use the joblib.dump() method to save the model to a file. ## Example import joblib from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.svm import SVC # Load data data = load_iris() X, y = data.data, data.target # Split training set and test set X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Create and train model model = SVC(kernel='linear') model.fit(X_train, y_train) # Save model to file joblib.dump(model,'svm_model.joblib') ### Load Model Use the joblib.load() method to load the saved model object. ## Example # Load saved model loaded_model = joblib.load('svm_model.joblib') # Use loaded model for prediction y_pred = loaded_model.predict(X_test) # Print prediction results print("Predictions:", y_pred) Through the above steps, we successfully saved the trained model to a file and can load the model at any later time for prediction. * * * ## 2. Using `pickle` to Save and Load Models pickle is a built-in Python module that allows serialization and deserialization of Python objects. Although joblib is more suitable for handling large amounts of data, pickle is also a commonly used tool for saving and loading models and is suitable for general situations. ### Save Model Similar to joblib, pickle also has a simple API to save and load objects. The code to save the model is as follows: ## Example import pickle from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split from sklearn.svm import SVC # Load data data = load_iris() X, y = data.data, data.target # Split training set and test set X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # Create and train model model = SVC(kernel='linear') model.fit(X_train, y_train) # Save model using pickle with open('svm_model.pkl','wb')as f: pickle.dump(model, f) ### Load Model Use pickle.load() to load the model: ## Example # Use pickle to load saved model with open('svm_model.pkl','rb')as f: loaded_model =pickle.load(f) # Use loaded model for prediction y_pred = loaded_model.predict(X_test) # Print prediction results print("Predictions:", y_pred) * * * ## 3. joblib vs pickle joblib and pickle are two commonly used methods for saving and loading models. joblib is more suitable for saving large data objects, while pickle is Python's standard serialization tool and is suitable for general situations. * **`joblib`**: Usually suitable for saving objects containing large amounts of numerical data (such as numpy arrays). `joblib` is more efficient than `pickle` when handling large-scale data. * **`pickle`**: Suitable for saving smaller objects or regular Python objects. It is Python's built-in library and requires no additional installation. If the model contains a large number of numerical arrays or matrices (such as Support Vector Machines, Random Forests, etc.), it is recommended to use joblib, which is more efficient than pickle. For smaller models or models that do not contain a large amount of numerical data, pickle is sufficient. * * * ## 4. Save and Load Pipeline In practical applications, a model is not just a single model; sometimes it combines multiple processing steps (such as data preprocessing, feature selection, model training, etc.). These processing steps can be completed using scikit-learn's Pipeline. Pipeline can also be saved and loaded through joblib or pickle. Save Pipeline: ## Example from sklearn.pipeline import Pipeline from sklearn.preprocessing import StandardScaler from sklearn.svm import SVC import joblib # Create a pipeline pipeline = Pipeline([ ('scaler', StandardScaler()), ('svc', SVC(kernel='linear')) ]) # Train pipeline pipeline.fit(X_train, y_train) # Save pipeline to file joblib.dump(pipeline,'pipeline_model.joblib') Load Pipeline: ## Example # Load pipeline loaded_pipeline = joblib.load('pipeline_model.joblib') # Use loaded pipeline for prediction y_pred = loaded_pipeline.predict(X_test) # Print prediction results print("Predictions:", y_pred) The process of saving and loading a pipeline is the same as for a single model; just make sure to save and load the entire pipeline object. * * * ## 5. Model Version Management In practical applications of machine learning, model updates and version management are crucial. Each time you train and save a model, it is best to add a timestamp or version number to the model file name to distinguish between different versions of the model. For example: ## Example import time # Create timestamp timestamp =time.strftime("%Y%m%d-%H%M%S") # Save model with timestamp joblib.dump(model, f'svm_model_{timestamp}.joblib') This way, we can manage different versions of the model based on timestamps, making it easy to trace back and update models. * * * ## 6. Model Persistence Once the model is trained and saved, we can load the model in subsequent practical applications for prediction without retraining. For example, we can integrate the saved model with web services, batch jobs, or other applications, allowing the model to be reused repeatedly without retraining. ### Using Loaded Model in Web Services For example, suppose we are using Flask to create a simple web service that provides model prediction services through an API interface. In this case, we can load the saved model for real-time prediction. ## Example from flask import Flask, request, jsonify import joblib import numpy as np app = Flask( __name__ ) # Load model model = joblib.load('svm_model.joblib') @app.route('/predict', methods=['POST']) def predict(): data = request.get_json()# Get input data features = np.array(data['features']).reshape(1, -1)# Convert to format suitable for prediction prediction = model.predict(features)# Use loaded model for prediction return jsonify({'prediction': prediction.tolist()})# Return prediction result if __name__ =='__main__': app.run(debug=True)
← Ml AlgorithmsSklearn Pipeline β†’