CIFAR-10數據集:理解、使用和優化

一、CIFAR-10數據集的基本介紹

CIFAR-10數據集是深度學習中常用的一個圖像分類數據集,其中包含了60000張32×32的彩色圖像,共分為10個類別,每個類別中有6000張圖片。CIFAR-10數據集中的10類物體分類分別是:’airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’, ‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。其中50000張圖片用於訓練集,10000張圖片用於測試集。

CIFAR-10數據集可以幫助開發者在小型問題上進行深度學習訓練和測試,可以使得深度學習的應用更加輕便快捷。

二、CIFAR-10數據集的難點及其優化

1、CIFAR-10數據集的樣本數量較少,只有50000張訓練數據,與ImageNet等大型數據集相比時,缺乏數據大小和方差(variance)分布等差異性。

2、CIFAR-10數據集中的圖像大小只有32×32,相對於其他圖像數據集,這個圖像大小較小,因此圖像的解析度會受到影響。此時可以選擇調整網路模型結構以適應小尺寸圖像,也可以選擇使用圖像增強方法以提高解析度。

3、CIFAR-10數據集中存在著一些可遮擋性、背景雜亂無章、尺度變化較大等難題,這會對模型的訓練造成一定的影響。可以通過圖像增強方法解決這些問題,也可以選擇改變網路模型結構以適應這些特殊問題。

三、CIFAR-10分類識別

在CIFAR-10數據集分類識別任務中,我們通常需要經過以下步驟:

首先,我們需要將CIFAR-10數據集準備好,將數據集讀取進來,進行訓練集和測試集的分類。

import numpy as np
import tensorflow as tf
from keras.datasets import cifar10

(X_train, y_train), (X_test, y_test) = cifar10.load_data()

然後,我們需要對數據集進行預處理,包括,圖像歸一化和one-hot編碼。

X_train_norm = tf.keras.utils.normalize(X_train, axis=1)
X_test_norm = tf.keras.utils.normalize(X_test, axis=1)

y_train_onehot = tf.keras.utils.to_categorical(y_train, num_classes=10)
y_test_onehot = tf.keras.utils.to_categorical(y_test, num_classes=10)

接下來,我們開始構建深度學習模型,可以使用卷積神經網路(CNN)作為模型,因為CNN在圖像識別方面表現不俗。

model = tf.keras.Sequential()

# block 1
model.add(tf.keras.layers.Conv2D(32, (3,3), activation='relu', padding='same', input_shape=(32, 32, 3)))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Dropout(0.25))

#block 2
model.add(tf.keras.layers.Conv2D(64, (3,3), activation='relu', padding='same'))
model.add(tf.keras.layers.MaxPooling2D((2,2)))
model.add(tf.keras.layers.Dropout(0.25))

model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation='relu'))
model.add(tf.keras.layers.Dropout(0.5))
model.add(tf.keras.layers.Dense(10, activation='softmax'))

在模型訓練方面,我們可以通過Keras提供的fit()函數,來訓練這個模型。

model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

history = model.fit(X_train_norm, y_train_onehot, batch_size=64, epochs=40, validation_data=(X_test_norm, y_test_onehot))

四、CIFAR-10數據集和CIFAR-100數據集的比較

CIFAR-10數據集和CIFAR-100數據集非常相似,分別具有10個類和100個類的情況下都有著相同的圖像大小和顏色數量。CIFAR-100數據集則是在CIFAR-10的基礎上,增加了90個類別,每個類別下有600張訓練圖片和100張測試圖片。與CIFAR-10數據集相比,CIFAR-100數據集包含更多的類別,因此更適合用於大規模多分類問題的研究。

五、CIFAR-10的準確率排名和召回率怎麼計算

CIFAR-10的準確率排名計算公式如下:

準確率排名 = 正確分類的總數 ÷ 測試數據集總數

召回率則是指所有正確分類的數據個數除以總測試集中所包含的該類別的個數。例如,假設測試集裡面一共有100個樣本被標記成了某一種類別,而分類器只能正確識別了80個,則這種類別的召回率便是80%。

六、CIFAR-10的PyTorch實現

PyTorch是由Facebook提供的神經網路框架,通過torchvision.libs.datasets模塊中的torchvision.datasets.CIFAR10()方法可以實現CIFAR-10數據集的讀取,通過定義網路模型來實現CIFAR-10的分類識別。

import torchvision.transforms as transforms
from torch.autograd import Variable
import torch.nn.functional as F
import torch.optim as optim
import os
import torchvision.datasets as datasets

# 數據預處理器
normalize = transforms.Normalize(mean=[0.4914, 0.4822, 0.4465],
                                 std=[0.2023, 0.1994, 0.2010])

# 數據增強
train_transforms = transforms.Compose([
    transforms.RandomCrop(32, padding=4),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    normalize,
])

val_transform = transforms.Compose([
    transforms.ToTensor(),
    normalize,
])

# 載入數據集
train_dataset = datasets.CIFAR10('data/', train=True, transform=train_transforms)
val_dataset = datasets.CIFAR10('data/', train=False, transform=val_transform)

