一、功能概述
load_state_dict是PyTorch中一個非常重要的方法,它可以將一個已經訓練好的模型的參數加載到另一個同樣結構的模型中。在實際使用中,它經常用於預訓練模型的遷移學習、模型參數的恢復等場景。在這一部分,我們將介紹load_state_dict方法的基本用法以及其調用的原理。
model_dict = model.state_dict() # 此時model還未更新過,其參數未被優化器更改 pretrained_dict = torch.load(PATH) # filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # overwrite entries in the existing state dict model_dict.update(pretrained_dict) model.load_state_dict(model_dict)
二、參數說明
load_state_dict方法有一個必要的參數,即pretrained_dict,表示已經訓練好的模型的參數,它是一個Python字典。該參數需要滿足以下兩個要求:
1、字典的鍵值對應着模型中各層的名稱
2、字典的值是一個已經訓練好的張量
在使用時需要注意,預訓練模型和目標模型的結構必須一致。
三、基本用法
load_state_dict方法的基本用法非常簡單,只需要通過Python字典構造函數構造一個預訓練模型的參數字典,然後使用load_state_dict方法將其加載到目標模型中即可。下面是一段簡單的示例代碼:
model = Net() pretrained_dict = torch.load(PATH) model.load_state_dict(pretrained_dict)
四、加載部分參數
在有些情況下,我們只需要加載模型的部分參數。例如,我們想僅加載預訓練模型中某些層的參數而保持目標模型中其他層的參數不變。在這種情況下,需要將pretrained_dict中不需要的部分剔除,可以使用Python字典的推導式來完成這一操作:
model_dict = model.state_dict() pretrained_dict = torch.load(PATH) # filter out unnecessary keys pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # overwrite entries in the existing state dict model_dict.update(pretrained_dict) model.load_state_dict(model_dict)
五、跨設備加載
在使用load_state_dict方法時,需要注意張量的設備類型和ID。如果預訓練模型和目標模型的設備類型或ID不同,就需要對預訓練模型中的參數進行相應的修改才能使其被成功加載。下面是一段示例代碼:
model = nn.DataParallel(model) pretrained_dict = torch.load(PATH) # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in pretrained_dict.items(): name = k[7:] # remove `module.` new_state_dict[name] = v # load params model.load_state_dict(new_state_dict)
六、加載到指定的層
有時候,我們可能只需要把預訓練模型的部分參數加載到目標模型的指定層中,而不需要覆蓋整個目標模型的參數。在這種情況下,我們需要手動獲取指定層的state_dict,並將預訓練模型中對應的參數賦值給該state_dict。下面是一段示例代碼:
model = Net() pretrained_dict = torch.load(PATH) # get the dict of a module net_dict = model.net.state_dict() pretrained_dict = {'.'.join(k.split('.')[1:]): v for k, v in pretrained_dict.items() if k.split('.')[1] == 'net'} # overwrite entries in the state dict for this module net_dict.update(pretrained_dict) model.net.load_state_dict(net_dict)
原創文章,作者:ODCF,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/134473.html