PyTorch二分类详解

一、基本概念

PyTorch是一个使用动态计算图的开源机器学习库,而二分类问题是机器学习中最基本、最常见的问题。在PyTorch中,二分类问题最常用的算法是逻辑回归(Logistic Regression)。

二、数据准备

在二分类问题中,我们通常需要准备好两个类别的数据,同时将数据集分为训练集和测试集。具体实现方式如下:

import torch
from torch.utils.data import Dataset, DataLoader

# 定义数据集
class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data.float()
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

# 准备数据
train_data = torch.tensor([[0.12, 0.23], [0.03, 0.45], [0.67, 0.39], [0.13, 0.52], [0.55, 0.69]])
train_labels = torch.tensor([0, 1, 1, 0, 1])
train_dataset = MyDataset(train_data, train_labels)

test_data = torch.tensor([[0.19, 0.66], [0.44, 0.83], [0.87, 0.29], [0.76, 0.09]])
test_labels = torch.tensor([1, 0, 1, 0])
test_dataset = MyDataset(test_data, test_labels)

# 加载数据
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

三、模型构建

逻辑回归模型可以通过定义一个类来实现。在类中,需要定义模型的结构以及前向传播的过程。在二分类问题中,通常使用sigmoid函数作为输出层的激活函数。

import torch.nn as nn

# 定义模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

model = LogisticRegression()

四、损失函数和优化器

在二分类问题中,通常使用二元交叉熵损失函数(Binary Cross Entropy Loss)作为损失函数,使用随机梯度下降(SGD)或Adam优化器。

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

五、模型训练

在进行模型训练时,通常需要通过多次迭代训练模型,并在每一次迭代后计算模型在验证集上的准确率和损失值。

num_epochs = 100
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        # 前向传播
        outputs = model(inputs)
        loss = criterion(outputs, labels.float().view(-1, 1))
        
        # 后向传播及优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 每100次迭代输出一次信息
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

    # 计算在训练集和验证集上的准确率及损失值
    with torch.no_grad():
        train_correct = 0
        train_total = 0
        train_loss = 0.0
        for inputs, labels in train_loader:
            outputs = model(inputs)
            predicted = (outputs >= 0.5).float()
            train_correct += (predicted == labels.float().view(-1, 1)).sum().item()
            train_total += len(labels)
            train_loss += criterion(outputs, labels.float().view(-1, 1)).item() * len(labels)

        test_correct = 0
        test_total = 0
        test_loss = 0.0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            predicted = (outputs >= 0.5).float()
            test_correct += (predicted == labels.float().view(-1, 1)).sum().item()
            test_total += len(labels)
            test_loss += criterion(outputs, labels.float().view(-1, 1)).item() * len(labels)

        train_acc = train_correct / train_total
        train_loss /= train_total
        test_acc = test_correct / test_total
        test_loss /= test_total

        print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'.format(epoch+1, num_epochs, train_loss, train_acc*100, test_loss, test_acc*100))

六、模型评估

在模型训练完成后,可以使用测试集来评估模型的性能。

with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_loader:
        outputs = model(inputs)
        predicted = (outputs >= 0.5).float()
        correct += (predicted == labels.float().view(-1, 1)).sum().item()
        total += len(labels)
    accuracy = correct / total
    print('Test Accuracy: {:.2f}%'.format(accuracy*100))

七、总结

通过本文,我们详细介绍了使用PyTorch进行二分类的基本流程,包括数据准备、模型构建、损失函数和优化器的设置、模型训练和模型评估。在实践中,我们可以通过修改模型结构、损失函数和优化器的设置来进一步提高模型性能。

原创文章,作者:IEUEF,如若转载,请注明出处:https://www.506064.com/n/366328.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
IEUEFIEUEF
上一篇 2025-04-02 01:02
下一篇 2025-04-02 01:28

相关推荐

  • PyTorch模块简介

    PyTorch是一个开源的机器学习框架,它基于Torch,是一个Python优先的深度学习框架,同时也支持C++,非常容易上手。PyTorch中的核心模块是torch,提供一些很好…

    编程 2025-04-27
  • Linux sync详解

    一、sync概述 sync是Linux中一个非常重要的命令,它可以将文件系统缓存中的内容,强制写入磁盘中。在执行sync之前,所有的文件系统更新将不会立即写入磁盘,而是先缓存在内存…

    编程 2025-04-25
  • 神经网络代码详解

    神经网络作为一种人工智能技术,被广泛应用于语音识别、图像识别、自然语言处理等领域。而神经网络的模型编写,离不开代码。本文将从多个方面详细阐述神经网络模型编写的代码技术。 一、神经网…

    编程 2025-04-25
  • Linux修改文件名命令详解

    在Linux系统中,修改文件名是一个很常见的操作。Linux提供了多种方式来修改文件名,这篇文章将介绍Linux修改文件名的详细操作。 一、mv命令 mv命令是Linux下的常用命…

    编程 2025-04-25
  • git config user.name的详解

    一、为什么要使用git config user.name? git是一个非常流行的分布式版本控制系统,很多程序员都会用到它。在使用git commit提交代码时,需要记录commi…

    编程 2025-04-25
  • Python安装OS库详解

    一、OS简介 OS库是Python标准库的一部分,它提供了跨平台的操作系统功能,使得Python可以进行文件操作、进程管理、环境变量读取等系统级操作。 OS库中包含了大量的文件和目…

    编程 2025-04-25
  • C语言贪吃蛇详解

    一、数据结构和算法 C语言贪吃蛇主要运用了以下数据结构和算法: 1. 链表 typedef struct body { int x; int y; struct body *nex…

    编程 2025-04-25
  • 详解eclipse设置

    一、安装与基础设置 1、下载eclipse并进行安装。 2、打开eclipse,选择对应的工作空间路径。 File -> Switch Workspace -> [选择…

    编程 2025-04-25
  • Python输入输出详解

    一、文件读写 Python中文件的读写操作是必不可少的基本技能之一。读写文件分别使用open()函数中的’r’和’w’参数,读取文件…

    编程 2025-04-25
  • MPU6050工作原理详解

    一、什么是MPU6050 MPU6050是一种六轴惯性传感器,能够同时测量加速度和角速度。它由三个传感器组成:一个三轴加速度计和一个三轴陀螺仪。这个组合提供了非常精细的姿态解算,其…

    编程 2025-04-25

发表回复

登录后才能评论