soft-nms詳解

一、什麼是soft-nms

Soft-nms是一種非極大值抑制(Non-Maximum Suppression,NMS)的改進方法,與傳統的NMS比較,soft-nms使用一個相對較為緩和的函數逐步降低重疊框的得分,從而保留了更多的框,提高了檢測的精度。

二、soft-nms的工作原理

傳統的NMS方法是根據重疊區域的大小將得分低的框刪除,而soft-nms則先利用計算出的較為緩和的函數對框的得分進行降權,然後再根據剩餘框的得分進行排序,最後按照類似於傳統NMS的方式進行篩選。

def soft_nms(dets, sigma=0.5, Nt=0.3, threshold=0.001, method=1):
    """
    PyTorch implementation of SoftNMS algorithm.
    # Arguments
        dets:        detections, size[N,5], format[x1,y1,x2,y2,score]
        sigma:       variance of Gaussian function, scalar
        Nt:          threshold for box overlap, scalar
        threshold:   score threshold, scalar
        method:      0=Max, 1=Linear, 2=Gaussian
    # Returns
        dets:        detections after SoftNMS, size[K,5]
    """

    # Indexes concatenate detection boxes with the score
    N = dets.shape[0]
    indexes = np.array([np.arange(N)])
    dets = np.concatenate((dets, indexes.T), axis=1)

    for i in range(N):
        # intermediate parameters for later parameters exchange
        si = dets[i, 4]
        xi = dets[i, :4]
        area_i = (xi[2] - xi[0] + 1) * (xi[3] - xi[1] + 1)

        if method == 1:  # Linear
            weight = np.ones((N - i))
            weight[0] = si
        else:  # Gaussian
            # Compute Gaussian weight coefficients
            xx = np.arange(i, N).astype(np.float32)
            if method == 2:
                sigma = 0.5
            ii = np.ones((xx.shape[0], 1)) * i
            # print(sigma)
            # print((xx - ii).shape)
            gauss = np.exp(-1.0 * ((xx - ii) ** 2) / (2 * sigma * sigma))

            if method == 2:
                weight = gauss
            else:
                weight = np.zeros((N - i))
                weight[0] = 1.0
                weight[1:] = gauss / np.sum(gauss)

        # Sort boxes by score
        idx = np.arange(i, N)
        idx_max = np.argmax(dets[idx, 4])
        idx_max += i

        # Swap boxes and scores
        dets[i, 4], dets[idx_max, 4] = dets[idx_max, 4], dets[i, 4]
        dets[i, :4], dets[idx_max, :4] = dets[idx_max, :4], dets[i, :4]
        dets[i, 5], dets[idx_max, 5] = dets[idx_max, 5], dets[i, 5]

        # Compute overlap ratios
        xx1 = np.maximum(dets[i, 0], dets[idx, 0])
        yy1 = np.maximum(dets[i, 1], dets[idx, 1])
        xx2 = np.minimum(dets[i, 2], dets[idx, 2])
        yy2 = np.minimum(dets[i, 3], dets[idx, 3])

        w = np.maximum(0.0, xx2 - xx1 + 1)
        h = np.maximum(0.0, yy2 - yy1 + 1)
        inter = w * h

        # Update weights
        if method == 0:  # Max
            weight[idx_max - i + 1:] = np.where(inter > Nt, 0.0, 1.0)
        else:  # Linear / Gaussian
            weight_matrix = np.zeros((weight.shape[0], weight.shape[0]))
            weight_matrix[0, :] = weight
            weight_matrix[1:, :] = np.diag(weight[1:])
            iou = inter / (area_i + dets[idx, 4] * (1 - inter))
            weight[idx - i + 1] = np.matmul(weight_matrix, (1.0 - iou).reshape(-1, 1)).reshape(-1)
            weight[idx_max - i + 1:] = np.where(iou > Nt, 0.0, weight[idx_max - i + 1:])

        # Apply weight
        dets[idx, 4] = dets[idx, 4] * weight

        # Weigh small scores
        suppress_small = np.where(dets[idx, 4]  0)[0]
    dets = dets[idx_keep]

    return dets[:, :5]

三、soft-nms的實現過程

Soft-nms的實現過程可以分為幾個步驟:

1. 輸入預測框

輸入神經網絡預測輸出的所有框,每個框有四個坐標和一個類別得分。

2. 對於每個框計算其權重

權重可以使用三種不同的函數:max、linear和Gaussian。

3. 重複以下步驟,直到不再有框被刪除

(1)選出最高得分的框,令其權重為1,與第一個框進行交換。

(2)計算當前框與剩餘框的重疊率。

(3)根據重疊率和選定的函數計算權重。

(4)根據權重更新每個框的得分。

(5)剔除得分小於設定閾值的框。

4. 輸出篩選後的結果

四、soft-nms的優點

Soft-nms與傳統的NMS相比有以下優點:

1. 保留更多的框

相對於傳統的NMS,soft-nms不會直接刪除重疊較多的框,而是通過降權,保留了更多的框。

2. 精度更高

相對於傳統的NMS,soft-nms保留了更多的框,因此可以提高檢測的精度。

3. 可以自適應地調整閾值

soft-nms的函數參數可以根據實際情況進行調整,從而自適應地調整閾值。

五、結語

Soft-nms是非常有用的一種NMS改進技術,可以在圖像檢測中提高精度,值得研究和應用。

原創文章,作者:JPBXE,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/372627.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
JPBXE的頭像JPBXE
上一篇 2025-04-24 06:40
下一篇 2025-04-24 06:40

相關推薦

  • Linux sync詳解

    一、sync概述 sync是Linux中一個非常重要的命令,它可以將文件系統緩存中的內容,強制寫入磁盤中。在執行sync之前,所有的文件系統更新將不會立即寫入磁盤,而是先緩存在內存…

    編程 2025-04-25
  • 神經網絡代碼詳解

    神經網絡作為一種人工智能技術,被廣泛應用於語音識別、圖像識別、自然語言處理等領域。而神經網絡的模型編寫,離不開代碼。本文將從多個方面詳細闡述神經網絡模型編寫的代碼技術。 一、神經網…

    編程 2025-04-25
  • Linux修改文件名命令詳解

    在Linux系統中,修改文件名是一個很常見的操作。Linux提供了多種方式來修改文件名,這篇文章將介紹Linux修改文件名的詳細操作。 一、mv命令 mv命令是Linux下的常用命…

    編程 2025-04-25
  • Python輸入輸出詳解

    一、文件讀寫 Python中文件的讀寫操作是必不可少的基本技能之一。讀寫文件分別使用open()函數中的’r’和’w’參數,讀取文件…

    編程 2025-04-25
  • nginx與apache應用開發詳解

    一、概述 nginx和apache都是常見的web服務器。nginx是一個高性能的反向代理web服務器,將負載均衡和緩存集成在了一起,可以動靜分離。apache是一個可擴展的web…

    編程 2025-04-25
  • MPU6050工作原理詳解

    一、什麼是MPU6050 MPU6050是一種六軸慣性傳感器,能夠同時測量加速度和角速度。它由三個傳感器組成:一個三軸加速度計和一個三軸陀螺儀。這個組合提供了非常精細的姿態解算,其…

    編程 2025-04-25
  • 詳解eclipse設置

    一、安裝與基礎設置 1、下載eclipse並進行安裝。 2、打開eclipse,選擇對應的工作空間路徑。 File -> Switch Workspace -> [選擇…

    編程 2025-04-25
  • Python安裝OS庫詳解

    一、OS簡介 OS庫是Python標準庫的一部分,它提供了跨平台的操作系統功能,使得Python可以進行文件操作、進程管理、環境變量讀取等系統級操作。 OS庫中包含了大量的文件和目…

    編程 2025-04-25
  • Java BigDecimal 精度詳解

    一、基礎概念 Java BigDecimal 是一個用於高精度計算的類。普通的 double 或 float 類型只能精確表示有限的數字,而對於需要高精度計算的場景,BigDeci…

    編程 2025-04-25
  • git config user.name的詳解

    一、為什麼要使用git config user.name? git是一個非常流行的分佈式版本控制系統,很多程序員都會用到它。在使用git commit提交代碼時,需要記錄commi…

    編程 2025-04-25

發表回復

登錄後才能評論