soft-nms详解

一、什么是soft-nms

Soft-nms是一种非极大值抑制(Non-Maximum Suppression,NMS)的改进方法,与传统的NMS比较,soft-nms使用一个相对较为缓和的函数逐步降低重叠框的得分,从而保留了更多的框,提高了检测的精度。

二、soft-nms的工作原理

传统的NMS方法是根据重叠区域的大小将得分低的框删除,而soft-nms则先利用计算出的较为缓和的函数对框的得分进行降权,然后再根据剩余框的得分进行排序,最后按照类似于传统NMS的方式进行筛选。

def soft_nms(dets, sigma=0.5, Nt=0.3, threshold=0.001, method=1):
    """
    PyTorch implementation of SoftNMS algorithm.
    # Arguments
        dets:        detections, size[N,5], format[x1,y1,x2,y2,score]
        sigma:       variance of Gaussian function, scalar
        Nt:          threshold for box overlap, scalar
        threshold:   score threshold, scalar
        method:      0=Max, 1=Linear, 2=Gaussian
    # Returns
        dets:        detections after SoftNMS, size[K,5]
    """

    # Indexes concatenate detection boxes with the score
    N = dets.shape[0]
    indexes = np.array([np.arange(N)])
    dets = np.concatenate((dets, indexes.T), axis=1)

    for i in range(N):
        # intermediate parameters for later parameters exchange
        si = dets[i, 4]
        xi = dets[i, :4]
        area_i = (xi[2] - xi[0] + 1) * (xi[3] - xi[1] + 1)

        if method == 1:  # Linear
            weight = np.ones((N - i))
            weight[0] = si
        else:  # Gaussian
            # Compute Gaussian weight coefficients
            xx = np.arange(i, N).astype(np.float32)
            if method == 2:
                sigma = 0.5
            ii = np.ones((xx.shape[0], 1)) * i
            # print(sigma)
            # print((xx - ii).shape)
            gauss = np.exp(-1.0 * ((xx - ii) ** 2) / (2 * sigma * sigma))

            if method == 2:
                weight = gauss
            else:
                weight = np.zeros((N - i))
                weight[0] = 1.0
                weight[1:] = gauss / np.sum(gauss)

        # Sort boxes by score
        idx = np.arange(i, N)
        idx_max = np.argmax(dets[idx, 4])
        idx_max += i

        # Swap boxes and scores
        dets[i, 4], dets[idx_max, 4] = dets[idx_max, 4], dets[i, 4]
        dets[i, :4], dets[idx_max, :4] = dets[idx_max, :4], dets[i, :4]
        dets[i, 5], dets[idx_max, 5] = dets[idx_max, 5], dets[i, 5]

        # Compute overlap ratios
        xx1 = np.maximum(dets[i, 0], dets[idx, 0])
        yy1 = np.maximum(dets[i, 1], dets[idx, 1])
        xx2 = np.minimum(dets[i, 2], dets[idx, 2])
        yy2 = np.minimum(dets[i, 3], dets[idx, 3])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h

        # Update weights
        if method == 0:  # Max
            weight[idx_max - i + 1:] = np.where(inter > Nt, 0.0, 1.0)
        else:  # Linear / Gaussian
            weight_matrix = np.zeros((weight.shape[0], weight.shape[0]))
            weight_matrix[0, :] = weight
            weight_matrix[1:, :] = np.diag(weight[1:])
            iou = inter / (area_i + dets[idx, 4] * (1 - inter))
            weight[idx - i + 1] = np.matmul(weight_matrix, (1.0 - iou).reshape(-1, 1)).reshape(-1)
            weight[idx_max - i + 1:] = np.where(iou > Nt, 0.0, weight[idx_max - i + 1:])

        # Apply weight
        dets[idx, 4] = dets[idx, 4] * weight

        # Weigh small scores
        suppress_small = np.where(dets[idx, 4]  0)[0]
    dets = dets[idx_keep]

    return dets[:, :5]

三、soft-nms的实现过程

Soft-nms的实现过程可以分为几个步骤:

