一、Pytorch模型載入概述
PyTorch是一個使用GPU和CPU優化的深度學習張量庫,它也是一個動態神經網路構建工具。Pytorch模型載入是將已訓練好的模型載入到內存中,以便使用。模型載入是模型應用的前提。Pytorch模型載入涉及到模型的序列化,反序列化和模型參數的賦值等操作。Pytorch中支持多種不同的序列化和反序列化方法,包括pickle和h5py等,其中最常用的方法是torch.save()和torch.load()函數。
二、Pytorch模型的保存與載入
1、保存模型
import torch
# 定義模型
model = torch.nn.Sequential(
torch.nn.Linear(10, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 1),
torch.nn.Sigmoid()
)
# 保存模型
torch.save(model.state_dict(), 'model.pth')
在上面的代碼中,我們首先定義了一個簡單的模型,然後使用torch.save()函數將模型參數保存到了’model.pth’文件中。
2、載入模型
import torch
# 定義模型
model = torch.nn.Sequential(
torch.nn.Linear(10, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 1),
torch.nn.Sigmoid()
)
# 載入模型
model.load_state_dict(torch.load('model.pth'))
在上面的代碼中,我們首先定義了一個模型,然後使用torch.load()函數載入’model.pth’文件中的參數,最後使用load_state_dict()函數將參數賦值給模型。
三、Pytorch模型載入的不同形式
1、載入整個模型
import torch
# 保存模型
torch.save(model, 'model.pth')
# 載入模型
model = torch.load('model.pth')
在這個例子中,我們使用了torch.save()函數保存整個模型,並使用torch.load()載入整個模型。
2、多個模型的保存與載入
import torch
# 定義多個模型
model1 = torch.nn.Sequential(
torch.nn.Linear(10, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 1),
torch.nn.Sigmoid()
)
model2 = torch.nn.Sequential(
torch.nn.Linear(10, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 10),
torch.nn.Softmax()
)
# 保存多個模型
torch.save({
'model1': model1.state_dict(),
'model2': model2.state_dict()
}, 'multi_model.pth')
# 載入多個模型
checkpoint = torch.load('multi_model.pth')
model1.load_state_dict(checkpoint['model1'])
model2.load_state_dict(checkpoint['model2'])
在這個例子中,我們定義了兩個模型,然後使用torch.save()函數保存多個模型的參數,並使用torch.load()載入多個模型的參數,最後使用load_state_dict()函數將參數賦值給對應的模型。
3、CPU/GPU間的模型載入
import torch
# 保存模型
torch.save(model.state_dict(), 'model.pth')
# 在CPU中載入模型
model_cpu = torch.nn.Sequential(
torch.nn.Linear(10, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 1),
torch.nn.Sigmoid()
)
model_cpu.load_state_dict(torch.load('model.pth', map_location=torch.device('cpu')))
# 在GPU中載入模型
model_gpu = torch.nn.Sequential(
torch.nn.Linear(10, 100),
torch.nn.ReLU(),
torch.nn.Linear(100, 1),
torch.nn.Sigmoid()
).to('cuda')
model_gpu.load_state_dict(torch.load('model.pth'))
在這個例子中,我們首先使用torch.save()函數保存模型參數,然後使用torch.load()函數在CPU和GPU中載入模型。需要注意的是,在載入模型時需要使用map_location參數將模型參數映射到對應的設備上。如果我們想要將模型載入到GPU上,則需要通過.to(‘cuda’)將模型轉移到GPU上。
四、總結
PyTorch模型載入是將已訓練好的模型載入到內存中,以便進行推理或微調。在PyTorch中,我們可以使用torch.save()函數保存模型權重,使用torch.load()函數載入模型權重,並使用load_state_dict()函數將權重賦值給模型。同時,我們還可以保存和載入多個模型,將模型載入到不同的設備上運行。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/308407.html