深入理解 Lovasz Loss

Lovasz Loss 是一種用於訓練分割模型的損失函數,它通過最小化真實標籤和預測標籤之間的差異來提高模型的準確性和穩定性,被廣泛應用於醫學圖像分割、自然語言處理、圖像識別、社交網路分析等領域。

一、Lovasz Loss 簡介

Lovasz Loss 的核心在於求解兩個集合之間的距離,其中一個集合是真實標籤集合,另一個集合是預測標籤集合。距離的計算方法是基於 Lovasz 擴展理論的,該理論主要用於研究無序的、不可比的有限偏序集的性質。

在分割模型中,我們通常使用 Dice Loss 或交叉熵損失作為評價指標,但是這些損失函數不太適用於非平衡數據集,因為它們會導致分類結果傾向於具有較多樣本的類別。

Lovasz Loss 的主要優點是,它可以有效地處理非平衡數據集,並且在處理稀疏邊界問題時非常有效。此外,Lovasz Loss 與直接優化非概率評分函數(如 IoU 或 Dice 等)相比具有更好的數學性質。

二、計算 Lovasz Loss

Lovasz Loss 的核心在於計算預測序列的排列代價,它可以表示為以下公式:

def lovasz_grad(gt_sorted):
    """
    計算 Lovasz Loss 的梯度
    """
    p = len(gt_sorted)
    gts = gt_sorted.sum()
    intersection = gts - gt_sorted.float().cumsum(0)
    union = gts + (1 - gt_sorted).float().cumsum(0)
    jaccard = 1.0 - intersection / union
    if p > 1:
        jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
    return jaccard

def flatten_binary_scores(scores, labels):
    """
    將概率評分函數與對應的標籤轉換為二進位函數
    """
    num_classes = scores.shape[1]
    all_thresh = torch.unsqueeze(torch.arange(num_classes), dim=0).cuda()
    all_scores = torch.unsqueeze(scores, dim=0)
    all_labels = torch.unsqueeze(labels, dim=0)

    gt = all_labels.long().cuda()
    scores = all_scores.float().cuda()
    scores = (scores > torch.unsqueeze(all_thresh, dim=2)).float()
    scores_sorted, _ = torch.sort(scores, dim=1, descending=True)

    grad = []
    loss = []
    for i in range(num_classes):
        gt_i = gt[:, i].float()
        grad_i = lovasz_grad(gt_i * 2 - 1)
        grad.append(grad_i)
        loss_i = torch.dot(torch.relu(scores_sorted[:, i] - gt_i * 2 + 1), grad_i)
        loss.append(loss_i)
    return torch.stack(loss), torch.stack(grad)

其中 gt_sorted 是通過對真實標籤集合進行排序得到的標籤序列,scores 是模型產生的預測標籤序列。這個函數將概率評分函數與對應的標籤轉換為二進位函數,然後計算二進位函數的 Lovasz Loss。

下面是 Lovasz Loss 的標準表達式:

def multi_lovasz_loss(scores, labels):
    """
    計算多類別 Lovasz Loss
    """
    num_classes = scores.shape[1]
    if num_classes == 1:
        loss, _ = lovasz_hinge(scores.squeeze().float(), labels.float())
        return loss.unsqueeze(0)
    losses = []
    grad = None
    for i in range(num_classes):
        loss_i, grad_i = lovasz_hinge(scores[:, i], labels[:, i], per_image=False)
        losses.append(loss_i)
        if grad is None:
            grad = torch.empty(num_classes, grad_i.size(0)).cuda()
        grad[i] = grad_i
    loss = torch.cat(losses).mean()
    return loss, grad

該函數可以計算多類別 Lovasz Loss,如果只有一個類別,它會使用 Lovasz Hinge Loss。

三、應用 Lovasz Loss

Lovasz Loss 在分割模型、圖像識別、社交網路分析等領域都得到了廣泛的應用。下面是一個利用 Lovasz Loss 進行圖像分割的實例:

