一、基本概念
隨機梯度下降法(Stochastic Gradient Descent,SGD)相對於傳統的梯度下降法,是一種更為高效的機器學習優化算法。梯度下降法每次迭代都要遍歷整個訓練集,計算所有樣本的梯度才能更新權重,因此在大規模數據的情況下會十分耗時。而SGD每次只選取一個或一小批量樣本來計算梯度,從而使得每次迭代的計算量大大降低,具有更高的計算效率。
SGD的具體實現過程如下:對於目標函數 $J(w)$,權重 $w$,學習率 $\eta$,SGD每次從訓練集中隨機選取一個樣本 $x_i$,計算該樣本的梯度 $g_i$,然後使用梯度下降法的公式更新權重:$w = w – \eta g_i$。這樣不斷迭代更新,直到達到一定的迭代次數或者達到收斂要求即可。
下面給出SGD的偽代碼實現:
for i in range(0, max_iter): shuffle(X) # 打亂訓練集 for j in range(num_samples): # 隨機選取一個訓練樣本進行梯度計算 xi = X[j] yi = y[j] gi = compute_gradient(J, xi, yi, w) # 更新權重 w = w - eta * gi
二、SGD的優點
1、高效:SGD每次只需要計算一個樣本的梯度,計算量較小,適合大規模數據集的優化問題。
2、易於並行:SGD每次更新只操作一個樣本,易於實現並行化操作,從而大大縮短了計算時間。
3、可收斂到局部最優解:SGD的收斂路徑具有一定的隨機性,能在一定程度上跳出局部最優解,收斂到全局最優解的概率也相對較大。
三、SGD的缺點與改進
1、收斂速度慢:SGD每次只更新一個樣本,可能會出現跳出最優解的情況,同時也容易受到樣本噪聲的干擾,導致收斂速度慢。
2、有一定的不穩定性:由於每次只考慮一個樣本,SGD可能會受到單個樣本的影響,進而影響到整個模型的參數更新。
為了克服SGD的缺點,研究者們提出了一系列改進方法。其中最常見的是Batch SGD和Mini-batch SGD。Batch SGD每次更新所有的樣本梯度,Mini-batch SGD選取一個小批量樣本進行梯度計算,大小通常設置在2~256之間。這樣權衡了運算速度和參數更新的精度。
下面給出Mini-batch SGD的實現代碼:
for i in range(0, max_iter): shuffle(X) # 打亂訓練集 for j in range(0, num_samples, batch_size): # 隨機選取一個小批量訓練樣本進行梯度計算 batch_indices = range(j, min(j + batch_size, num_samples)) X_batch = X[batch_indices] y_batch = y[batch_indices] gi = compute_gradient(J, X_batch, y_batch, w) # 更新權重 w = w - eta * gi
四、實戰代碼示例
下面以sklearn庫中的breast_cancer數據集為例,展示如何使用SGDClassifier類進行二分類問題訓練。在示例代碼中,我們使用SGDClassifier類進行100次迭代的訓練,打印出了訓練集和測試集上的分類準確率。
from sklearn.datasets import load_breast_cancer from sklearn.linear_model import SGDClassifier from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler # 加載數據集 data = load_breast_cancer() X = data.data y = data.target # 數據歸一化 scaler = StandardScaler() X = scaler.fit_transform(X) # 劃分訓練集和測試集 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42) # 使用SGDClassifier進行訓練 clf = SGDClassifier(max_iter=100, tol=1e-3) clf.fit(X_train, y_train) # 打印結果 print("Train set score:", clf.score(X_train, y_train)) print("Test set score:", clf.score(X_test, y_test))
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/258321.html