殘差結構:從原理到應用

一、殘差結構的原理

殘差結構在深度學習中的應用越來越廣泛,其核心原理是將輸入特徵和參考特徵拼接在一起進行訓練,以增強模型的學習能力和泛化能力。

具體地,殘差結構引入了跨層連接,使得模型可以直接利用淺層的信息跟蹤梯度,從而更好地學習到深層次特徵表示。這種跨越層次的連接機制將輸入和輸出進行殘差學習,通過殘差的計算和積累,模型可以更快速、準確地逼近真實函數。

下面是一個簡單的殘差結構示例:

import torch.nn as nn
class ResidualBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride=1):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1)
        self.bn1 = nn.BatchNorm2d(out_channels)
        self.relu1 = nn.ReLU(inplace=True)
        self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
        self.bn2 = nn.BatchNorm2d(out_channels)
        self.relu2 = nn.ReLU(inplace=True)
        self.shortcut = nn.Sequential()
        if stride != 1 or in_channels != out_channels:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(out_channels)
            )
    def forward(self, x):
        out = self.conv1(x)
        out = self.bn1(out)
        out = self.relu1(out)
        out = self.conv2(out)
        out = self.bn2(out)
        out += self.shortcut(x)
        out = self.relu2(out)
        return out

二、殘差結構的改進

雖然殘差結構在很多模型中表現出了優異的性能,但它也存在著一些問題,例如如果網路結構複雜度過高,很容易產生梯度消失問題。

因此,一些改進方法被提出,例如Res2Net、PVT、HRNet等。Res2Net與ResNeXt類似,但是它將多個並行分支的特徵更多地交互,以增強網路表達能力。PVT(Pyramid Vision Transformer)則是基於Transformer的一種新型網路模型,它擁有多級金字塔特徵融合機制,可以更好地應對不同大小的對象。HRNet(High-Resolution Network)則是一種具有多分支特徵提取模塊的網路結構,可以更好地保留高解析度信息。

三、殘差結構的應用

殘差結構已經被廣泛應用於圖像分類、目標檢測、語音識別等各種領域。以下是在圖像分類上的一個例子:

import torch
import torchvision.transforms as transforms
import torchvision.datasets as datasets
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader

train_transforms = transforms.Compose([transforms.Resize(256),
                                       transforms.RandomCrop(224),
                                       transforms.RandomHorizontalFlip(),
                                       transforms.ToTensor(),
                                       transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                            std=[0.229, 0.224, 0.225])])

test_transforms = transforms.Compose([transforms.Resize(256),
                                      transforms.CenterCrop(224),
                                      transforms.ToTensor(),
                                      transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                           std=[0.229, 0.224, 0.225])])

train_dataset = datasets.CIFAR10('data/', train=True, transform=train_transforms, download=True)
test_dataset = datasets.CIFAR10('data/', train=False, transform=test_transforms, download=True)

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False, num_workers=4)


class ResNet(nn.Module):
    def __init__(self, block, num_blocks, num_classes=10):
        super(ResNet, self).__init__()
        self.in_channels = 64
        self.conv = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn = nn.BatchNorm2d(64)
        self.relu = nn.ReLU(inplace=True)
        self.layer1 = self.make_layer(block, 64, num_blocks[0], stride=1)
        self.layer2 = self.make_layer(block, 128, num_blocks[1], stride=2)
        self.layer3 = self.make_layer(block, 256, num_blocks[2], stride=2)
        self.layer4 = self.make_layer(block, 512, num_blocks[3], stride=2)
        self.avg_pool = nn.AdaptiveAvgPool2d((1, 1))
        self.fc = nn.Linear(512 * block.expansion, num_classes)
        
    def make_layer(self, block, out_channels, num_blocks, stride):
        strides = [stride] + [1] * (num_blocks - 1)
        layers = []
        for stride in strides:
            layers.append(block(self.in_channels, out_channels, stride))
            self.in_channels = out_channels * block.expansion
        return nn.Sequential(*layers)
    
    def forward(self, x):
        out = self.conv(x)
        out = self.bn(out)
        out = self.relu(out)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = self.avg_pool(out)
        out = out.view(out.size(0), -1)
        out = self.fc(out)
        return out

