深入了解PyTorch中的ImageFolder

一、基本概念

ImageFolder是PyTorch中一個非常實用的類,它可以將一個文件夾中的圖片按照預先定義好的transform操作轉換為PyTorch中可以使用的Tensor。

示例代碼:

import torch.utils.data as data
from torchvision.datasets.folder import IMG_EXTENSIONS, default_loader

class ImageFolder(data.Dataset):
    def __init__(self, root, transform=None, target_transform=None,
                 loader=default_loader, is_valid_file=None):
        ...
    def __getitem__(self, index):
        ...
    def __len__(self):
        ...

其中root指定需要讀取的文件夾路徑,transform指定需要進行的數據轉換操作(如樣本隨機旋轉、隨機裁剪等),target_transform指定目標標籤的轉換操作,比如將字符串類型轉換為數字類型;loader指定需要使用哪個文件讀取器,默認為PIL.ImageLoader。

二、應用場景

ImageFolder可以方便地獲取文件夾中的所有圖片,並進行各種數據增強操作,通常應用於圖像分類、目標檢測、圖像分割等領域。比如在圖像分類中,可以通過ImageFolder將每一類圖片存放在一個文件夾中,文件夾的名稱即為類別名,提高了文件夾的結構化程度。

示例代碼:

import torchvision.transforms as transforms
import torchvision.datasets as datasets

# 對數據進行隨機旋轉和裁剪
train_transform = transforms.Compose([
    transforms.RandomRotation(30),
    transforms.RandomResizedCrop(224),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

# 加載數據集
train_dataset = datasets.ImageFolder('./train', transform=train_transform)

首先定義了一個train_transform,其中包括了隨機旋轉、隨機裁剪、隨機水平翻轉等操作;接着使用ImageFolder加載文件夾中的所有圖片,並將train_transform作為參數傳入,實現對數據集的增強操作。

三、數據預處理

在使用ImageFolder時,需要對數據進行預處理,包括數據增強、歸一化等操作,以加快訓練速度、提升模型精度。具體可以根據不同的實際應用場景選擇不同的預處理方式。

示例代碼:

import torchvision.transforms as transforms

transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])

以上代碼定義了一個名為transform的預處理操作,首先將圖片resize到256×256,然後進行中心裁剪到224×224,接着將圖片轉換成Tensor格式,並進行歸一化操作。

四、可視化數據

ImageFolder還可以用於可視化數據,方便我們觀察樣本的特點和變化。以下代碼展示了如何將10張圖片隨機可視化出來。

示例代碼:

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import matplotlib.pyplot as plt
import numpy as np

# 加載數據集
transform = transforms.Compose([
    transforms.Resize(256),
    transforms.CenterCrop(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                         std=[0.229, 0.224, 0.225])
])
train_dataset = datasets.ImageFolder('./train', transform=transform)

# 隨機可視化數據集中的10張圖片
def display_imgs(imgs, labels):
    # 將Tensor轉換為numpy數組
    imgs = imgs.numpy().transpose((0, 2, 3, 1))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    imgs = std * imgs + mean
    imgs = np.clip(imgs, 0, 1)
    # 可視化圖片
    fig = plt.figure(figsize=(25, 20))
    for i in range(10):
        ax = fig.add_subplot(2, 5, i + 1, xticks=[], yticks=[])
        ax.imshow(imgs[i])
        ax.set_title(train_dataset.classes[labels[i]])
    fig.show()

data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=10,
                                          shuffle=True, num_workers=4)
data_iter = iter(data_loader)
imgs, labels = data_iter.next()
display_imgs(imgs, labels)

以上代碼首先使用ImageFolder加載數據集並進行預處理,然後隨機選擇10張樣本進行可視化。其中display_imgs函數用於將Tensor格式的圖片轉換為numpy數組,並進行可視化。最終結果為10張隨機選擇的圖片以及其對應的類別。

五、小結

ImageFolder是PyTorch中一個非常實用的類,可以方便地讀取文件夾中的圖片,並進行數據增強、數據預處理、可視化等操作。在圖像分類、目標檢測、圖像分割等領域中都有廣泛的應用。

原創文章,作者:YMFTV,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/368629.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
YMFTV的頭像YMFTV
上一篇 2025-04-12 01:13
下一篇 2025-04-12 01:13

相關推薦

  • PyTorch模塊簡介

    PyTorch是一個開源的機器學習框架,它基於Torch,是一個Python優先的深度學習框架,同時也支持C++,非常容易上手。PyTorch中的核心模塊是torch,提供一些很好…

    編程 2025-04-27
  • 深入解析Vue3 defineExpose

    Vue 3在開發過程中引入了新的API `defineExpose`。在以前的版本中,我們經常使用 `$attrs` 和` $listeners` 實現父組件與子組件之間的通信,但…

    編程 2025-04-25
  • 深入理解byte轉int

    一、字節與比特 在討論byte轉int之前,我們需要了解字節和比特的概念。字節是計算機存儲單位的一種,通常表示8個比特(bit),即1字節=8比特。比特是計算機中最小的數據單位,是…

    編程 2025-04-25
  • 深入理解Flutter StreamBuilder

    一、什麼是Flutter StreamBuilder? Flutter StreamBuilder是Flutter框架中的一個內置小部件,它可以監測數據流(Stream)中數據的變…

    編程 2025-04-25
  • 深入探討OpenCV版本

    OpenCV是一個用於計算機視覺應用程序的開源庫。它是由英特爾公司創建的,現已由Willow Garage管理。OpenCV旨在提供一個易於使用的計算機視覺和機器學習基礎架構,以實…

    編程 2025-04-25
  • 深入了解scala-maven-plugin

    一、簡介 Scala-maven-plugin 是一個創造和管理 Scala 項目的maven插件,它可以自動生成基本項目結構、依賴配置、Scala文件等。使用它可以使我們專註於代…

    編程 2025-04-25
  • 深入了解LaTeX的腳註(latexfootnote)

    一、基本介紹 LaTeX作為一種排版軟件,具有各種各樣的功能,其中腳註(footnote)是一個十分重要的功能之一。在LaTeX中,腳註是用命令latexfootnote來實現的。…

    編程 2025-04-25
  • 深入探討馮諾依曼原理

    一、原理概述 馮諾依曼原理,又稱“存儲程序控制原理”,是指計算機的程序和數據都存儲在同一個存儲器中,並且通過一個統一的總線來傳輸數據。這個原理的提出,是計算機科學發展中的重大進展,…

    編程 2025-04-25
  • 深入了解Python包

    一、包的概念 Python中一個程序就是一個模塊,而一個模塊可以引入另一個模塊,這樣就形成了包。包就是有多個模塊組成的一個大模塊,也可以看做是一個文件夾。包可以有效地組織代碼和數據…

    編程 2025-04-25
  • 深入剖析MapStruct未生成實現類問題

    一、MapStruct簡介 MapStruct是一個Java bean映射器,它通過註解和代碼生成來在Java bean之間轉換成本類代碼,實現類型安全,簡單而不失靈活。 作為一個…

    編程 2025-04-25

發表回復

登錄後才能評論