# 搭建神經網路模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 128, 3, padding=1)
        self.fc1 = nn.Linear(128 * 4 * 4, 512)
        self.fc2 = nn.Linear(512, 10)

    def forward(self, x):
        out = F.relu(self.conv1(x))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv2(out))
        out = F.max_pool2d(out, 2)
        out = F.relu(self.conv3(out))
        out = F.max_pool2d(out, 2)
        out = out.view(-1, 128 * 4 * 4)
        out = F.relu(self.fc1(out))
        out = self.fc2(out)
        return out

model = Net()

# 模型訓練
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

losses = []
accuracies = []
val_losses = []
val_accuracies = []

def train(epoch):
    model.train()
    train_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(train_loader):
        if cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        optimizer.zero_grad()
        inputs, targets = Variable(inputs), Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)
        loss.backward()
        optimizer.step()

        train_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    acc = 100. * correct / total
    losses.append(train_loss / len(train_loader))
    accuracies.append(acc)

    print('Train Epoch: {}\t Loss: {:.6f}\t Accuracy: {:.6f}'.format(
        epoch, train_loss / len(train_loader), acc))

def validate(epoch):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0

    for batch_idx, (inputs, targets) in enumerate(val_loader):
        if cuda:
            inputs, targets = inputs.cuda(), targets.cuda()
        inputs, targets = Variable(inputs, volatile=True), Variable(targets)
        outputs = model(inputs)
        loss = criterion(outputs, targets)

        val_loss += loss.data[0]
        _, predicted = torch.max(outputs.data, 1)
        total += targets.size(0)
        correct += predicted.eq(targets.data).cpu().sum()

    acc = 100. * correct / total
    val_losses.append(val_loss / len(val_loader))
    val_accuracies.append(acc)

    print('Validation Epoch: {}\t Loss: {:.6f}\t Accuracy: {:.6f}'.format(
        epoch, val_loss / len(val_loader), acc))

for epoch in range(1, 101):
    train(epoch)
    validate(epoch)

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/259362.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-12-15 16:28
下一篇 2024-12-15 16:28

相關推薦

  • Python讀取CSV數據畫散點圖

    本文將從以下方面詳細闡述Python讀取CSV文件並畫出散點圖的方法: 一、CSV文件介紹 CSV(Comma-Separated Values)即逗號分隔值,是一種存儲表格數據的…

    編程 2025-04-29
  • Python中讀入csv文件數據的方法用法介紹

    csv是一種常見的數據格式,通常用於存儲小型數據集。Python作為一種廣泛流行的編程語言,內置了許多操作csv文件的庫。本文將從多個方面詳細介紹Python讀入csv文件的方法。…

    編程 2025-04-29
  • 如何用Python統計列表中各數據的方差和標準差

    本文將從多個方面闡述如何使用Python統計列表中各數據的方差和標準差, 並給出詳細的代碼示例。 一、什麼是方差和標準差 方差是衡量數據變異程度的統計指標,它是每個數據值和該數據值…

    編程 2025-04-29
  • Python多線程讀取數據

    本文將詳細介紹多線程讀取數據在Python中的實現方法以及相關知識點。 一、線程和多線程 線程是操作系統調度的最小單位。單線程程序只有一個線程,按照程序從上到下的順序逐行執行。而多…

    編程 2025-04-29
  • Python兩張表數據匹配

    本篇文章將詳細闡述如何使用Python將兩張表格中的數據匹配。以下是具體的解決方法。 一、數據匹配的概念 在生活和工作中,我們常常需要對多組數據進行比對和匹配。在數據量較小的情況下…

    編程 2025-04-29
  • Python爬取公交數據

    本文將從以下幾個方面詳細闡述python爬取公交數據的方法: 一、準備工作 1、安裝相關庫 import requests from bs4 import BeautifulSou…

    編程 2025-04-29
  • Python數據標準差標準化

    本文將為大家詳細講述Python中的數據標準差標準化,以及涉及到的相關知識。 一、什麼是數據標準差標準化 數據標準差標準化是數據處理中的一種方法,通過對數據進行標準差標準化可以將不…

    編程 2025-04-29
  • 如何使用Python讀取CSV數據

    在數據分析、數據挖掘和機器學習等領域,CSV文件是一種非常常見的文件格式。Python作為一種廣泛使用的編程語言,也提供了方便易用的CSV讀取庫。本文將介紹如何使用Python讀取…

    編程 2025-04-29
  • Python如何打亂數據集

    本文將從多個方面詳細闡述Python打亂數據集的方法。 一、shuffle函數原理 shuffle函數是Python中的一個內置函數,主要作用是將一個可迭代對象的元素隨機排序。 在…

    編程 2025-04-29
  • Python根據表格數據生成折線圖

    本文將介紹如何使用Python根據表格數據生成折線圖。折線圖是一種常見的數據可視化圖表形式,可以用來展示數據的趨勢和變化。Python是一種流行的編程語言,其強大的數據分析和可視化…

    編程 2025-04-29

發表回復

登錄後才能評論