一、基本概念
ImageFolder是PyTorch中一個非常實用的類,它可以將一個文件夾中的圖片按照預先定義好的transform操作轉換為PyTorch中可以使用的Tensor。
示例代碼:
import torch.utils.data as data
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader
class ImageFolder(data.Dataset):
def __init__(self, root, transform=None, target_transform=None,
loader=default_loader, is_valid_file=None):
...
def __getitem__(self, index):
...
def __len__(self):
...
其中root指定需要讀取的文件夾路徑,transform指定需要進行的數據轉換操作(如樣本隨機旋轉、隨機裁剪等),target_transform指定目標標籤的轉換操作,比如將字符串類型轉換為數字類型;loader指定需要使用哪個文件讀取器,默認為PIL.ImageLoader。
二、應用場景
ImageFolder可以方便地獲取文件夾中的所有圖片,並進行各種數據增強操作,通常應用於圖像分類、目標檢測、圖像分割等領域。比如在圖像分類中,可以通過ImageFolder將每一類圖片存放在一個文件夾中,文件夾的名稱即為類別名,提高了文件夾的結構化程度。
示例代碼:
import torchvision.transforms as transforms
import torchvision.datasets as datasets
# 對數據進行隨機旋轉和裁剪
train_transform = transforms.Compose([
transforms.RandomRotation(30),
transforms.RandomResizedCrop(224),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
# 加載數據集
train_dataset = datasets.ImageFolder('./train', transform=train_transform)
首先定義了一個train_transform,其中包括了隨機旋轉、隨機裁剪、隨機水平翻轉等操作;接着使用ImageFolder加載文件夾中的所有圖片,並將train_transform作為參數傳入,實現對數據集的增強操作。
三、數據預處理
在使用ImageFolder時,需要對數據進行預處理,包括數據增強、歸一化等操作,以加快訓練速度、提升模型精度。具體可以根據不同的實際應用場景選擇不同的預處理方式。
示例代碼:
import torchvision.transforms as transforms
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
以上代碼定義了一個名為transform的預處理操作,首先將圖片resize到256×256,然後進行中心裁剪到224×224,接着將圖片轉換成Tensor格式,並進行歸一化操作。
四、可視化數據
ImageFolder還可以用於可視化數據,方便我們觀察樣本的特點和變化。以下代碼展示了如何將10張圖片隨機可視化出來。
示例代碼:
import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np
# 加載數據集
transform = transforms.Compose([
transforms.Resize(256),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder('./train', transform=transform)
# 隨機可視化數據集中的10張圖片
def display_imgs(imgs, labels):
# 將Tensor轉換為numpy數組
imgs = imgs.numpy().transpose((0, 2, 3, 1))
mean = np.array([0.485, 0.456, 0.406])
std = np.array([0.229, 0.224, 0.225])
imgs = std * imgs + mean
imgs = np.clip(imgs, 0, 1)
# 可視化圖片
fig = plt.figure(figsize=(25, 20))
for i in range(10):
ax = fig.add_subplot(2, 5, i + 1, xticks=[], yticks=[])
ax.imshow(imgs[i])
ax.set_title(train_dataset.classes[labels[i]])
fig.show()
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10,
shuffle=True, num_workers=4)
data_iter = iter(data_loader)
imgs, labels = data_iter.next()
display_imgs(imgs, labels)
以上代碼首先使用ImageFolder加載數據集並進行預處理,然後隨機選擇10張樣本進行可視化。其中display_imgs函數用於將Tensor格式的圖片轉換為numpy數組,並進行可視化。最終結果為10張隨機選擇的圖片以及其對應的類別。
五、小結
ImageFolder是PyTorch中一個非常實用的類,可以方便地讀取文件夾中的圖片,並進行數據增強、數據預處理、可視化等操作。在圖像分類、目標檢測、圖像分割等領域中都有廣泛的應用。
原創文章,作者:YMFTV,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/368629.html