PyTorch實現手寫數字識別模型

一、背景介紹

手寫數字識別是一個非常實用的任務,它可以被應用於很多場景中,例如銀行付款單據的自動識別、手寫信件自動識別等。在深度學習技術的發展下,越來越多的人開始嘗試使用機器學習算法來解決手寫數字識別問題。其中,PyTorch是一種非常流行的深度學習框架,它提供了一系列的工具和API,可以幫助開發者快速構建、訓練和部署深度神經網絡模型,因此,在這篇文章中,我們將會使用PyTorch來構建一個手寫數字識別模型。

二、數據集概述

在深度神經網絡中,數據集是非常重要的一環,因為模型的效果很大程度上依賴於所使用的數據集。在這篇文章中,我們將會使用MNIST數據集,這是一個非常經典的手寫數字數據集,包含了大約70000張28×28像素的灰度圖像。其中60000張為訓練集,10000張為測試集。這個數據集可以通過PyTorch的Dataset和DataLoader API直接加載,並進行高效的數據預處理。

import torch
from torchvision import datasets, transforms

batch_size = 64

# 加載數據集
train_data = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)

# 創建 DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

三、模型架構

在PyTorch中,我們可以通過繼承torch.nn.Module基類來構建自定義的神經網絡模型。在這篇文章中,我們將會使用一個簡單的卷積神經網絡(Convolutional Neural Network,CNN)來實現手寫數字的識別。我們的CNN模型包含兩個卷積層、兩個池化層、兩個全連接層和一個輸出層,具體細節請看下面的代碼:

import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # 第一層卷積、激活、池化
        x = self.pool(F.relu(self.conv2(x)))   # 第二層卷積、激活、池化
        x = x.view(-1, 32 * 7 * 7)             # 將二維特徵圖轉換成一維特徵向量
        x = F.relu(self.fc1(x))                # 第一層全連接、激活
        x = F.relu(self.fc2(x))                # 第二層全連接、激活
        x = self.fc3(x)                        # 輸出層
        return x

# 實例化模型
cnn = CNN()

四、訓練過程

在確定了數據集和模型架構之後,我們就可以開始訓練我們的模型了。在這裡,我們將使用交叉熵(Cross Entropy)損失函數和隨機梯度下降(SGD)優化算法來進行模型的訓練。為了使模型能夠更好地泛化到未見過的數據集上,我們在訓練過程中還將使用一些簡單的技術,例如學習率衰減和早期停止,以幫助提高模型的精度。

import torch.optim as optim

learning_rate = 0.1
num_epochs = 10

# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn.parameters(), lr=learning_rate, momentum=0.9)

# 訓練模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # 將輸入數據和目標變量轉換成正確的類型
        images = images.float()
        labels = labels.long()

        # 前向傳播
        outputs = cnn(images)
        loss = criterion(outputs, labels)

        # 反向傳播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 輸出訓練結果
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))

    # 在每個 epoch 結束之後,使用測試數據集評估模型的性能
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            # 將輸入數據和目標變量轉換成正確的類型
            images = images.float()
            labels = labels.long()

            # 前向傳播並計算準確率
            outputs = cnn(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

五、模型評估

在經過訓練之後,我們需要評估模型的性能。在這裡,我們簡單地使用在測試數據集上的整體分類準確率來評估模型的性能。整體分類準確率是指模型在測試數據集上正確分類的樣本數占所有樣本數的比例。

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # 將輸入數據和目標變量轉換成正確的類型
        images = images.float()
        labels = labels.long()

        # 前向傳播並計算準確率
        outputs = cnn(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

六、總結

在這篇文章中,我們使用PyTorch實現了一個手寫數字識別模型。我們首先介紹了MNIST數據集,並使用PyTorch的Dataset和DataLoader API來加載和預處理數據。然後,我們構建了一個簡單的卷積神經網絡模型,並使用交叉熵損失函數和SGD優化算法來訓練模型。最後,我們使用測試數據集來評估模型的性能。本文代碼已經在Colab上執行。

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/235999.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-12-12 11:57
下一篇 2024-12-12 11:57

相關推薦

  • TensorFlow Serving Java:實現開發全功能的模型服務

    TensorFlow Serving Java是作為TensorFlow Serving的Java API,可以輕鬆地將基於TensorFlow模型的服務集成到Java應用程序中。…

    編程 2025-04-29
  • Python循環符合要求數字求和

    這篇文章將詳細介紹如何通過Python循環符合要求數字求和。如果你想用Python求和但又不想手動輸入數字,那麼本文將是一個不錯的選擇。 一、使用while循環實現求和 sum =…

    編程 2025-04-29
  • Python訓練模型後如何投入應用

    Python已成為機器學習和深度學習領域中熱門的編程語言之一,在訓練完模型後如何將其投入應用中,是一個重要問題。本文將從多個方面為大家詳細闡述。 一、模型持久化 在應用中使用訓練好…

    編程 2025-04-29
  • Python基本數字類型

    本文將介紹Python中基本數字類型,包括整型、布爾型、浮點型、複數型,並提供相應的代碼示例以便讀者更好的理解。 一、整型 整型即整數類型,Python中的整型沒有大小限制,所以可…

    編程 2025-04-29
  • Python打印數字三角形

    本文將詳細闡述如何使用Python打印數字三角形,包括從基本代碼實現到進階操作的應用。通過本文的學習,您可以掌握Python的基礎語法,同時加深對Python循環和函數的理解,提高…

    編程 2025-04-29
  • Python數字求和怎麼寫

    在Python中實現數字求和非常簡單,下面將從多個方面對Python數字求和的實現方法做詳細的闡述。 一、直接使用「+」符號進行求和 a = 10 b = 20 c = a + b…

    編程 2025-04-29
  • Python提取連續數字

    本文將介紹如何使用Python提取一個字符串中的連續數字。 一、使用正則表達式提取 正則表達式是一種可以匹配文本片段的模式。Python內置了re模塊,可以使用正則表達式進行字符串…

    編程 2025-04-29
  • Python中如何判斷字符為數字

    判斷字符是否為數字是Python編程中常見的需求,本文將從多個方面詳細闡述如何使用Python進行字符判斷。 一、isdigit()函數判斷字符是否為數字 Python中可以使用i…

    編程 2025-04-29
  • Python實現一元線性回歸模型

    本文將從多個方面詳細闡述Python實現一元線性回歸模型的代碼。如果你對線性回歸模型有一些了解,對Python語言也有所掌握,那麼本文將對你有所幫助。在開始介紹具體代碼前,讓我們先…

    編程 2025-04-29
  • ARIMA模型Python應用用法介紹

    ARIMA(自回歸移動平均模型)是一種時序分析常用的模型,廣泛應用於股票、經濟等領域。本文將從多個方面詳細闡述ARIMA模型的Python實現方式。 一、ARIMA模型是什麼? A…

    編程 2025-04-29

發表回復

登錄後才能評論