一、常見的importtorchvision模塊
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.utils as utils
torchvision是PyTorch的一個計算機視覺包,其中最為常用的模塊有transforms、datasets、models、utils。transforms模塊提供了常用的圖像預處理方法;datasets模塊提供了常見的視覺數據集(如CIFAR10、MNIST等);models模塊提供了經典的預訓練模型(如ResNet、VGG等);utils模塊提供了一些常用的工具函數。
使用這些模塊,我們可以方便地搭建計算機視覺領域的深度學習模型。
二、利用transforms模塊進行數據預處理
transform_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])])
trainset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform_train)
transforms模塊提供了方便的圖像預處理方法,可以在訓練網絡前對輸入數據進行必要的處理。在上述示例中,我們使用了transforms.Compose方法將多個變換組合在一起來處理輸入數據。其中:
- transforms.RandomCrop隨機裁剪圖像
- transforms.RandomHorizontalFlip隨機水平翻轉圖像
- transforms.ToTensor將圖像轉換為張量類型
- transforms.Normalize對圖像張量進行標準化處理,即將圖像張量減去均值再除以標準差,使得圖像張量的值在(-1, 1)之間
三、通過datasets模塊加載數據集
trainset = datasets.ImageFolder(root='path/to/data',
transform=transforms.ToTensor())
trainloader = torch.utils.data.DataLoader(trainset, batch_size=32,
shuffle=True, num_workers=4)
datasets模塊提供了常見的視覺數據集,使用datasets.ImageFolder方法可以加載自己的圖像數據集。其中,root參數指定圖像數據的根目錄,以文件夾的形式將不同類別的圖像分別存儲在不同的子目錄中。transform參數指定對輸入圖像進行的預處理方法。
使用torch.utils.data.DataLoader方法可以將數據集轉換為可供訓練功能使用的批量數據。
四、使用models模塊搭建深度學習模型
import torch.nn as nn
import torch.optim as optim
model = models.resnet18(pretrained=True)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, 2)
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)
models模塊提供了經典的預訓練模型,使用這些模型可以快速搭建自己的深度學習模型。在上述示例中,我們使用了pretrained=True指定使用預訓練的ResNet18模型,並且修改了全連接層的輸出結構。nn.CrossEntropyLoss指定損失函數,optim.SGD指定優化器。
五、通過utils模塊進行結果可視化
import numpy as np
import torchvision.transforms.functional as F
def imshow(img):
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
plt.show()
dataiter = iter(trainloader)
images, labels = dataiter.next()
imshow(utils.make_grid(images))
print(' '.join('%5s' % classes[labels[j]] for j in range(32)))
utils模塊提供了一些常用的工具函數,例如可視化結果。在上述示例中,我們使用了utils.make_grid將輸入數據製成網格圖,imshow函數將該圖可視化。結果如下所示:
cat dog dog deer deer plane horse dog deer cat ship frog truck deer dog truck horse deer deer dog deer truck car cat truck deer deer car dog truck dog truck plane plane car
六、總結
importtorchvision模塊是PyTorch計算機視覺方向的重要組成部分,提供了豐富的預處理方法(transforms)、數據集加載方法(datasets)、預訓練模型(models)以及工具函數(utils)。我們可以基於這些組件快速搭建自己的計算機視覺深度學習模型,並且通過utils模塊提供的可視化工具進行結果展示。
原創文章,作者:CSTDG,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/334227.html