PyTorch实现手写数字识别模型

一、背景介绍

手写数字识别是一个非常实用的任务,它可以被应用于很多场景中,例如银行付款单据的自动识别、手写信件自动识别等。在深度学习技术的发展下,越来越多的人开始尝试使用机器学习算法来解决手写数字识别问题。其中,PyTorch是一种非常流行的深度学习框架,它提供了一系列的工具和API,可以帮助开发者快速构建、训练和部署深度神经网络模型,因此,在这篇文章中,我们将会使用PyTorch来构建一个手写数字识别模型。

二、数据集概述

在深度神经网络中,数据集是非常重要的一环,因为模型的效果很大程度上依赖于所使用的数据集。在这篇文章中,我们将会使用MNIST数据集,这是一个非常经典的手写数字数据集,包含了大约70000张28×28像素的灰度图像。其中60000张为训练集,10000张为测试集。这个数据集可以通过PyTorch的Dataset和DataLoader API直接加载,并进行高效的数据预处理。

import torch
from torchvision import datasets, transforms

batch_size = 64

# 加载数据集
train_data = datasets.MNIST(root='data', train=True, transform=transforms.ToTensor(), download=True)
test_data = datasets.MNIST(root='data', train=False, transform=transforms.ToTensor(), download=True)

# 创建 DataLoader
train_loader = torch.utils.data.DataLoader(train_data, batch_size=batch_size, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_data, batch_size=batch_size, shuffle=False)

三、模型架构

在PyTorch中,我们可以通过继承torch.nn.Module基类来构建自定义的神经网络模型。在这篇文章中,我们将会使用一个简单的卷积神经网络(Convolutional Neural Network,CNN)来实现手写数字的识别。我们的CNN模型包含两个卷积层、两个池化层、两个全连接层和一个输出层,具体细节请看下面的代码:

import torch.nn as nn
import torch.nn.functional as F

class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.conv1 = nn.Conv2d(1, 16, 5, padding=2)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(16, 32, 5, padding=2)
        self.fc1 = nn.Linear(32 * 7 * 7, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))   # 第一层卷积、激活、池化
        x = self.pool(F.relu(self.conv2(x)))   # 第二层卷积、激活、池化
        x = x.view(-1, 32 * 7 * 7)             # 将二维特征图转换成一维特征向量
        x = F.relu(self.fc1(x))                # 第一层全连接、激活
        x = F.relu(self.fc2(x))                # 第二层全连接、激活
        x = self.fc3(x)                        # 输出层
        return x

# 实例化模型
cnn = CNN()

四、训练过程

在确定了数据集和模型架构之后,我们就可以开始训练我们的模型了。在这里,我们将使用交叉熵(Cross Entropy)损失函数和随机梯度下降(SGD)优化算法来进行模型的训练。为了使模型能够更好地泛化到未见过的数据集上,我们在训练过程中还将使用一些简单的技术,例如学习率衰减和早期停止,以帮助提高模型的精度。

import torch.optim as optim

learning_rate = 0.1
num_epochs = 10

# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(cnn.parameters(), lr=learning_rate, momentum=0.9)

# 训练模型
for epoch in range(num_epochs):
    for i, (images, labels) in enumerate(train_loader):
        # 将输入数据和目标变量转换成正确的类型
        images = images.float()
        labels = labels.long()

        # 前向传播
        outputs = cnn(images)
        loss = criterion(outputs, labels)

        # 反向传播
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 输出训练结果
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch + 1, num_epochs, i + 1, len(train_loader), loss.item()))

    # 在每个 epoch 结束之后,使用测试数据集评估模型的性能
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            # 将输入数据和目标变量转换成正确的类型
            images = images.float()
            labels = labels.long()

            # 前向传播并计算准确率
            outputs = cnn(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

        print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

五、模型评估

在经过训练之后,我们需要评估模型的性能。在这里,我们简单地使用在测试数据集上的整体分类准确率来评估模型的性能。整体分类准确率是指模型在测试数据集上正确分类的样本数占所有样本数的比例。

with torch.no_grad():
    correct = 0
    total = 0
    for images, labels in test_loader:
        # 将输入数据和目标变量转换成正确的类型
        images = images.float()
        labels = labels.long()

        # 前向传播并计算准确率
        outputs = cnn(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    print('Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))

六、总结

在这篇文章中,我们使用PyTorch实现了一个手写数字识别模型。我们首先介绍了MNIST数据集,并使用PyTorch的Dataset和DataLoader API来加载和预处理数据。然后,我们构建了一个简单的卷积神经网络模型,并使用交叉熵损失函数和SGD优化算法来训练模型。最后,我们使用测试数据集来评估模型的性能。本文代码已经在Colab上执行。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/235999.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝的头像小蓝
上一篇 2024-12-12 11:57
下一篇 2024-12-12 11:57

相关推荐

  • TensorFlow Serving Java:实现开发全功能的模型服务

    TensorFlow Serving Java是作为TensorFlow Serving的Java API,可以轻松地将基于TensorFlow模型的服务集成到Java应用程序中。…

    编程 2025-04-29
  • Python循环符合要求数字求和

    这篇文章将详细介绍如何通过Python循环符合要求数字求和。如果你想用Python求和但又不想手动输入数字,那么本文将是一个不错的选择。 一、使用while循环实现求和 sum =…

    编程 2025-04-29
  • Python训练模型后如何投入应用

    Python已成为机器学习和深度学习领域中热门的编程语言之一,在训练完模型后如何将其投入应用中,是一个重要问题。本文将从多个方面为大家详细阐述。 一、模型持久化 在应用中使用训练好…

    编程 2025-04-29
  • Python基本数字类型

    本文将介绍Python中基本数字类型,包括整型、布尔型、浮点型、复数型,并提供相应的代码示例以便读者更好的理解。 一、整型 整型即整数类型,Python中的整型没有大小限制,所以可…

    编程 2025-04-29
  • Python打印数字三角形

    本文将详细阐述如何使用Python打印数字三角形,包括从基本代码实现到进阶操作的应用。通过本文的学习,您可以掌握Python的基础语法,同时加深对Python循环和函数的理解,提高…

    编程 2025-04-29
  • Python数字求和怎么写

    在Python中实现数字求和非常简单,下面将从多个方面对Python数字求和的实现方法做详细的阐述。 一、直接使用“+”符号进行求和 a = 10 b = 20 c = a + b…

    编程 2025-04-29
  • Python提取连续数字

    本文将介绍如何使用Python提取一个字符串中的连续数字。 一、使用正则表达式提取 正则表达式是一种可以匹配文本片段的模式。Python内置了re模块,可以使用正则表达式进行字符串…

    编程 2025-04-29
  • Python中如何判断字符为数字

    判断字符是否为数字是Python编程中常见的需求,本文将从多个方面详细阐述如何使用Python进行字符判断。 一、isdigit()函数判断字符是否为数字 Python中可以使用i…

    编程 2025-04-29
  • Python实现一元线性回归模型

    本文将从多个方面详细阐述Python实现一元线性回归模型的代码。如果你对线性回归模型有一些了解,对Python语言也有所掌握,那么本文将对你有所帮助。在开始介绍具体代码前,让我们先…

    编程 2025-04-29
  • ARIMA模型Python应用用法介绍

    ARIMA(自回归移动平均模型)是一种时序分析常用的模型,广泛应用于股票、经济等领域。本文将从多个方面详细阐述ARIMA模型的Python实现方式。 一、ARIMA模型是什么? A…

    编程 2025-04-29

发表回复

登录后才能评论