softmax交叉熵損失函數詳解

在機器學習和深度學習中,選擇恰當的損失函數是十分重要的,因為它確定了模型的訓練方向和結果。softmax交叉熵損失函數是神經網路中用於分類問題的一種常見的損失函數。它是通過softmax函數將分類結果正則化後,再計算交叉熵損失值得到的。在這篇文章中,我們將從多個方面對softmax交叉熵損失函數進行詳細的闡述,包括梯度、求導、Python和PyTorch實現,以及與其他常見損失函數的區別。

一、softmax交叉熵損失函數梯度

梯度是損失函數優化的關鍵之一,因為梯度能反映損失函數在當前點的變化趨勢和方向。在softmax交叉熵損失函數中,我們先計算出softmax函數的輸出結果,再計算它們與標籤之間的交叉熵損失值。設softmax函數的輸入為$z$,輸出為$a$,標籤為$y$,則softmax交叉熵損失函數可以表示為:

    L = -sum(y[i]*log(a[i]))  # i為類別數

其中$log$表示自然對數,$sum$表示求和。其梯度可以使用鏈式法則來求解,即先求出$L$對$a$的偏導數$\frac{\partial L}{\partial a}$,再求出$a$對$z$的偏導數$\frac{\partial a}{\partial z}$,最後將兩者相乘即可得到$L$對$z$的偏導數$\frac{\partial L}{\partial z}$。具體推導過程如下:

1. $L$對$a$的偏導數:

    dL/da = [-y[i]/a[i] for i in range(n)]  # n為樣本數

2. $a$對$z$的偏導數:

    da/dz = diag(a) - a*a.T  # diag表示對角線,a為列向量,T表示轉置

3. $L$對$z$的偏導數:

    dL/dz = da/dz * dL/da

二、softmax交叉熵損失函數求導

求導是梯度的基礎,因為梯度本質上就是損失函數的導數。在softmax交叉熵損失函數中,我們需要求解$z$對參數$w$的偏導數$\frac{\partial z}{\partial w}$。在計算$\frac{\partial z}{\partial w}$時,我們需要計算$\frac{\partial z}{\partial a}$和$\frac{\partial a}{\partial z}$,因為它們都涉及到參數$w$。具體推導過程如下:

1. $z$對$w$的偏導數:

    dz/dw = x.T  # x為輸入特徵矩陣

2. $z$對$a$的偏導數:

    dz/da = w.T

3. $a$對$z$的偏導數:

    da/dz = diag(a) - a*a.T

4. $L$對$z$的偏導數:

    dL/dz = da/dz * dL/da

5. $L$對$w$的偏導數:

    dL/dw = dz/dw * dz/da * da/dz * dL/da

三、交叉熵損失函數Python

交叉熵損失函數是分類問題中常見的損失函數之一。在Python中,我們可以使用NumPy庫來實現交叉熵損失函數。假設我們有$N$個樣本,$K$個類別,預測值為$p$,真實值為$t$,則交叉熵損失函數可以表示為:

    def cross_entropy_loss(p, t):
        loss = 0.0
        for i in range(N):
            for j in range(K):
                loss += -t[i][j]*math.log(p[i][j])
        return loss

其中,$N$為樣本數,$K$為類別數。

四、softmax損失函數

softmax損失函數是神經網路中用於分類問題的一種常見的損失函數,它可以將分類結果正則化成一個概率分布,便於計算交叉熵損失值。在Python中,我們可以使用NumPy庫來實現softmax函數。假設我們有$N$個樣本,$K$個類別,預測值為$z$,則softmax函數可以表示為:

    def softmax(z):
        a = np.exp(z) / np.sum(np.exp(z), axis=1, keepdims=True)
        return a

其中,np.exp表示自然指數函數,axis=1表示對每個樣本在行方向上做softmax,keepdims=True表示保持二維矩陣的維度不變。

五、交叉熵損失函數和MSE區別

交叉熵損失函數和均方誤差(MSE)損失函數都是常見的神經網路損失函數。它們之間的區別在於適用範圍和求導方式。

