PyTorch 是目前深度學習領域最流行的框架之一。其提供了豐富的功能和靈活性,使其成為科學家和開發人員的首選選擇。在 PyTorch 中,transforms 是用於轉換圖像和數據的重要組件之一。transforms 模塊提供了各種常用的圖像變換方法,既可以應用於訓練數據集的標註,也可以應用於模型的輸入數據。本文將從多個方面詳細闡述 transforms 模塊,旨在讓讀者更深入理解 PyTorch 中的 transforms 模塊。
一、預處理
在 PyTorch 中,圖像預處理是模型訓練中非常重要的一環,transforms 模塊提供了多種處理方式。使用 transforms.Compose 方法,可以將多種變換組合在一起,並且按照給定的順序依次應用。
下面是一個簡單的示例代碼:
import torch
from torchvision import transforms
from PIL import Image
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
img = Image.open('image.jpg')
img_t = transform(img)
這個例子中,我們首先定義了一個 Compose 變換,其中包括了四個變換,分別是:
- Resize:將圖片調整為給定尺寸。
- CenterCrop:將圖片按中心進行剪裁。
- ToTensor:將 PIL 圖片轉換為 PyTorch Tensor。
- Normalize:對 Tensor 進行標準化,使其具有零均值和單位方差。
在變換後,我們使用 PIL 庫打開一張 jpeg 格式的圖片,並應用定義好的變換。最終,我們得到了一個經過預處理的 Tensor。
二、數據增強
在深度學習任務中,數據增強是一個非常重要的方法,目的是通過對原始數據進行變換,擴充訓練集的規模,增加模型的魯棒性。transforms 模塊提供了多種用於數據增強的變換,在 PyTorch 中非常容易使用。
下面是一個包括 RandomCrop、RandomHorizontalFlip 和 ColorJitter 的數據增強示例:
transform = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ColorJitter(brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
在這個例子中,我們使用了 RandomResizedCrop 可以對圖片進行隨機大小裁剪,RandomHorizontalFlip 可以隨機翻轉圖片。另外,ColorJitter 可以進行隨機顏色調整。
三、自定義變換
在 PyTorch 中,我們可以自定義一些數據變換類,來滿足特定的需求。為了創建一個自定義的變換類,我們需要繼承 torchvision.transforms 中的 Transform 類,並實現 __call__ 方法。__call__ 函數需要接受一個 PIL 圖像,並返回預處理後的圖像。下面是一個將圖片進行鏡像翻轉的自定義變換:
class Mirror(object):
def __call__(self, img):
return img.transpose(method=Image.FLIP_LEFT_RIGHT)
transform = transforms.Compose([
transforms.Resize(256),
Mirror(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])
在這個例子中,我們定義了一個 Mirror 變換,它繼承了 Transform 類,並在 __call__ 方法中實現了鏡像翻轉操作。在 Compose 中使用 Mirror 變換可以讓圖片在水平方向上進行翻轉。
四、應用於 Torchvision 庫
PyTorch 的 Torchvision 庫是一個用於計算機視覺任務的常用庫。Torchvision 提供了許多常用的數據集和模型,同時也提供了一些常用的 transforms 變換,可以用於數據集的預處理和數據增強。為了使用 Torchvision 中的 transforms,我們可以直接在 torchvision.transforms 中導入對應的變換。
下面是一個使用 Torchvision 庫的示例:
import torchvision.transforms as transforms
train_transforms = transforms.Compose([
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
val_transforms = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
在這個例子中,我們定義了兩個用於訓練集和驗證集的 transforms。對於訓練集,我們使用了 RandomResizedCrop 和 RandomHorizontalFlip 來進行數據增強;對於驗證集,我們使用了 CenterCrop 來進行預處理。
總結
本文詳細介紹了 PyTorch 中的 transforms 模塊。transforms 提供了用於預處理和數據增強的常見變換和自定義變換。除了可以應用於 PyTorch 訓練數據集的標註之外,transforms 可以應用於模型的輸入數據以及 Torchvision 庫中的數據集。
原創文章,作者:LGNPH,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/372421.html