PyTorch二分類詳解

一、基本概念

PyTorch是一個使用動態計算圖的開源機器學習庫,而二分類問題是機器學習中最基本、最常見的問題。在PyTorch中,二分類問題最常用的演算法是邏輯回歸(Logistic Regression)。

二、數據準備

在二分類問題中,我們通常需要準備好兩個類別的數據,同時將數據集分為訓練集和測試集。具體實現方式如下:

import torch
from torch.utils.data import Dataset, DataLoader

# 定義數據集
class MyDataset(Dataset):
    def __init__(self, data, labels):
        self.data = data.float()
        self.labels = labels
    
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        x = self.data[index]
        y = self.labels[index]
        return x, y

# 準備數據
train_data = torch.tensor([[0.12, 0.23], [0.03, 0.45], [0.67, 0.39], [0.13, 0.52], [0.55, 0.69]])
train_labels = torch.tensor([0, 1, 1, 0, 1])
train_dataset = MyDataset(train_data, train_labels)

test_data = torch.tensor([[0.19, 0.66], [0.44, 0.83], [0.87, 0.29], [0.76, 0.09]])
test_labels = torch.tensor([1, 0, 1, 0])
test_dataset = MyDataset(test_data, test_labels)

# 載入數據
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=False)

三、模型構建

邏輯回歸模型可以通過定義一個類來實現。在類中,需要定義模型的結構以及前向傳播的過程。在二分類問題中,通常使用sigmoid函數作為輸出層的激活函數。

import torch.nn as nn

# 定義模型
class LogisticRegression(nn.Module):
    def __init__(self):
        super(LogisticRegression, self).__init__()
        self.linear = nn.Linear(2, 1)
        self.sigmoid = nn.Sigmoid()
    
    def forward(self, x):
        out = self.linear(x)
        out = self.sigmoid(out)
        return out

model = LogisticRegression()

四、損失函數和優化器

在二分類問題中,通常使用二元交叉熵損失函數(Binary Cross Entropy Loss)作為損失函數,使用隨機梯度下降(SGD)或Adam優化器。

criterion = nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

五、模型訓練

在進行模型訓練時,通常需要通過多次迭代訓練模型,並在每一次迭代後計算模型在驗證集上的準確率和損失值。

num_epochs = 100
for epoch in range(num_epochs):
    for i, (inputs, labels) in enumerate(train_loader):
        # 前向傳播
        outputs = model(inputs)
        loss = criterion(outputs, labels.float().view(-1, 1))
        
        # 後向傳播及優化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        # 每100次迭代輸出一次信息
        if (i + 1) % 100 == 0:
            print('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}'.format(epoch+1, num_epochs, i+1, len(train_loader), loss.item()))

    # 計算在訓練集和驗證集上的準確率及損失值
    with torch.no_grad():
        train_correct = 0
        train_total = 0
        train_loss = 0.0
        for inputs, labels in train_loader:
            outputs = model(inputs)
            predicted = (outputs >= 0.5).float()
            train_correct += (predicted == labels.float().view(-1, 1)).sum().item()
            train_total += len(labels)
            train_loss += criterion(outputs, labels.float().view(-1, 1)).item() * len(labels)

        test_correct = 0
        test_total = 0
        test_loss = 0.0
        for inputs, labels in test_loader:
            outputs = model(inputs)
            predicted = (outputs >= 0.5).float()
            test_correct += (predicted == labels.float().view(-1, 1)).sum().item()
            test_total += len(labels)
            test_loss += criterion(outputs, labels.float().view(-1, 1)).item() * len(labels)

        train_acc = train_correct / train_total
        train_loss /= train_total
        test_acc = test_correct / test_total
        test_loss /= test_total

        print('Epoch [{}/{}], Train Loss: {:.4f}, Train Acc: {:.2f}%, Test Loss: {:.4f}, Test Acc: {:.2f}%'.format(epoch+1, num_epochs, train_loss, train_acc*100, test_loss, test_acc*100))

六、模型評估

在模型訓練完成後,可以使用測試集來評估模型的性能。

with torch.no_grad():
    correct = 0
    total = 0
    for inputs, labels in test_loader:
        outputs = model(inputs)
        predicted = (outputs >= 0.5).float()
        correct += (predicted == labels.float().view(-1, 1)).sum().item()
        total += len(labels)
    accuracy = correct / total
    print('Test Accuracy: {:.2f}%'.format(accuracy*100))

七、總結

通過本文,我們詳細介紹了使用PyTorch進行二分類的基本流程,包括數據準備、模型構建、損失函數和優化器的設置、模型訓練和模型評估。在實踐中,我們可以通過修改模型結構、損失函數和優化器的設置來進一步提高模型性能。

原創文章,作者:IEUEF,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/366328.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
IEUEF的頭像IEUEF
上一篇 2025-04-02 01:02
下一篇 2025-04-02 01:28

相關推薦

  • PyTorch模塊簡介

    PyTorch是一個開源的機器學習框架,它基於Torch,是一個Python優先的深度學習框架,同時也支持C++,非常容易上手。PyTorch中的核心模塊是torch,提供一些很好…

    編程 2025-04-27
  • Linux sync詳解

    一、sync概述 sync是Linux中一個非常重要的命令,它可以將文件系統緩存中的內容,強制寫入磁碟中。在執行sync之前,所有的文件系統更新將不會立即寫入磁碟,而是先緩存在內存…

    編程 2025-04-25
  • 神經網路代碼詳解

    神經網路作為一種人工智慧技術,被廣泛應用於語音識別、圖像識別、自然語言處理等領域。而神經網路的模型編寫,離不開代碼。本文將從多個方面詳細闡述神經網路模型編寫的代碼技術。 一、神經網…

    編程 2025-04-25
  • Linux修改文件名命令詳解

    在Linux系統中,修改文件名是一個很常見的操作。Linux提供了多種方式來修改文件名,這篇文章將介紹Linux修改文件名的詳細操作。 一、mv命令 mv命令是Linux下的常用命…

    編程 2025-04-25
  • git config user.name的詳解

    一、為什麼要使用git config user.name? git是一個非常流行的分散式版本控制系統,很多程序員都會用到它。在使用git commit提交代碼時,需要記錄commi…

    編程 2025-04-25
  • Python安裝OS庫詳解

    一、OS簡介 OS庫是Python標準庫的一部分,它提供了跨平台的操作系統功能,使得Python可以進行文件操作、進程管理、環境變數讀取等系統級操作。 OS庫中包含了大量的文件和目…

    編程 2025-04-25
  • C語言貪吃蛇詳解

    一、數據結構和演算法 C語言貪吃蛇主要運用了以下數據結構和演算法: 1. 鏈表 typedef struct body { int x; int y; struct body *nex…

    編程 2025-04-25
  • 詳解eclipse設置

    一、安裝與基礎設置 1、下載eclipse並進行安裝。 2、打開eclipse,選擇對應的工作空間路徑。 File -> Switch Workspace -> [選擇…

    編程 2025-04-25
  • Python輸入輸出詳解

    一、文件讀寫 Python中文件的讀寫操作是必不可少的基本技能之一。讀寫文件分別使用open()函數中的’r’和’w’參數,讀取文件…

    編程 2025-04-25
  • MPU6050工作原理詳解

    一、什麼是MPU6050 MPU6050是一種六軸慣性感測器,能夠同時測量加速度和角速度。它由三個感測器組成:一個三軸加速度計和一個三軸陀螺儀。這個組合提供了非常精細的姿態解算,其…

    編程 2025-04-25

發表回復

登錄後才能評論