ResNet-18是一種非常著名的深度神經網路,它在ImageNet數據集上表現優異,被廣泛應用於計算機視覺領域。本文將從網路結構、Skip connection、殘差模塊、全局平均池化等多個方面對ResNet-18進行詳細的闡述。
一、網路結構
ResNet-18是由18個卷積層和全連接層組成的深度卷積神經網路,這個網路結構中每一個卷積層都有一個殘差塊,其中包含若干個卷積層和batch normalization層。在卷積層之間,使用了stride=2,stride=1,和shortcut來改善性能。
import torch.nn as nn
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = nn.ReLU()(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = nn.ReLU()(out)
return out
class ResNet(nn.Module):
def __init__(self, block, num_blocks, num_classes=10):
super(ResNet, self).__init__()
self.in_planes = 64
self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(64)
self.layer1 = self._make_layer(block, 64, num_blocks[0], stride=1)
self.layer2 = self._make_layer(block, 128, num_blocks[1], stride=2)
self.layer3 = self._make_layer(block, 256, num_blocks[2], stride=2)
self.layer4 = self._make_layer(block, 512, num_blocks[3], stride=2)
self.linear = nn.Linear(512*block.expansion, num_classes)
def _make_layer(self, block, planes, num_blocks, stride):
strides = [stride] + [1]*(num_blocks-1)
layers = []
for stride in strides:
layers.append(block(self.in_planes, planes, stride))
self.in_planes = planes * block.expansion
return nn.Sequential(*layers)
def forward(self, x):
out = nn.ReLU()(self.bn1(self.conv1(x)))
out = self.layer1(out)
out = self.layer2(out)
out = self.layer3(out)
out = self.layer4(out)
out = nn.AvgPool2d(kernel_size=4)(out)
out = out.view(out.size(0), -1)
out = self.linear(out)
return out
def resnet18():
return ResNet(BasicBlock, [2,2,2,2])
這個網路結構包含4個殘差模塊,每個殘差模塊包含多個BasicBlock,最後一個BasicBlock中做了全局平均池化。輸出層為一個完全連接層,用於分類。
二、Skip Connection
ResNet-18網路的核心是skip connection。在傳統CNN網路中,前面的層處理信息經過多次池化和卷積後被深度神經網路較深的層所覆蓋,導致前面的信息被遺忘,難以訓練。skip connection解決了這個問題,可以將前面的信息一路留給更深的網路層,可以直接傳遞信號而不會使其消失。
在ResNet中,skip connection是一種shortcut,負責將輸入源直接傳遞到殘差模塊中去。例如,在一個兩個卷積層組成的模塊中,輸入x通過第一個卷積層和relu激活函數之後,跳過第二個卷積層,直接到達該模塊的輸出,如下所示:
x = conv1(x)
out = conv2(x) + x
這種shortcut實在是太經典了,使得Deep Residual Network成為當時最優秀的網路之一。
三、殘差模塊
殘差模塊是ResNet-18網路的基礎部分。每一個殘差模塊都是由兩個連接級聯的卷積層組合而成。其中的第一個卷積層可以是3*3,5*5或者7*7,第二個卷積層通常都是3*3。
下面是一個基本的殘差模塊的實現:
class BasicBlock(nn.Module):
expansion = 1
def __init__(self, in_planes, planes, stride=1):
super(BasicBlock, self).__init__()
self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(planes)
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(planes)
self.shortcut = nn.Sequential()
if stride != 1 or in_planes != self.expansion*planes:
self.shortcut = nn.Sequential(
nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(self.expansion*planes)
)
def forward(self, x):
out = nn.ReLU()(self.bn1(self.conv1(x)))
out = self.bn2(self.conv2(out))
out += self.shortcut(x)
out = nn.ReLU()(out)
return out
在這個殘差模塊中,殘差塊有兩個卷積層,分別用於卷積源數據和傳遞經過第一個卷積的結果到一個shortcut通道,使上一個殘差塊的輸出可以和這個殘差塊的輸入之間消除一些高階多項式的雜訊,最終輸出結果。
四、全局平均池化
全局平均池化是ResNet-18網路的最後一步操作。它的目的是將特徵圖中每個像素劃分為整個空間的均值,此操作將特徵圖進一步壓縮為單個值,以用作分類器的輸入。
下面是一個實現全局平均池化的代碼塊:
out = nn.AvgPool2d(kernel_size=4)(out)
out = out.view(out.size(0), -1)
在這個代碼塊中,第一行中的AvgPool2d函數用於把特徵映射進行均值池化,第二行基於當前大小調整輸出的形狀以便於將數據輸入到全連接層中。
總結
ResNet-18是一種非常成功的深度神經網路,它不僅可以支持更深層次的架構,而且可以在一定程度上減少過擬合的問題。在本文中,我們對ResNet-18的網路結構、Skip Connection機制、殘差模塊、全局平均池化等多個方面進行了詳盡的解析,並提供了相應的代碼示例,幫助讀者更好的理解這個深度神經網路。
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/160678.html