一、基礎知識
參數初始化是深度學習模型中的重要環節之一,它直接影響到模型的泛化能力和訓練效果。在 PyTorch 中給參數賦初值有兩種方式,分別是手動設置和自動初始化。在使用手動設置時需要注意參數的大小、含義和初始化方式。同時,PyTorch 提供了一些默認的初始化方式,可以方便地使用。
在 PyTorch 中,模型的參數是存儲在 Parameter 類型的變量中,其初始化方式主要包括:
- 常量初始化
- 隨機初始化
- 預訓練模型初始化
二、常量初始化
常量初始化是最簡單的初始化方式,它將參數賦為固定的常量值。這種方式通常不是很常用,但有時候可以用於解決特殊的問題。例如,當我們想固定某些參數的值不變時,可以使用常量初始化。
import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.fc1 = nn.Linear(128*8*8, 1024) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(1024, 10) # 將conv1的權重初始化為0.1 nn.init.constant_(self.conv1.weight, 0.1)
三、隨機初始化
隨機初始化是最常用的初始化方式之一,它可以使得參數在一定範圍內發生變化,增強模型的泛化能力。隨機初始化通常包括以下幾種方式:
- 均勻分布初始化
- 正態分布初始化
- 截斷正態分布初始化
- 自定義初始化
其中,均勻分布初始化和正態分布初始化在 PyTorch 中均有對應的函數,可以直接使用。
import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(3, 64, kernel_size=3, padding=1) self.conv2 = nn.Conv2d(64, 128, kernel_size=3, padding=1) self.fc1 = nn.Linear(128*8*8, 1024) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(1024, 10) # 均勻分布初始化 nn.init.uniform_(self.fc1.weight, -0.1, 0.1) # 正態分布初始化 nn.init.normal_(self.fc2.weight, mean=0, std=0.01)
四、預訓練模型初始化
預訓練模型初始化是指使用已經預先訓練好的模型來初始化當前模型的參數,該方式在遷移學習中應用廣泛。在 PyTorch 中,使用預訓練模型進行初始化通常有兩種方式,分別是從文件中加載和在線下載。
import torchvision.models as models # 從文件中加載預訓練模型 resnet18 = models.resnet18(pretrained=True) # 在線下載預訓練模型,需要聯網 vgg16 = models.vgg16(pretrained=True)
五、自定義初始化
有時候,我們需要使用一些特殊的方式來初始化模型參數,這時就需要自定義初始化函數。在 PyTorch 中,可以使用自定義初始化函數來實現這一點。
import torch.nn as nn def my_init(m): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) nn.init.constant_(m.bias, 0.1) class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.fc1 = nn.Linear(128*8*8, 1024) self.fc2 = nn.Linear(1024, 10) # 自定義初始化 self.apply(my_init)
六、小結
在 PyTorch 中,參數初始化是深度學習模型中不可或缺的重要部分。了解各種初始化方式的優缺點,根據不同的網絡結構和需求選擇合適的初始化方式,是提高模型性能的重要手段。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/160606.html