ASPP模块详解

ASPP(Atrous Spatial Pyramid Pooling)是一种用于图像分割任务的模块,旨在解决语义分割中空间上下文信息不足的问题。该模块在多个深度学习框架中得到了广泛的应用,如在DeepLab系列中发挥了关键作用。下面将从多个方面对ASPP模块进行详细的阐述。

一、ASPP模块原理

ASPP模块是基于空洞卷积(或称孔卷积,dilated convolution)的思想。空洞卷积是一种可以在不增加网络参数的情况下,增大感受野的技术,可以帮助模型获取更大范围的图像信息。ASPP模块采用多个空洞卷积,不同采样率的空洞卷积可捕获不同尺度的局部信息,最终得到具有不同感受野的特征图。下面是ASPP模块的代码实现:


import torch.nn as nn
import torch.nn.functional as F

class ASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates):
        super(ASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[0], dilation=rates[0])
        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[1], dilation=rates[1])
        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[2], dilation=rates[2])
        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout(0.5)

    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.conv2(x)
        feat3 = self.conv3(x)
        feat4 = self.conv4(x)
        out = torch.cat((feat1, feat2, feat3, feat4), dim=1)
        out = self.bn(self.conv5(out))
        out = F.relu(out)
        out = self.dropout(out)
        return out

ASPP模块实现了上述原理,使用四个不同采样率(rates)的空洞卷积,之后对输出进行合并,再通过一次卷积和BatchNorm层得到最终的输出。该模块中还加入了Dropout层防止过拟合。

二、多尺度ASPP模块

为进一步提高模型的准确性,可以在ASPP模块中引入多尺度的特征图。具体方法是在不同大小的特征图上分别使用ASPP模块,之后将它们合并得到最终的输出。多尺度ASPP模块的代码实现如下:


import torch

def ASPP_module(x, in_channels, out_channels, rates):
    feat1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1)(x)
    feat2 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[0], dilation=rates[0])(x)
    feat3 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[1], dilation=rates[1])(x)
    feat4 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[2], dilation=rates[2])(x)
    out = torch.cat((feat1, feat2, feat3, feat4), dim=1)
    out = torch.nn.BatchNorm2d(out_channels)(out)
    out = torch.nn.ReLU()(out)
    out = torch.nn.Dropout2d()(out)
    return out

class MultiScaleASPP(nn.Module):
    def __init__(self, in_channels, out_channels):
        super().__init__()
        rates = [1, 6, 12]
        self.aspp1 = ASPP_module(in_channels, out_channels, [1, 1, 1])
        self.aspp2 = ASPP_module(in_channels, out_channels, [6, 12, 18])
        self.aspp3 = ASPP_module(in_channels, out_channels, rates)
        self.global_avg_pool = nn.Sequential(
            nn.AdaptiveAvgPool2d((1, 1)),
            nn.Conv2d(in_channels, out_channels, 1, stride=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.ReLU()
        )
        self.conv = nn.Conv2d(out_channels*4, out_channels, kernel_size=1, bias=False)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(p=0.1)

    def forward(self, x):
        feat1 = self.aspp1(x)
        feat2 = self.aspp2(x)
        feat3 = self.aspp3(x)
        global_avg_pool = self.global_avg_pool(x).expand(x.size()[0], -1, x.size()[2], x.size()[3])
        out = torch.cat([feat1, feat2, feat3, global_avg_pool], dim=1)
        out = self.conv(out)
        out = self.bn(out)
        out = torch.nn.ReLU()(out)
        out = self.dropout(out)
        return out

利用多尺度ASPP模块,可以容易地在已有的ASPP模块中实现定制化的模型结构。

三、ASPP模块在DeepLab系列网络中的应用

DeepLab是语义分割任务中的一类经典网络,使用ASPP模块在网络中成功地解决了空间上下文信息不足问题,取得了较好的效果。下面以DeepLab-v3+网络为例,说明ASPP模块在其中的应用。该网络在ImageNet数据集上预训练,在PASCAL VOC、Cityscapes等数据集上微调。


import torch.nn as nn

class DeepLabv3(nn.Module):
    def __init__(self, backbone, classifier, aspp_dilate=[6,12,18]):
        super(DeepLabv3, self).__init__()
        self.backbone = backbone
        self.classifier = classifier
        self.aspp = MultiScaleASPP(in_channels=2048, out_channels=256)
        self.final_conv = nn.Conv2d(256, 256, kernel_size=1)
        self._init_weight()

    def forward(self, x):
        input_shape = x.shape[-2:]
        feature_map = self.backbone(x)
        feature_map = self.aspp(feature_map)
        feature_map = self.final_conv(feature_map)
        output = self.classifier(feature_map)
        output = F.interpolate(output, size=input_shape,
                               mode='bilinear', align_corners=False)
        return output

    def _init_weight(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                nn.init.kaiming_normal_(m.weight)
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

在DeepLabv3网络中,ASPP模块的输出经过一次卷积和上采样操作之后用于分类器进行预测。该网络在PASCAL VOC数据集上取得了当时最优秀的性能。

四、ASPP模块的优化

由于ASPP模块经常被用于深度学习网络的预测部分,而该部分常常需要对每个像素进行操作,因此ASPP模块的计算量很大。为此,研究者尝试减少ASPP模块的计算量,提出了多种方法,如使用深度可分离卷积(depthwise separable convolution)等。下面是一种改进ASPP模块的方法:


import torch.nn as nn

class GDASPP(nn.Module):
    def __init__(self, in_channels, out_channels, rates):
        super(GDASPP, self).__init__()
        self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.conv2 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[0], dilation=rates[0], groups=out_channels)
        self.conv3 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[1], dilation=rates[1], groups=out_channels)
        self.conv4 = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=rates[2], dilation=rates[2], groups=out_channels)
        self.conv5 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
        self.bn = nn.BatchNorm2d(out_channels)
        self.dropout = nn.Dropout2d(0.5)

    def forward(self, x):
        feat1 = self.conv1(x)
        feat2 = self.conv2(x)
        feat3 = self.conv3(x)
        feat4 = self.conv4(x)
        out = torch.cat((feat1, feat2, feat3, feat4), dim=1)
        out = self.bn(self.conv5(out))
        out = F.relu(out)
        out = self.dropout(out)
        return out

