AttentionUNet:一種全新的醫學圖像分割方式

一、AttentionUNet速度

AttentionUNet是一個新穎的網絡結構,它有效地將U-Net(一種流行的醫學圖像分割框架)與注意力機制相結合,可以在更少的時間內實現高質量的醫學圖像分割。

相比於傳統的U-Net模型,AttentionUNet的速度要快得多。因為AttentionUNet引入了注意力機制,可以只關注有用的特徵,從而減少了網絡的計算複雜度。

下面是使用AttentionUNet進行醫學圖像分割的示例代碼:

import torch
import torch.nn as nn
import torch.nn.functional as F

class ConvBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(ConvBlock, self).__init__()
    self.conv1 = nn.Conv2d(in_channels, out_channels, 3, padding=1)
    self.bn1 = nn.BatchNorm2d(out_channels)
    self.conv2 = nn.Conv2d(out_channels, out_channels, 3, padding=1)
    self.bn2 = nn.BatchNorm2d(out_channels)
    
  def forward(self, x):
    x = F.relu(self.bn1(self.conv1(x)))
    x = F.relu(self.bn2(self.conv2(x)))
    return x
    
class AttentionBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(AttentionBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, 1)
    self.bn = nn.BatchNorm2d(out_channels)
    
    self.theta = nn.Conv2d(out_channels, out_channels // 8, 1)
    self.phi = nn.Conv2d(out_channels, out_channels // 8, 1)
    self.g = nn.Conv2d(out_channels, out_channels // 2, 1)
    
    self.W = nn.Conv2d(out_channels // 2, out_channels, 1)
    
  def forward(self, x):
    h = F.relu(self.bn(self.conv(x)))
    
    theta = self.theta(h)
    phi = F.max_pool2d(self.phi(h), [2, 2])
    g = F.max_pool2d(self.g(h), [2, 2])
    
    theta = theta.view(-1, theta.size(1), theta.size(2) * theta.size(3))
    theta = theta.permute(0, 2, 1)
    phi = phi.view(-1, phi.size(1), phi.size(2) * phi.size(3))
    f = torch.matmul(theta, phi)
    f = F.softmax(f, dim=-1)
    
    g = g.view(-1, g.size(1), g.size(2) * g.size(3))
    out = torch.matmul(f, g)
    out = out.permute(0, 2, 1).contiguous()
    out = out.view(-1, self.W.size(1), h.size(2), h.size(3))
    out = self.W(out)
    
    return out + h
    
class AttentionUNet(nn.Module):
  def __init__(self, in_channels=3, out_channels=1, init_features=32):
    super(AttentionUNet, self).__init__()
    self.downsamples = nn.ModuleList([])
    self.upsamples = nn.ModuleList([])
    
    features = init_features
    self.conv1 = nn.Conv2d(in_channels, features, 3, padding=1)
    self.bn1 = nn.BatchNorm2d(features)
    self.conv2 = nn.Conv2d(features, features, 3, padding=1)
    self.bn2 = nn.BatchNorm2d(features)
    
    for i in range(4):
      self.downsamples.append(ConvBlock(features, features * 2))
      features = features * 2
      
    features = features * 2
    self.bridge = ConvBlock(features, features)
    
    for i in range(4):
      self.upsamples.append(AttentionBlock(features, features // 2))
      features = features // 2
    self.conv3 = nn.Conv2d(init_features, out_channels, 1)
    
  def forward(self, x):
    residuals = []
    out = F.relu(self.bn1(self.conv1(x)))
    out = F.relu(self.bn2(self.conv2(out)))
    residuals.append(out)
    
    for downsample in self.downsamples:
      out = downsample(out)
      residuals.append(out)
      
    out = self.bridge(out)
    
    for i in range(len(self.upsamples)):
      attention = self.upsamples[i](out)
      out = F.interpolate(out, scale_factor=2, mode='bilinear', align_corners=True)
      out = torch.cat([out, attention], dim=1)
      
    out = self.conv3(torch.cat([residuals[-1], out], dim=1))
    return out
  
if __name__ == '__main__':
  model = AttentionUNet(in_channels=3, out_channels=1, init_features=32)
  print(model)

二、AttentionUNet代碼

AttentionUNet模型的代碼可以從上面的示例中簡單地看出來,它是由若干個卷積塊和注意力塊組成,並在卷積塊之間添加了下採樣和上採樣操作,從而得到更好的分辨率。

注意力塊在這裡起到了非常重要的作用,能夠專註於有用的特徵,從而幫助網絡更快地學習到有意義的信息。

代碼中的模型結構在訓練醫學圖像分割模型時特別有用,下面是數據準備和模型訓練的示例代碼:

from torch.utils.data import DataLoader
from torchvision import transforms

train_transforms = transforms.Compose([
  transforms.RandomHorizontalFlip(p=0.5),
  transforms.RandomVerticalFlip(p=0.5),
  transforms.ToTensor(),
])

val_transforms = transforms.Compose([
  transforms.ToTensor(),
])

train_data = MedicalImageSegmentationDataset(data_dir='train', transforms=train_transforms)
val_data = MedicalImageSegmentationDataset(data_dir='val', transforms=val_transforms)

train_loader = DataLoader(train_data, batch_size=4, shuffle=True, num_workers=2)
val_loader = DataLoader(val_data, batch_size=1, shuffle=False, num_workers=1)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = AttentionUNet(in_channels=3, out_channels=1, init_features=32).to(device)

criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

for epoch in range(10):
  model.train()
  train_loss = 0
  
  for i, data in enumerate(train_loader):
    inputs, labels = data['input'].to(device), data['label'].to(device)
    
    optimizer.zero_grad()
    outputs = model(inputs)
    loss = criterion(outputs, labels)
    loss.backward()
    optimizer.step()
    
    train_loss += loss.item()
  
  train_loss /= len(train_loader)
  
  model.eval()
  val_loss = 0
  
  with torch.no_grad():
    for i, data in enumerate(val_loader):
      inputs, labels = data['input'].to(device), data['label'].to(device)
      
      outputs = model(inputs)
      loss = criterion(outputs, labels)
      
      val_loss += loss.item()
      
  val_loss /= len(val_loader)
  
  print(f'Epoch {epoch + 1}, Training Loss: {train_loss:.4f}, Validation Loss: {val_loss:.4f}')

三、AttentionUNet參數量

AttentionUNet相對於傳統的U-Net模型來說有更多的參數,但是這些參數是經過仔細設計的,能夠幫助網絡更好地學習長期依賴關係。此外,注意力機制還可以降低網絡的計算複雜度,使得AttentionUNet在醫學圖像分割方面的實際表現要更加出色。

AttentionUNet網絡的總參數量隨着模型深度的增加而逐漸增加,但是相比於其他一些現有的醫學圖像分割方法,AttentionUNet的參數量並不是非常大,訓練也可以在合理的時間內完成。

下面是獲取AttentionUNet模型的總參數量的代碼:

from torchsummary import summary

model = AttentionUNet(in_channels=3, out_channels=1, init_features=32)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

summary(model, (3, 256, 256))

四、AttentionUNet中的注意力選取

AttentionUNet中的注意力選取是通過一系列的卷積操作來實現的,這個過程被稱為自注意力機制。自注意力機制可以幫助網絡專註於有用的特徵,從而加速模型的學習過程。在AttentionUNet中,自注意力機制被應用於每個注意力塊中,以選擇最有用的特徵並將其提供給下一層。

下面是在AttentionBlock中實現注意力選取的代碼:

class AttentionBlock(nn.Module):
  def __init__(self, in_channels, out_channels):
    super(AttentionBlock, self).__init__()
    self.conv = nn.Conv2d(in_channels, out_channels, 1)
    self.bn = nn.BatchNorm2d(out_channels)
    
    self.theta = nn.Conv2d(out_channels, out_channels // 8, 1)
    self.phi = nn.Conv2d(out_channels, out_channels // 8, 1)
    self.g = nn.Conv2d(out_channels, out_channels // 2, 1)
    
    self.W = nn.Conv2d(out_channels // 2, out_channels, 1)
    
  def forward(self, x):
    h = F.relu(self.bn(self.conv(x)))
    
    theta = self.theta(h)
    phi = F.max_pool2d(self.phi(h), [2, 2])
    g = F.max_pool2d(self.g(h), [2, 2])
    
    theta = theta.view(-1, theta.size(1), theta.size(2) * theta.size(3))
    theta = theta.permute(0, 2, 1)
    phi = phi.view(-1, phi.size(1), phi.size(2) * phi.size(3))
    f = torch.matmul(theta, phi)
    f = F.softmax(f, dim=-1)
    
    g = g.view(-1, g.size(1), g.size(2) * g.size(3))
    out = torch.matmul(f, g)
    out = out.permute(0, 2, 1).contiguous()
    out = out.view(-1, self.W.size(1), h.size(2), h.size(3))
    out = self.W(out)
    
    return out + h

通過上面的代碼,我們可以很清楚地看到注意力選取是如何在AttentionBlock中實現的,具體來說,它通過三個卷積函數來計算每個像素點的注意力權重。這些函數在網絡中不斷交替使用,以將有價值的信息提供給下一層,從而更好地分割醫學圖像。

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hk/n/197114.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-12-03 13:27
下一篇 2024-12-03 13:27

相關推薦

  • 如何在Java中拼接OBJ格式的文件並生成完整的圖像

    OBJ格式是一種用於表示3D對象的標準格式,通常由一組頂點、面和紋理映射坐標組成。在本文中,我們將討論如何將多個OBJ文件拼接在一起,生成一個完整的3D模型。 一、讀取OBJ文件 …

    編程 2025-04-29
  • 如何實現圖像粘貼到蒙版

    本文將從多個方面介紹圖像粘貼到蒙版的實現方法。 一、創建蒙版 首先,在HTML中創建一個蒙版元素,用於接收要粘貼的圖片。 <div id=”mask” style=”widt…

    編程 2025-04-29
  • Python緩存圖片的處理方式

    本文將從多個方面詳細闡述Python緩存圖片的處理方式,包括緩存原理、緩存框架、緩存策略、緩存更新和緩存清除等方面。 一、緩存原理 緩存是一種提高應用程序性能的技術,在網絡應用中流…

    編程 2025-04-29
  • Python圖像黑白反轉用法介紹

    本文將從多個方面詳細闡述Python圖像黑白反轉的方法和技巧。 一、Pillow模塊介紹 Pillow是Python的一個圖像處理模塊,可以進行圖片的裁剪、旋轉、縮放等操作。使用P…

    編程 2025-04-28
  • Matlab二值圖像全面解析

    本文將全面介紹Matlab二值圖像的相關知識,包括二值圖像的基本原理、如何對二值圖像進行處理、如何從二值圖像中提取信息等等。通過本文的學習,你將能夠掌握Matlab二值圖像的基本操…

    編程 2025-04-28
  • Python在線編輯器的優勢與實現方式

    Python在線編輯器是Python語言愛好者的重要工具之一,它可以讓用戶方便快捷的在線編碼、調試和分享代碼,無需在本地安裝Python環境。本文將從多個方面對Python在線編輯…

    編程 2025-04-28
  • Python實現圖像轉化為灰度圖像

    本文將從多個方面詳細闡述如何使用Python將圖像轉化為灰度圖像,包括圖像的概念、灰度圖像的概念、Python庫的使用以及完整的Python代碼實現。 一、圖像與灰度圖像 圖像是指…

    編程 2025-04-28
  • 圖像與信號處理期刊級別

    本文將從多個方面介紹圖像與信號處理期刊級別的相關知識,包括圖像壓縮、人臉識別、關鍵點匹配等等。 一、圖像壓縮 圖像在傳輸和存儲中佔據了大量的空間,因此圖像壓縮成為了很重要的技術。常…

    編程 2025-04-28
  • Java表單提交方式

    Java表單提交有兩種方式,分別是get和post。下面我們將從以下幾個方面詳細闡述這兩種方式。 一、get方式 1、什麼是get方式 在get方式下,表單的數據會以查詢字符串的形…

    編程 2025-04-27
  • 用Pythonic的方式編寫高效代碼

    Pythonic是一種編程哲學,它強調Python編程風格的簡單、清晰、優雅和明確。Python應該描述為一種語言而不是一種編程語言。Pythonic的編程方式不僅可以使我們在編碼…

    編程 2025-04-27

發表回復

登錄後才能評論