一、背景介紹
手寫數字識別是一個非常實用的任務,它可以被應用於很多場景中,例如銀行付款單據的自動識別、手寫信件自動識別等。在深度學習技術的發展下,越來越多的人開始嘗試使用機器學習算法來解決手寫數字識別問題。其中,PyTorch是一種非常流行的深度學習框架,它提供了一系列的工具和API,可以幫助開發者快速構建、訓練和部署深度神經網絡模型,因此,在這篇文章中,我們將會使用PyTorch來構建一個手寫數字識別模型。
二、數據集概述
在深度神經網絡中,數據集是非常重要的一環,因為模型的效果很大程度上依賴於所使用的數據集。在這篇文章中,我們將會使用MNIST數據集,這是一個非常經典的手寫數字數據集,包含了大約70000張28×28像素的灰度圖像。其中60000張為訓練集,10000張為測試集。這個數據集可以通過PyTorch的Dataset和DataLoader API直接加載,並進行高效的數據預處理。
import torch
from torchvision import datasets, transforms
batch_size = 64
# 加載數據集
train_data = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)
# 創建 DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)
三、模型架構
在PyTorch中,我們可以通過繼承torch.nn.Module基類來構建自定義的神經網絡模型。在這篇文章中,我們將會使用一個簡單的卷積神經網絡(Convolutional Neural Network,CNN)來實現手寫數字的識別。我們的CNN模型包含兩個卷積層、兩個池化層、兩個全連接層和一個輸出層,具體細節請看下面的代碼:
import torch.nn as nn
import torch.nn.functional as F
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
self.pool = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
self.fc1 = nn.Linear(32 * 7 * 7, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = self.pool(F.relu(self.conv1(x))) # 第一層卷積、激活、池化
x = self.pool(F.relu(self.conv2(x))) # 第二層卷積、激活、池化
x = x.view(-1, 32 * 7 * 7) # 將二維特徵圖轉換成一維特徵向量
x = F.relu(self.fc1(x)) # 第一層全連接、激活
x = F.relu(self.fc2(x)) # 第二層全連接、激活
x = self.fc3(x) # 輸出層
return x
# 實例化模型
cnn = CNN()
四、訓練過程
在確定了數據集和模型架構之後,我們就可以開始訓練我們的模型了。在這裡,我們將使用交叉熵(Cross Entropy)損失函數和隨機梯度下降(SGD)優化算法來進行模型的訓練。為了使模型能夠更好地泛化到未見過的數據集上,我們在訓練過程中還將使用一些簡單的技術,例如學習率衰減和早期停止,以幫助提高模型的精度。
import torch.optim as optim
learning_rate = 0.1
num_epochs = 10
# 定義損失函數和優化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn.parameters(), lr=learning_rate, momentum=0.9)
# 訓練模型
for epoch in range(num_epochs):
for i, (images, labels) in enumerate(train_loader):
# 將輸入數據和目標變量轉換成正確的類型
images = images.float()
labels = labels.long()
# 前向傳播
outputs = cnn(images)
loss = criterion(outputs, labels)
# 反向傳播
optimizer.zero_grad()
loss.backward()
optimizer.step()
# 輸出訓練結果
if (i + 1) % 100 == 0:
print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))
# 在每個 epoch 結束之後,使用測試數據集評估模型的性能
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
# 將輸入數據和目標變量轉換成正確的類型
images = images.float()
labels = labels.long()
# 前向傳播並計算準確率
outputs = cnn(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
五、模型評估
在經過訓練之後,我們需要評估模型的性能。在這裡,我們簡單地使用在測試數據集上的整體分類準確率來評估模型的性能。整體分類準確率是指模型在測試數據集上正確分類的樣本數占所有樣本數的比例。
with torch.no_grad():
correct = 0
total = 0
for images, labels in test_loader:
# 將輸入數據和目標變量轉換成正確的類型
images = images.float()
labels = labels.long()
# 前向傳播並計算準確率
outputs = cnn(images)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
六、總結
在這篇文章中,我們使用PyTorch實現了一個手寫數字識別模型。我們首先介紹了MNIST數據集,並使用PyTorch的Dataset和DataLoader API來加載和預處理數據。然後,我們構建了一個簡單的卷積神經網絡模型,並使用交叉熵損失函數和SGD優化算法來訓練模型。最後,我們使用測試數據集來評估模型的性能。本文代碼已經在Colab上執行。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/235999.html