ResNet18 结构分析

一、概述

ResNet(残差神经网络)于 2015 年提出,是 ImageNet 图像分类任务上的冠军。其核心思想是通过引入跨层连接(shortcut connection)解决了深层神经网络训练过程中梯度消失和梯度爆炸的问题。ResNet18 是 ResNet 的一个较小版本,共有 18 层,包括 16 层卷积神经网络层和 2 层全连接层。

二、ResNet18 细节

1.卷积层

ResNet18 的卷积层共包含16层,在 3×3 的卷积核后,采用 ReLU 激活函数,步长为1,不使用池化层。其中,前7层卷积没有跨层连接,8-16 层使用了跨层连接。具体来说,8-16 层中每隔一个残差块就有一条跨层连接,用于将上一层特征图与下一层残差块的输出相加。

2.残差块

ResNet18 的每个残差块包含 2-3 个卷积层,每个卷积层都跟随着 Batch Normalization 层和 ReLU 激活函数。具体来说,第一个卷积层的卷积核大小为 3 × 3,第二个卷积层的卷积核大小也为 3 × 3。如果跨层连接存在,还需要添加一层 1 × 1 的卷积层用于调整维度。

3.全连接层

ResNet18 的最后两层全连接层都含有 512 个神经元,倒数第二层使用 ReLU 激活函数,而最后一层使用 Softmax 函数产生对类的预测输出。

三、代码示例

1.定义 ResNet18 结构

import torch.nn as nn

class BasicBlock(nn.Module):
    expansion = 1

    def __init__(self, in_planes, planes, stride=1):
        super(BasicBlock, self).__init__()
        self.conv1 = nn.Conv2d(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(planes)
        self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn2 = nn.BatchNorm2d(planes)

        self.shortcut = nn.Sequential()
        if stride != 1 or in_planes != self.expansion*planes:
            self.shortcut = nn.Sequential(
                nn.Conv2d(in_planes, self.expansion*planes, kernel_size=1, stride=stride, bias=False),
                nn.BatchNorm2d(self.expansion*planes)
            )

    def forward(self, x):
        out = nn.functional.relu(self.bn1(self.conv1(x)))
        out = self.bn2(self.conv2(out))
        out += self.shortcut(x)
        out = nn.functional.relu(out)
        return out

class ResNet18(nn.Module):
    def __init__(self, num_classes=10):
        super(ResNet18, self).__init__()
        self.in_planes = 64

        self.conv1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1, bias=False)
        self.bn1 = nn.BatchNorm2d(64)
        self.layer1 = self._make_layer(64, 2, stride=1)
        self.layer2 = self._make_layer(128, 2, stride=2)
        self.layer3 = self._make_layer(256, 2, stride=2)
        self.layer4 = self._make_layer(512, 2, stride=2)
        self.linear = nn.Linear(512*BasicBlock.expansion, num_classes)

    def _make_layer(self, planes, num_blocks, stride):
        strides = [stride] + [1]*(num_blocks-1)
        layers = []
        for stride in strides:
            layers.append(BasicBlock(self.in_planes, planes, stride))
            self.in_planes = planes * BasicBlock.expansion
        return nn.Sequential(*layers)

    def forward(self, x):
        out = nn.functional.relu(self.bn1(self.conv1(x)))
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = nn.functional.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        out = self.linear(out)
        return out

2.实例化 ResNet18 模型

resnet18 = ResNet18()

3.模型训练

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(resnet18.parameters())

for epoch in range(num_epochs):
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        optimizer.zero_grad()
        outputs = resnet18(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        if i % 2000 == 1999:    
            print('[%d, %5d] loss: %.3f' %
                  (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

原创文章,作者:QWYJ,如若转载,请注明出处:https://www.506064.com/n/135397.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
QWYJQWYJ
上一篇 2024-10-04 00:12
下一篇 2024-10-04 00:12

相关推荐

  • Vue TS工程结构用法介绍

    在本篇文章中,我们将从多个方面对Vue TS工程结构进行详细的阐述,涵盖文件结构、路由配置、组件间通讯、状态管理等内容,并给出对应的代码示例。 一、文件结构 一个好的文件结构可以极…

    编程 2025-04-29
  • Python程序的三种基本控制结构

    控制结构是编程语言中非常重要的一部分,它们指导着程序如何在不同的情况下执行相应的指令。Python作为一种高级编程语言,也拥有三种基本的控制结构:顺序结构、选择结构和循环结构。 一…

    编程 2025-04-29
  • Lidar避障与AI结构光避障哪个更好?

    简单回答:Lidar避障适用于需要高精度避障的场景,而AI结构光避障更适用于需要快速响应的场景。 一、Lidar避障 Lidar,即激光雷达,通过激光束扫描环境获取点云数据,从而实…

    编程 2025-04-27
  • Switch C:多选结构的利器

    在编写程序时,我们经常需要根据某些条件执行不同的代码,这时就需要使用选择结构。在C语言中,有if语句、switch语句等多种选择结构可供使用。其中,switch语句是一种非常强大的…

    编程 2025-04-25
  • Python分支结构的详细阐述

    一、if语句的基本语法 if 条件: 代码语句1 代码语句2 …… if语句是Python分支结构中最基本也是最常用的结构,它的基本语法如上所示。if语句会先判断条件是否成立,如果…

    编程 2025-04-24
  • 深入理解 Vue 目录结构

    Vue 是一款由 Evan You 开发的流行 JavaScript 框架。Vue 具有响应式视图和组件化的思想,让开发者可以轻松构建交互式的 Web 应用。那么在 Vue 开发中…

    编程 2025-04-24
  • JS递归遍历树结构详解

    一、JS递归遍历树结构并修改 function traverse(node) { if(node == null) return; //遍历结束 node.value++; // …

    编程 2025-04-24
  • 详解数组结构

    一、数组的基本概念 数组是一种有序的数据结构,可以容纳一组相同数据类型的元素。每个元素有一个唯一的索引(下标),可以通过下标来访问数组的元素。数组一般分为一维和多维,也可以具有不同…

    编程 2025-04-23
  • 残差结构:从原理到应用

    一、残差结构的原理 残差结构在深度学习中的应用越来越广泛,其核心原理是将输入特征和参考特征拼接在一起进行训练,以增强模型的学习能力和泛化能力。 具体地,残差结构引入了跨层连接,使得…

    编程 2025-04-23
  • LTE帧结构详解

    一、帧结构简介 LTE网络中的帧结构是由多个子帧和时隙构成的。每个子帧由14个符号组成,符号的长度为0.5ms。每个符号中又包含7个资源块,一个资源块可以传输12个子载波。一个子帧…

    编程 2025-04-22

发表回复

登录后才能评论