小批量梯度下降法的詳細闡述

一、什麼是小批量梯度下降法

1、小批量梯度下降法(Mini-batch Gradient Descent, MBGD)是一種介於梯度下降法(GD)和隨機梯度下降法(SGD)之間的優化算法。在每次迭代時,它不像GD一樣使用所有的訓練樣本,也不像SGD一樣只使用一個樣本,而是使用一小部分訓練樣本(通常是2到1000個)。

2、這樣可以在降低隨機性和提高算法效率之間取得平衡。與SGD相比,MBGD在每次迭代時使用更多的數據,因此更可能找到全局最優解;與GD相比 ,MBGD在每次迭代時使用更少的數據,因此更快、更節省內存。

二、小批量梯度下降法的實現

1、首先需要定義一個損失函數(Loss Function),通常是均方誤差函數,表示預測結果與真實結果之間的差距。

def loss_function(y_true, y_pred):
    return ((y_true-y_pred)**2).mean()

2、然後需要定義模型(Model),通常是一個線性回歸模型,也可以是其他類型的模型。

class LinearRegression:
    def __init__(self, n_features):
        self.n_features = n_features
        self.weights = np.random.randn(n_features)
        self.bias = np.random.randn()

    def forward(self, x):
        return np.dot(x, self.weights) + self.bias

    def backward(self, x, y, y_pred):
        n_samples = x.shape[0]
        d_weights = (2 / n_samples) * np.dot(x.T, (y_pred - y))
        d_bias = (2 / n_samples) * np.sum(y_pred - y)
        return d_weights, d_bias

3、在訓練過程中,需要隨機抽取一小部分樣本構成一個batch,計算這個batch的損失和梯度,然後更新模型參數。

def train_step(model, optimizer, x_batch, y_batch):
    # forward
    y_pred = model.forward(x_batch)
    # backward
    d_weights, d_bias = model.backward(x_batch, y_batch, y_pred)
    # update
    optimizer.update(model, d_weights, d_bias)
    # compute loss
    loss = loss_function(y_batch, y_pred)
    return loss

三、小批量梯度下降法的優點和缺點

1、優點:

(1)相對於梯度下降法,小批量梯度下降法更快,內存消耗更少,更適合大規模數據集的訓練;

(2)相對於隨機梯度下降法,小批量梯度下降法更穩定,更容易找到全局最優解;

(3)由於小批量梯度下降法使用了一部分數據,因此可以獲得比隨機梯度下降法更準確的梯度,從而更快地收斂。

2、缺點:

(1)需要調整batch size的大小,太小容易增加噪聲,太大會佔用過多的內存;

(2)需要調整學習率(learning rate)的大小,太小可能導致收斂過慢,太大可能導致震蕩不收斂;

(3)需要對數據進行shuffle,否則容易陷入局部最優解。

四、小批量梯度下降法的應用

1、小批量梯度下降法是深度學習中最常用的優化算法之一,廣泛應用於神經網絡的訓練;

2、小批量梯度下降法也可以應用於其他機器學習領域,如線性回歸、邏輯回歸、支持向量機等;

3、小批量梯度下降法的變種還有動量梯度下降法、Adam等,它們在小批量梯度下降法的基礎上加入了一些優化技巧,可以獲得更好的性能。

原創文章,作者:SSNSG,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/349483.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
SSNSG的頭像SSNSG
上一篇 2025-02-15 17:10
下一篇 2025-02-15 17:10

相關推薦

  • 預處理共軛梯度法

    預處理共軛梯度法是一種求解線性方程組的迭代方法,相比直接求解,其具有更高的效率和更快的速度。本文將從幾個方面對預處理共軛梯度法進行詳細的闡述,並給出完整的代碼示例。 一、預處理共軛…

    編程 2025-04-28
  • Python邏輯回歸梯度下降法

    本文將通過Python邏輯回歸梯度下降法,對於邏輯回歸的原理、實現方法和應用進行詳細闡述。 一、邏輯回歸原理 邏輯回歸是一種常用的分類算法,其原理可以用線性回歸模型來描述,將線性回…

    編程 2025-04-27
  • index.html怎麼打開 – 詳細解析

    一、index.html怎麼打開看 1、如果你已經擁有了index.html文件,那麼你可以直接使用任何一個現代瀏覽器打開index.html文件,比如Google Chrome、…

    編程 2025-04-25
  • Resetful API的詳細闡述

    一、Resetful API簡介 Resetful(REpresentational State Transfer)是一種基於HTTP協議的Web API設計風格,它是一種輕量級的…

    編程 2025-04-25
  • 關鍵路徑的詳細闡述

    關鍵路徑是項目管理中非常重要的一個概念,它通常指的是項目中最長的一條路徑,它決定了整個項目的完成時間。在這篇文章中,我們將從多個方面對關鍵路徑做詳細的闡述。 一、概念 關鍵路徑是指…

    編程 2025-04-25
  • AXI DMA的詳細闡述

    一、AXI DMA概述 AXI DMA是指Advanced eXtensible Interface Direct Memory Access,是Xilinx公司提供的基於AMBA…

    編程 2025-04-25
  • neo4j菜鳥教程詳細闡述

    一、neo4j介紹 neo4j是一種圖形數據庫,以實現高效的圖操作為設計目標。neo4j使用圖形模型來存儲數據,數據的表述方式類似於實際世界中的網絡。neo4j具有高效的讀和寫操作…

    編程 2025-04-25
  • c++ explicit的詳細闡述

    一、explicit的作用 在C++中,explicit關鍵字可以在構造函數聲明前加上,防止編譯器進行自動類型轉換,強制要求調用者必須強制類型轉換才能調用該函數,避免了將一個參數類…

    編程 2025-04-25
  • HTMLButton屬性及其詳細闡述

    一、button屬性介紹 button屬性是HTML5新增的屬性,表示指定文本框擁有可供點擊的按鈕。該屬性包括以下幾個取值: 按鈕文本 提交 重置 其中,type屬性表示按鈕類型,…

    編程 2025-04-25
  • Vim使用教程詳細指南

    一、Vim使用教程 Vim是一個高度可定製的文本編輯器,可以在Linux,Mac和Windows等不同的平台上運行。它具有快速移動,複製,粘貼,查找和替換等強大功能,尤其在面對大型…

    編程 2025-04-25

發表回復

登錄後才能評論