1. 输入预测框

输入神经网络预测输出的所有框,每个框有四个坐标和一个类别得分。

2. 对于每个框计算其权重

权重可以使用三种不同的函数:max、linear和Gaussian。

3. 重复以下步骤,直到不再有框被删除

(1)选出最高得分的框,令其权重为1,与第一个框进行交换。

(2)计算当前框与剩余框的重叠率。

(3)根据重叠率和选定的函数计算权重。

(4)根据权重更新每个框的得分。

(5)剔除得分小于设定阈值的框。

4. 输出筛选后的结果

四、soft-nms的优点

Soft-nms与传统的NMS相比有以下优点:

1. 保留更多的框

相对于传统的NMS,soft-nms不会直接删除重叠较多的框,而是通过降权,保留了更多的框。

2. 精度更高

相对于传统的NMS,soft-nms保留了更多的框,因此可以提高检测的精度。

3. 可以自适应地调整阈值

soft-nms的函数参数可以根据实际情况进行调整,从而自适应地调整阈值。

五、结语

Soft-nms是非常有用的一种NMS改进技术,可以在图像检测中提高精度,值得研究和应用。

原创文章,作者:JPBXE,如若转载,请注明出处:https://www.506064.com/n/372627.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
JPBXEJPBXE
上一篇 2025-04-24 06:40
下一篇 2025-04-24 06:40

相关推荐

  • Linux sync详解

    一、sync概述 sync是Linux中一个非常重要的命令,它可以将文件系统缓存中的内容,强制写入磁盘中。在执行sync之前,所有的文件系统更新将不会立即写入磁盘,而是先缓存在内存…

    编程 2025-04-25
  • 神经网络代码详解

    神经网络作为一种人工智能技术,被广泛应用于语音识别、图像识别、自然语言处理等领域。而神经网络的模型编写,离不开代码。本文将从多个方面详细阐述神经网络模型编写的代码技术。 一、神经网…

    编程 2025-04-25
  • Linux修改文件名命令详解

    在Linux系统中,修改文件名是一个很常见的操作。Linux提供了多种方式来修改文件名,这篇文章将介绍Linux修改文件名的详细操作。 一、mv命令 mv命令是Linux下的常用命…

    编程 2025-04-25
  • Python输入输出详解

    一、文件读写 Python中文件的读写操作是必不可少的基本技能之一。读写文件分别使用open()函数中的’r’和’w’参数,读取文件…

    编程 2025-04-25
  • nginx与apache应用开发详解

    一、概述 nginx和apache都是常见的web服务器。nginx是一个高性能的反向代理web服务器,将负载均衡和缓存集成在了一起,可以动静分离。apache是一个可扩展的web…

    编程 2025-04-25
  • MPU6050工作原理详解

    一、什么是MPU6050 MPU6050是一种六轴惯性传感器,能够同时测量加速度和角速度。它由三个传感器组成:一个三轴加速度计和一个三轴陀螺仪。这个组合提供了非常精细的姿态解算,其…

    编程 2025-04-25
  • 详解eclipse设置

    一、安装与基础设置 1、下载eclipse并进行安装。 2、打开eclipse,选择对应的工作空间路径。 File -> Switch Workspace -> [选择…

    编程 2025-04-25
  • Python安装OS库详解

    一、OS简介 OS库是Python标准库的一部分,它提供了跨平台的操作系统功能,使得Python可以进行文件操作、进程管理、环境变量读取等系统级操作。 OS库中包含了大量的文件和目…

    编程 2025-04-25
  • Java BigDecimal 精度详解

    一、基础概念 Java BigDecimal 是一个用于高精度计算的类。普通的 double 或 float 类型只能精确表示有限的数字,而对于需要高精度计算的场景,BigDeci…

    编程 2025-04-25
  • git config user.name的详解

    一、为什么要使用git config user.name? git是一个非常流行的分布式版本控制系统,很多程序员都会用到它。在使用git commit提交代码时,需要记录commi…

    编程 2025-04-25

发表回复

登录后才能评论