WideResNet 是 ResNet 的一種改進,是一個由 Zhang、Sun 和 Ross Girshick 在論文 「Wide Residual Networks」 中提出的深度神經網路模型,該網路模型以原始的 ResNet 為基礎,將其擴展到更寬的網路上,並在各種圖像分類數據集上取得了最優的結果。
一、WideResNet 的簡介
傳統的 ResNet 構建在殘差塊上,主要是為了解決梯度消失的問題。而 WideResNet 通過增加通道的寬度來改進模型,可以獲得更好的表達能力和分類精度。
WideResNet 模型的特點如下:
- 寬度因子(width factor):利用更寬的卷積核代替更深的層數,並且提高卷積核的數量,以獲得更多的鑒別特徵。
- 深度因子(depth factor):通過增加殘差塊的數量來增加網路深度。
- Dropout:使用 Dropout 技術來減少過擬合,防止模型出現高方差。
- Batch Normalization:使用批標準化技術加速訓練過程,提高泛化能力。
二、WideResNet 的結構
WideResNet 主要由幾個組成部分組成:輸入層、卷積層、殘差塊、全局池化層、全連接層和輸出層。
WideResNet 的殘差塊包括兩個卷積層和一個 Skip Connection。在卷積操作後,數據通過 Shortcut 直接連接到輸出變數。
def conv3x3(in_channels, out_channels, stride=1):
"""3x3 convolution with padding"""
return nn.Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False)
class BasicBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3):
super().__init__()
self.bn1 = nn.BatchNorm2d(in_channels)
self.conv1 = conv3x3(in_channels, out_channels, stride)
self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
self.bn2 = nn.BatchNorm2d(out_channels)
self.conv2 = conv3x3(out_channels, out_channels)
if in_channels != out_channels or stride > 1:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels)
)
else:
self.shortcut = nn.Identity()
def forward(self, x):
out = F.relu(self.bn1(x), inplace=True)
out = self.conv1(out)
if self.dropout:
out = self.dropout(out)
out = self.conv2(F.relu(self.bn2(out), inplace=True))
out += self.shortcut(x)
return out
WideResNet 的卷積層包括一個 3×3 的卷積核,用 ReLU 函數作為激活函數,一個 Batch Normalization 層和一個 2×2 的最大池化層。
class ConvLayer(nn.Module):
def __init__(self, in_channels, out_channels, stride=1, dropout_rate=0.3):
super().__init__()
self.conv = nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False
)
self.bn = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.dropout = nn.Dropout2d(dropout_rate) if dropout_rate > 0 else None
self.pool = nn.MaxPool2d(kernel_size=2, stride=2)
def forward(self, x):
out = self.conv(x)
out = self.bn(out)
out = self.relu(out)
if self.dropout:
out = self.dropout(out)
out = self.pool(out)
return out
三、WideResNet 的應用
WideResNet 常用於各種計算機視覺任務,如物體識別、圖像分割、場景理解等。在 ImageNet 數據集上,WideResNet 取得了最優的 Top-1 和 Top-5 分類精度。
下面是 WideResNet 在 CIFAR-10 數據集上的完整代碼:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torchvision import datasets, transforms
class WRN(nn.Module):
def __init__(self, depth=28, widen_factor=10, dropout_rate=0.3, num_classes=10):
super(WRN, self).__init__()
self.depth = depth
self.widen_factor = widen_factor
self.dropout_rate = dropout_rate
self.num_classes = num_classes
k = widen_factor # width multiplier
# Network architecture
n = (depth - 4) // 6
block = BasicBlock
channels = [16, 16 * k, 32 * k, 64 * k]
self.features = nn.Sequential(
ConvLayer(3, channels[0], dropout_rate=dropout_rate),
self._make_layer(block, channels[1], n, dropout_rate=dropout_rate),
self._make_layer(block, channels[2], n, stride=2, dropout_rate=dropout_rate),
self._make_layer(block, channels[3], n, stride=2, dropout_rate=dropout_rate),
nn.BatchNorm2d(channels[3]),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d(1),
)
self.classifier = nn.Linear(channels[3], num_classes)
# Initialization
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.normal_(m.weight, 0, 0.01)
nn.init.constant_(m.bias, 0)
def _make_layer(self, block, out_channels, num_blocks, stride=1, dropout_rate=0):
layers = []
for i in range(num_blocks):
layers.append(
block(
16 * self.widen_factor if i == 0 else out_channels,
out_channels,
stride if i == 0 else 1,
dropout_rate=dropout_rate if i == 0 else 0,
)
)
return nn.Sequential(*layers)
def forward(self, x):
out = self.features(x)
out = out.view(out.size(0), -1)
out = self.classifier(out)
return out
# Training settings
batch_size = 128
epochs = 50
lr = 0.1
momentum = 0.9
weight_decay = 5e-4
device = "cuda" if torch.cuda.is_available() else "cpu"
# Load data
train_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(
root="./data",
train=True,
download=True,
transform=transforms.Compose(
[
transforms.RandomHorizontalFlip(),
transforms.RandomCrop(32, padding=4),
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
),
),
batch_size=batch_size,
shuffle=True,
num_workers=4,
pin_memory=True,
)
test_loader = torch.utils.data.DataLoader(
datasets.CIFAR10(
root="./data",
train=False,
download=True,
transform=transforms.Compose(
[
transforms.ToTensor(),
transforms.Normalize(
(0.4914, 0.4822, 0.4465), (0.2023, 0.1994, 0.2010)
),
]
),
),
batch_size=batch_size,
shuffle=False,
num_workers=4,
pin_memory=True,
)
# Create model
model = WRN().to(device)
# Optimization
optimizer = optim.SGD(
model.parameters(), lr=lr, momentum=momentum, weight_decay=weight_decay
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, epochs)
# Training loop
for epoch in range(epochs):
model.train()
for batch_idx, (data, target) in enumerate(train_loader):
data, target = data.to(device), target.to(device)
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
scheduler.step()
model.eval()
test_loss = 0
correct = 0
with torch.no_grad():
for data, target in test_loader:
data, target = data.to(device), target.to(device)
output = model(data)
test_loss += F.cross_entropy(output, target, reduction="sum").item()
pred = output.max(1, keepdim=True)[1]
correct += pred.eq(target.view_as(pred)).sum().item()
test_loss /= len(test_loader.dataset)
print(f"Epoch {epoch + 1}/{epochs}")
print(f"Train - loss: {loss.item():.4f}")
print(f"Test - loss: {test_loss:.4f}, accuracy: {correct}/{len(test_loader.dataset)} ({100. * correct / len(test_loader.dataset):.2f}%)\n")
原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/152705.html
微信掃一掃
支付寶掃一掃