一、.load_state_dict的介紹
.load_state_dict方法是PyTorch中一個十分重要的方法,它可以將預訓練模型的狀態字典載入到新的模型中。模型的狀態字典包含了模型的參數和緩衝器
該方法的作用是載入參數和緩衝器,並且使用嚴格的參數匹配,如果有對應不上的參數,會報錯。
def load_state_dict(self, state_dict: 'OrderedDict[str, Tensor]', strict: bool = True) -> None: r"""Loads a model's parameter dictionary (state_dict). Arguments: state_dict (dict): a dict containing parameters and persistent buffers. strict (bool, optional): whether to strictly enforce that the keys in :attr:`state_dict` match the keys returned by this module's :meth:`state_dict` function. Default: ``True`` Returns: None .. note:: The :attr:`strict` parameter has home-field advantage here. See the note in :meth:`torch.nn.Module.load_state_dict` for a description of how it's used. """
二、.load_state_dict方法的應用場景
.load_state_dict方法是在訓練中使用預訓練模型時常用的方法。預訓練模型的狀態字典不能直接複製到一個新模型中,需要使用.load_state_dict方法來恢復模型。
在遷移學習中,我們可以使用已訓練好的模型,將其參數作為新模型的初始參數,然後再在該基礎上進行訓練,從而加速我們的訓練過程,提高模型的性能。
下面是一段使用.load_state_dict方法載入預訓練模型並用來進行測試的代碼:
import torch import torch.nn as nn import torchvision.models as models model = models.resnet18(pretrained=True) fc_inputs = model.fc.in_features model.fc = nn.Sequential( nn.Linear(fc_inputs, 1024), nn.ReLU(inplace=True), nn.Linear(1024, 10)) model.load_state_dict(torch.load('resnet18.pth')) # test the model inputs = torch.randn(1, 3, 224, 224) outputs = model(inputs) print(outputs.shape)
三、.load_state_dict方法的常用參數
1、state_dict參數
state_dict是一個包含了參數和緩衝器的字典。這個字典可以從一個已經訓練好的模型中獲取,也可以通過state_dict()方法獲取當前模型的參數字典。
例如:
model = torchvision.models.resnet18(pretrained=True) state_dict = model.state_dict()
2、strict參數
strict參數是一個布爾類型的值,用於標記是否使用嚴格的參數匹配。
如果strict=True,則state_dict中的參數名稱必須與新模型中的參數名稱完全匹配,否則會報錯。
如果strict=False,則新模型中沒有指定的參數,就忽略掉,而不會報錯。
四、.load_state_dict方法的注意事項
1、模型的架構需要保持一致
.load_state_dict方法的使用需要注意模型的架構必須與原始模型的架構完全相同,否則將無法載入參數。如果想要更改模型的架構,可以使用torch.nn.Sequential()重新構造模型。
2、載入預訓練模型需要正確指定路徑
如果我們需要載入一個預訓練模型,需要正確指定預訓練模型的位置。一般來說,預訓練模型被保存為一個.pth文件。如果.pth文件和模型代碼不在同一個文件夾中,則需要使用正確的路徑來載入模型。
# 模型保存在model文件夾中的resnet18.pth文件中 model = models.resnet18(pretrained=True) model.load_state_dict(torch.load('model/resnet18.pth'))
3、.load_state_dict方法與.freeze_layers()方法的配合使用
當使用預訓練模型進行遷移學習時,我們常常需要固定一些層的參數,只更新特定的層。在這種情況下,我們可以使用.freeze_layers()方法來凍結層的參數,在反向傳播時不進行參數更新。在.load_state_dict()方法中,我們需要排除掉已凍結的層,否則這些層的參數將會被載入進去。
例如:
model = torchvision.models.resnet18(pretrained=True) for param in model.parameters(): param.requires_grad = False num_features = model.fc.in_features model.fc = nn.Linear(num_features, 10) # 假設已經凍結了卷積層的參數 params_to_update = [] for name, param in model.named_parameters(): if '.bn' not in name: params_to_update.append(param) optimizer = torch.optim.Adam(params_to_update)
在以上代碼中,.freeze_layers()方法已經凍結了所有的卷積層,現在我們只更新全連接層的參數。所以在.load_state_dict()方法中,我們需要指定只載入全連接層的參數:
model.load_state_dict(torch.load('model_weights.pth'), strict=False)
五、總結
在本文中,我們詳細講解了PyTorch中.load_state_dict()方法的使用方法及注意事項。通過本文的介紹,我們可以清楚地知道如何在訓練中使用預訓練模型,並且了解了一些需要注意的問題。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/192453.html