交叉熵損失函數適用於分類問題,它將模型輸出的類別概率正則化成一個概率分布,並計算與真實標籤之間的交叉熵損失值。求導時,可以使用鏈式法則將分類結果和損失結果相乘,得到最終的梯度。

MSE損失函數適用於回歸問題,它計算的是模型輸出值與真實標籤之間的均方誤差,求導時只需要將輸出值與真實標籤之間的誤差乘以2即可得到梯度。

六、交叉熵損失函數PyTorch

在PyTorch中,我們可以使用交叉熵損失函數來計算模型的損失值。假設我們有$N$個樣本,$K$個類別,預測值為$p$,真實值為$t$,則交叉熵損失函數可以表示為:

    import torch.nn.functional as F
    loss = F.cross_entropy(p, t)

其中,F表示PyTorch中的函數庫,cross_entropy表示交叉熵損失函數,p為模型輸出的類別概率分布,t為真實標籤。在模型訓練過程中,我們可以將損失函數的值作為模型的優化目標,並使用反向傳播演算法來更新模型參數。

原創文章,作者:TBOJE,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/316061.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
TBOJE的頭像TBOJE
上一篇 2025-01-09 12:14
下一篇 2025-01-09 12:14

相關推薦

  • Python中引入上一級目錄中函數

    Python中經常需要調用其他文件夾中的模塊或函數,其中一個常見的操作是引入上一級目錄中的函數。在此,我們將從多個角度詳細解釋如何在Python中引入上一級目錄的函數。 一、加入環…

    編程 2025-04-29
  • Python中capitalize函數的使用

    在Python的字元串操作中,capitalize函數常常被用到,這個函數可以使字元串中的第一個單詞首字母大寫,其餘字母小寫。在本文中,我們將從以下幾個方面對capitalize函…

    編程 2025-04-29
  • Python中set函數的作用

    Python中set函數是一個有用的數據類型,可以被用於許多編程場景中。在這篇文章中,我們將學習Python中set函數的多個方面,從而深入了解這個函數在Python中的用途。 一…

    編程 2025-04-29
  • 三角函數用英語怎麼說

    三角函數,即三角比函數,是指在一個銳角三角形中某一角的對邊、鄰邊之比。在數學中,三角函數包括正弦、餘弦、正切等,它們在數學、物理、工程和計算機等領域都得到了廣泛的應用。 一、正弦函…

    編程 2025-04-29
  • 單片機列印函數

    單片機列印是指通過串口或並口將一些數據列印到終端設備上。在單片機應用中,列印非常重要。正確的列印數據可以讓我們知道單片機運行的狀態,方便我們進行調試;錯誤的列印數據可以幫助我們快速…

    編程 2025-04-29
  • Python3定義函數參數類型

    Python是一門動態類型語言,不需要在定義變數時顯示的指定變數類型,但是Python3中提供了函數參數類型的聲明功能,在函數定義時明確定義參數類型。在函數的形參後面加上冒號(:)…

    編程 2025-04-29
  • Python實現計算階乘的函數

    本文將介紹如何使用Python定義函數fact(n),計算n的階乘。 一、什麼是階乘 階乘指從1乘到指定數之間所有整數的乘積。如:5! = 5 * 4 * 3 * 2 * 1 = …

    編程 2025-04-29
  • Python定義函數判斷奇偶數

    本文將從多個方面詳細闡述Python定義函數判斷奇偶數的方法,並提供完整的代碼示例。 一、初步了解Python函數 在介紹Python如何定義函數判斷奇偶數之前,我們先來了解一下P…

    編程 2025-04-29
  • Python函數名稱相同參數不同:多態

    Python是一門面向對象的編程語言,它強烈支持多態性 一、什麼是多態多態是面向對象三大特性中的一種,它指的是:相同的函數名稱可以有不同的實現方式。也就是說,不同的對象調用同名方法…

    編程 2025-04-29
  • 分段函數Python

    本文將從以下幾個方面詳細闡述Python中的分段函數,包括函數基本定義、調用示例、圖像繪製、函數優化和應用實例。 一、函數基本定義 分段函數又稱為條件函數,指一條直線段或曲線段,由…

    編程 2025-04-29

發表回復

登錄後才能評論