class SegmentationLoss(nn.Module):
    """
    基於 Lovasz Loss 的圖像分割損失函數
    """
    def __init__(self, mode='binary', per_image=False):
        super(SegmentationLoss, self).__init__()
        self.mode = mode
        self.per_image = per_image

    def forward(self, outputs, labels):
        if self.mode == 'binary':
            loss, grad = lovasz_hinge(outputs.squeeze(), labels.squeeze(), per_image=self.per_image)
        elif self.mode == 'multiclass':
            loss, grad = multi_lovasz_loss(outputs, labels)

        return loss

我們可以通過定義一個繼承自 nn.Module 的 SegmentationLoss 類來使用 Lovasz Loss 訓練分割模型。根據需要,可以選擇單類別分割或多類別分割。

四、總結

Lovasz Loss 在非平衡數據集的圖像分割中具有很好的性能,特別是在處理稀疏圖像邊界問題時非常有效。然而,它也有一些缺點,例如在計算上相對複雜,訓練時間相對較長。

通過深入理解 Lovasz Loss 的核心思想和計算方法,我們可以更好地應用它來提高分割模型的準確性和穩定性。

原創文章,作者:SRVNJ,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/334862.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
SRVNJ的頭像SRVNJ
上一篇 2025-02-05 13:05
下一篇 2025-02-05 13:05

相關推薦

  • eslint no-loss-of-precision requires at least eslint v7.1.0

    這篇文章將從以下幾個方面詳細闡述eslint no-loss-of-precision requires至少需要eslint v7.1.0版本的問題: 一、概述 如果使用較老的es…

    編程 2025-04-29
  • 深入解析Vue3 defineExpose

    Vue 3在開發過程中引入了新的API `defineExpose`。在以前的版本中,我們經常使用 `$attrs` 和` $listeners` 實現父組件與子組件之間的通信,但…

    編程 2025-04-25
  • 深入理解byte轉int

    一、位元組與比特 在討論byte轉int之前,我們需要了解位元組和比特的概念。位元組是計算機存儲單位的一種,通常表示8個比特(bit),即1位元組=8比特。比特是計算機中最小的數據單位,是…

    編程 2025-04-25
  • 深入理解Flutter StreamBuilder

    一、什麼是Flutter StreamBuilder? Flutter StreamBuilder是Flutter框架中的一個內置小部件,它可以監測數據流(Stream)中數據的變…

    編程 2025-04-25
  • 深入探討OpenCV版本

    OpenCV是一個用於計算機視覺應用程序的開源庫。它是由英特爾公司創建的,現已由Willow Garage管理。OpenCV旨在提供一個易於使用的計算機視覺和機器學習基礎架構,以實…

    編程 2025-04-25
  • 深入了解scala-maven-plugin

    一、簡介 Scala-maven-plugin 是一個創造和管理 Scala 項目的maven插件,它可以自動生成基本項目結構、依賴配置、Scala文件等。使用它可以使我們專註於代…

    編程 2025-04-25
  • 深入了解LaTeX的腳註(latexfootnote)

    一、基本介紹 LaTeX作為一種排版軟體,具有各種各樣的功能,其中腳註(footnote)是一個十分重要的功能之一。在LaTeX中,腳註是用命令latexfootnote來實現的。…

    編程 2025-04-25
  • 深入探討馮諾依曼原理

    一、原理概述 馮諾依曼原理,又稱「存儲程序控制原理」,是指計算機的程序和數據都存儲在同一個存儲器中,並且通過一個統一的匯流排來傳輸數據。這個原理的提出,是計算機科學發展中的重大進展,…

    編程 2025-04-25
  • 深入剖析MapStruct未生成實現類問題

    一、MapStruct簡介 MapStruct是一個Java bean映射器,它通過註解和代碼生成來在Java bean之間轉換成本類代碼,實現類型安全,簡單而不失靈活。 作為一個…

    編程 2025-04-25
  • 深入了解Python包

    一、包的概念 Python中一個程序就是一個模塊,而一個模塊可以引入另一個模塊,這樣就形成了包。包就是有多個模塊組成的一個大模塊,也可以看做是一個文件夾。包可以有效地組織代碼和數據…

    編程 2025-04-25

發表回復

登錄後才能評論