def ResNet18():
    return ResNet(ResidualBlock, [2, 2, 2, 2])

model = ResNet18()

criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9, weight_decay=1e-4)
lr_scheduler = optim.lr_scheduler.MultiStepLR(optimizer, milestones=[150, 225], gamma=0.1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)
criterion.to(device)

num_epochs = 300

for epoch in range(num_epochs):
    model.train()
    for i, (images, labels) in enumerate(train_loader):
        images = images.to(device)
        labels = labels.to(device)
        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
    lr_scheduler.step()
    model.eval()
    with torch.no_grad():
        correct = 0
        total = 0
        for images, labels in test_loader:
            images = images.to(device)
            labels = labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    accuracy = 100 * correct / total
    print('Epoch [{}/{}], Loss: {:.4f}, Accuracy: {:.2f}%'.format(epoch+1, num_epochs, loss.item(), accuracy))

原創文章,作者:DMZUX,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/371481.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
DMZUX的頭像DMZUX
上一篇 2025-04-23 18:08
下一篇 2025-04-23 18:08

相關推薦

  • Harris角點檢測演算法原理與實現

    本文將從多個方面對Harris角點檢測演算法進行詳細的闡述,包括演算法原理、實現步驟、代碼實現等。 一、Harris角點檢測演算法原理 Harris角點檢測演算法是一種經典的計算機視覺演算法…

    編程 2025-04-29
  • 瘦臉演算法 Python 原理與實現

    本文將從多個方面詳細闡述瘦臉演算法 Python 實現的原理和方法,包括該演算法的意義、流程、代碼實現、優化等內容。 一、演算法意義 隨著科技的發展,瘦臉演算法已經成為了人們修圖中不可缺少…

    編程 2025-04-29
  • Vue TS工程結構用法介紹

    在本篇文章中,我們將從多個方面對Vue TS工程結構進行詳細的闡述,涵蓋文件結構、路由配置、組件間通訊、狀態管理等內容,並給出對應的代碼示例。 一、文件結構 一個好的文件結構可以極…

    編程 2025-04-29
  • 神經網路BP演算法原理

    本文將從多個方面對神經網路BP演算法原理進行詳細闡述,並給出完整的代碼示例。 一、BP演算法簡介 BP演算法是一種常用的神經網路訓練演算法,其全稱為反向傳播演算法。BP演算法的基本思想是通過正…

    編程 2025-04-29
  • Python程序的三種基本控制結構

    控制結構是編程語言中非常重要的一部分,它們指導著程序如何在不同的情況下執行相應的指令。Python作為一種高級編程語言,也擁有三種基本的控制結構:順序結構、選擇結構和循環結構。 一…

    編程 2025-04-29
  • GloVe詞向量:從原理到應用

    本文將從多個方面對GloVe詞向量進行詳細的闡述,包括其原理、優缺點、應用以及代碼實現。如果你對詞向量感興趣,那麼這篇文章將會是一次很好的學習體驗。 一、原理 GloVe(Glob…

    編程 2025-04-27
  • Lidar避障與AI結構光避障哪個更好?

    簡單回答:Lidar避障適用於需要高精度避障的場景,而AI結構光避障更適用於需要快速響應的場景。 一、Lidar避障 Lidar,即激光雷達,通過激光束掃描環境獲取點雲數據,從而實…

    編程 2025-04-27
  • 編譯原理語法分析思維導圖

    本文將從以下幾個方面詳細闡述編譯原理語法分析思維導圖: 一、語法分析介紹 1.1 語法分析的定義 語法分析是編譯器中將輸入的字元流轉換成抽象語法樹的一個過程。該過程的目的是確保輸入…

    編程 2025-04-27
  • Python字典底層原理用法介紹

    本文將以Python字典底層原理為中心,從多個方面詳細闡述。字典是Python語言的重要組成部分,具有非常強大的功能,掌握其底層原理對於學習和使用Python將是非常有幫助的。 一…

    編程 2025-04-25
  • Switch C:多選結構的利器

    在編寫程序時,我們經常需要根據某些條件執行不同的代碼,這時就需要使用選擇結構。在C語言中,有if語句、switch語句等多種選擇結構可供使用。其中,switch語句是一種非常強大的…

    編程 2025-04-25

發表回復

登錄後才能評論