PyTorch參數初始化詳解

一、基礎知識

參數初始化是深度學習模型中的重要環節之一,它直接影響到模型的泛化能力和訓練效果。在 PyTorch 中給參數賦初值有兩種方式,分別是手動設置和自動初始化。在使用手動設置時需要注意參數的大小、含義和初始化方式。同時,PyTorch 提供了一些默認的初始化方式,可以方便地使用。

在 PyTorch 中,模型的參數是存儲在 Parameter 類型的變量中,其初始化方式主要包括:

  • 常量初始化
  • 隨機初始化
  • 預訓練模型初始化

二、常量初始化

常量初始化是最簡單的初始化方式,它將參數賦為固定的常量值。這種方式通常不是很常用,但有時候可以用於解決特殊的問題。例如,當我們想固定某些參數的值不變時,可以使用常量初始化。

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128*8*8, 1024)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 10)

        # 將conv1的權重初始化為0.1
        nn.init.constant_(self.conv1.weight, 0.1)

三、隨機初始化

隨機初始化是最常用的初始化方式之一,它可以使得參數在一定範圍內發生變化,增強模型的泛化能力。隨機初始化通常包括以下幾種方式:

  • 均勻分布初始化
  • 正態分布初始化
  • 截斷正態分布初始化
  • 自定義初始化

其中,均勻分布初始化和正態分布初始化在 PyTorch 中均有對應的函數,可以直接使用。

import torch.nn as nn

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1)
        self.fc1 = nn.Linear(128*8*8, 1024)
        self.dropout = nn.Dropout(0.5)
        self.fc2 = nn.Linear(1024, 10)

        # 均勻分布初始化
        nn.init.uniform_(self.fc1.weight, -0.1, 0.1)

        # 正態分布初始化
        nn.init.normal_(self.fc2.weight, mean=0, std=0.01)

四、預訓練模型初始化

預訓練模型初始化是指使用已經預先訓練好的模型來初始化當前模型的參數,該方式在遷移學習中應用廣泛。在 PyTorch 中,使用預訓練模型進行初始化通常有兩種方式,分別是從文件中加載和在線下載。

import torchvision.models as models

# 從文件中加載預訓練模型
resnet18 = models.resnet18(pretrained=True)

# 在線下載預訓練模型,需要聯網
vgg16 = models.vgg16(pretrained=True)

五、自定義初始化

有時候,我們需要使用一些特殊的方式來初始化模型參數,這時就需要自定義初始化函數。在 PyTorch 中,可以使用自定義初始化函數來實現這一點。

import torch.nn as nn

def my_init(m):
    if isinstance(m, nn.Linear):
        nn.init.xavier_uniform_(m.weight)
        nn.init.constant_(m.bias, 0.1)

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.fc1 = nn.Linear(128*8*8, 1024)
        self.fc2 = nn.Linear(1024, 10)

        # 自定義初始化
        self.apply(my_init)

六、小結

在 PyTorch 中,參數初始化是深度學習模型中不可或缺的重要部分。了解各種初始化方式的優缺點,根據不同的網絡結構和需求選擇合適的初始化方式,是提高模型性能的重要手段。

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/160606.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-11-21 01:17
下一篇 2024-11-21 01:17

相關推薦

  • 三星內存條參數用法介紹

    本文將詳細解釋三星內存條上面的各種參數,讓你更好地了解內存條並選擇適合自己的一款。 一、容量大小 容量大小是內存條最基本的參數,一般以GB為單位表示,常見的有2GB、4GB、8GB…

    編程 2025-04-29
  • Python3定義函數參數類型

    Python是一門動態類型語言,不需要在定義變量時顯示的指定變量類型,但是Python3中提供了函數參數類型的聲明功能,在函數定義時明確定義參數類型。在函數的形參後面加上冒號(:)…

    編程 2025-04-29
  • Python input參數變量用法介紹

    本文將從多個方面對Python input括號里參數變量進行闡述與詳解,並提供相應的代碼示例。 一、基本介紹 Python input()函數用於獲取用戶輸入。當程序運行到inpu…

    編程 2025-04-29
  • Spring Boot中發GET請求參數的處理

    本文將詳細介紹如何在Spring Boot中處理GET請求參數,並給出完整的代碼示例。 一、Spring Boot的GET請求參數基礎 在Spring Boot中,處理GET請求參…

    編程 2025-04-29
  • Python函數名稱相同參數不同:多態

    Python是一門面向對象的編程語言,它強烈支持多態性 一、什麼是多態多態是面向對象三大特性中的一種,它指的是:相同的函數名稱可以有不同的實現方式。也就是說,不同的對象調用同名方法…

    編程 2025-04-29
  • Python Class括號中的參數用法介紹

    本文將對Python中類的括號中的參數進行詳細解析,以幫助初學者熟悉和掌握類的創建以及參數設置。 一、Class的基本定義 在Python中,通過使用關鍵字class來定義類。類包…

    編程 2025-04-29
  • Hibernate日誌打印sql參數

    本文將從多個方面介紹如何在Hibernate中打印SQL參數。Hibernate作為一種ORM框架,可以通過打印SQL參數方便開發者調試和優化Hibernate應用。 一、通過配置…

    編程 2025-04-29
  • 全能編程開發工程師必知——DTD、XML、XSD以及DTD參數實體

    本文將從大體介紹DTD、XML以及XSD三大知識點,同時深入探究DTD參數實體的作用及實際應用場景。 一、DTD介紹 DTD是文檔類型定義(Document Type Defini…

    編程 2025-04-29
  • Python可變參數

    本文旨在對Python中可變參數進行詳細的探究和講解,包括可變參數的概念、實現方式、使用場景等多個方面,希望能夠對Python開發者有所幫助。 一、可變參數的概念 可變參數是指函數…

    編程 2025-04-29
  • XGBoost n_estimator參數調節

    XGBoost 是 處理結構化數據常用的機器學習框架之一,其中的 n_estimator 參數決定着模型的複雜度和訓練速度,這篇文章將從多個方面詳細闡述 n_estimator 參…

    編程 2025-04-28

發表回復

登錄後才能評論