所述改进的ASPP模块将普通卷积替换为深度可分离卷积,可以大大降低计算量,同时保持模型准确性。该模块应用于DeepLabv3+中可以取得比原版ASPP模块更好的结果。

至此,我们详细地介绍了ASPP模块及其应用。ASPP模块在图像分割任务中具有重要作用,值得广大研究者深入研究。

原创文章,作者:小蓝,如若转载,请注明出处:https://www.506064.com/n/236080.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
小蓝小蓝
上一篇 2024-12-12 11:57
下一篇 2024-12-12 11:57

相关推荐

  • 光模块异常,SFP未认证(entityphysicalindex=6743835)——解决方案和

    如果您遇到类似optical module exception, sfp is not certified. (entityphysicalindex=6743835)的问题,那么…

    编程 2025-04-29
  • Python模块下载与安装指南

    如果想要扩展Python的功能,可以使用Python模块来实现。但是,在使用之前,需要先下载并安装对应的模块。本文将从以下多个方面对Python模块下载与安装进行详细的阐述,包括使…

    编程 2025-04-29
  • Python编程三剑客——模块、包、库

    本文主要介绍Python编程三剑客:模块、包、库的概念、特点、用法,以及在实际编程中的实际应用,旨在帮助读者更好地理解和应用Python编程。 一、模块 1、概念:Python模块…

    编程 2025-04-29
  • 如何使用pip安装模块

    pip作为Python默认的包管理系统,是安装和管理Python包的一种方式,它可以轻松快捷地安装、卸载和管理Python的扩展库、模块等。下面从几个方面详细介绍pip的使用方法。…

    编程 2025-04-28
  • Python如何下载第三方模块

    想要使Python更加强大且具备跨平台性,我们可以下载许多第三方模块。下面将从几个方面详细介绍如何下载第三方模块。 一、使用pip下载第三方模块 pip是Python的软件包管理器…

    编程 2025-04-28
  • Python datetime和time模块用法介绍

    本文将详细阐述Python datetime和time模块的用法和应用场景,以帮助读者更好地理解和运用这两个模块。 一、datetime模块 datetime模块提供了处理日期和时…

    编程 2025-04-28
  • Idea创建模块时下面没有启动类的解决方法

    本文将从以下几个方面对Idea创建模块时下面没有启动类进行详细阐述: 一、创建SpringBoot项目时没有启动类的解决方法 在使用Idea创建SpringBoot项目时,有可能会…

    编程 2025-04-28
  • l9110风扇传感器模块原理图解析

    本文将从原理图概述、硬件特性、软件实现等多个方面对l9110风扇传感器模块进行详细解析,并给出对应代码实例。 一、原理图概述 l9110风扇传感器模块主要由驱动芯片l9110、电位…

    编程 2025-04-28
  • 掌握Python3中datetime模块的使用

    Python3中的datetime模块是处理日期和时间的常用模块之一,它提供了一些函数和类,可以轻松处理日期和时间,包括日期和时间的计算、格式化、解析、时区转换等。本文将从多个方面…

    编程 2025-04-28
  • Python导入模块方法

    在Python编程中,模块是管理函数和变量之类内容的一种方式。Python标准库提供了许多有用的模块,让我们可以方便地实现对底层硬件和网络等的控制。本文将介绍Python中常用的导…

    编程 2025-04-28

发表回复

登录后才能评论