一、GridSearchCV概述
GridSearchCV是scikit-learn中一個重要的調參工具,用於系統地遍歷多個參數組合,通過交叉驗證確定最佳參數。在機器學習算法中,各個算法有很多超參數,超參數的優化對算法的性能至關重要。而GridSearchCV,正是通過遍歷所有參數組合,找到最優參數從而提高模型在給定數據集上的性能。
二、逐步介紹GridSearchCV
1、如何使用GridSearchCV?
GridSearchCV需要提供一個估計器(estimator)和一組參數(dictionary),用於將參數遍歷作為一個估計器的參數組合進行交叉驗證。首先需要加載相關庫,如下:
from sklearn.model_selection import GridSearchCV
from sklearn.svm import SVC
接下來,加載數據並建立模型:
import numpy as np
from sklearn.datasets import load_iris
from sklearn.model_selection import train_test_split
iris = load_iris()
X, y = iris.data, iris.target
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3,random_state=0)
svc = SVC()
構建參數字典:
param_grid = {'C': [0.1, 1, 10, 100, 1000],
'gamma': ['scale', 'auto', 0.001, 0.0001],
'kernel': ['rbf', 'linear', 'poly', 'sigmoid']}
GridSearchCV的使用:
clf = GridSearchCV(svc, param_grid, cv=5)
clf.fit(X_train, y_train)
其中,SVC的參數指定為clf,param_grid指定了要搜索的參數組合,cv表示使用的交叉驗證的策略。5表示使用5折交叉驗證法。
2、參數的選擇
超參數的選擇是機器學習中非常重要的一步。參數的組合可能非常多,這使得參數的選擇非常困難。但是使用GridSearchCV可以一個一個地測試每種超參數的組合,從而找到最優的參數組合。在上述例子中,我們使用了3個參數:C、gamma和kernel。C是SVM中的參數,表示懲罰誤分類點的權重,gamma是核函數中的參數,表示向樣本中添加一個樣本的影響程度。kernel是SVM中的核函數類型。在param_grid中,C的取值範圍是[0.1, 1, 10, 100, 1000],gamma的取值是[‘scale’, ‘auto’, 0.001, 0.0001],kernel的取值是[‘rbf’, ‘linear’, ‘poly’, ‘sigmoid’]。使用GridSearchCV會嘗試每種參數組合,最終得出最佳的參數組合。
3、結果分析與輸出
GridSearchCV有很多輸出信息,其中最重要的是best_params_和best_score_。best_params_是最佳參數的集合,best_score_是最佳參數組合的得分。另外,可以使用cv_results_將所有的參數集合及其得分統計出來。代碼如下:
print("The best parameters are %s with a score of %0.2f" % (clf.best_params_, clf.best_score_))
print("The best parameters sets found on development set:")
print(clf.best_params_)
print("Grid scores on development set:")
means = clf.cv_results_['mean_test_score']
stds = clf.cv_results_['std_test_score']
for mean, std, params in zip(means, stds, clf.cv_results_['params']):
print("%0.3f (+/-%0.03f) for %r" % (mean, std * 2, params))
三、注意事項
1、使用更易搜索的參數
當需要使用GridSearchCV時,保證參數越少越好。如果搜索的空間很大,搜索過程就會非常耗時。使用太多參數會使搜索過程變得非常緩慢。確保每個參數慎重再慎重的選擇。
2、使用並行計算
sklearn提供了並行計算功能,這使得搜索過程更快。默認情況下,GridSearchCV使用的計算是單進程的,使用n_jobs參數將其更改為多進程的。
例如:
clf = GridSearchCV(svc, param_grid, cv=5, n_jobs=-1)
3、不要期望太高
網格搜索是一項非常強大的技術,但不要期望它能在所有數據集上表現良好。在某些情況下,它的表現可能不如其他優化技術。在實際應用時,最好記錄不同模型及其參數的得分。
四、總結
GridSearchCV是機器學習中一個應該掌握的重要調參工具。在這篇文章中,我們了解了GridSearchCV的基本原理,如何使用它找到模型的最佳參數,如何分析最佳參數的結果。
原創文章,作者:IIDO,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/135771.html