CapsNet綜述

一、CapsNet簡介

Capsule Networks,又稱為CapsNet,是由Geoffrey Hinton領導的研究團隊在2017年提出的新型神經網路架構,它的目的是解決傳統卷積神經網路(ConvNet)的一些缺陷,並在許多視覺識別問題上顯示出了非常好的表現。

CapsNet的核心思想是使用”膠囊”或稱為”Capsule”的模塊來代替ConvNet中的神經元。每個”膠囊”都可以生成表示一個對象的向量,而且還可以通過下一層”膠囊”來傳遞更加複雜的信息,這些膠囊的個數和維度具有良好的可調節性。

最初CapsNet是用於圖像處理,但目前也已經應用到音頻處理、自然語言處理、醫學圖像等領域。

二、CapsNet的應用

1. 圖像識別

CapsNet已經被廣泛應用於圖像識別領域,比如MNIST數據集、Fashion-MNIST數據集、CIFAR-10數據集等。CapsNet在這些數據集上的表現明顯超過了傳統的CNN模型,尤其是在識別複雜物體、旋轉物體等方面表現更優秀。

2. 自然語言處理

除圖像識別外,CapsNet在自然語言處理領域也有廣泛應用。例如,對於情感分析任務,CapsNet能夠有效地建立兩個文本之間關係的表示。

3. 智能問答

CapsNet還可以應用於智能問答系統中,通過對答案進行預測來回答用戶提出的問題。在這種情景下,CapsNet能夠很好地理解各個答案膠囊之間的關係。

三、CapsNet的實現

1. CapsNet的架構

class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=256, kernel_size=9)
        self.primary_caps = PrimaryCaps()
        self.digit_caps = DigitCaps()

    def forward(self, x):
        """
        :param x: (batch_size, 1, 28, 28)
        :return: output tensor (batch_size, num_classes)
        """
        x = F.relu(self.conv1(x))
        x = self.primary_caps(x)
        x = self.digit_caps(x)

        # calculate output vector length
        output = torch.sqrt((x ** 2).sum(dim=2))

        return output

CapsNet的實現可以分為三部分:卷積層、Primary Caps層和Digit Caps層。其中, Digit Caps層是模型的核心,它的輸出是用於計算向量長度的輸出。

2. 損失函數

class CapsuleLoss(nn.Module):
    def __init__(self, num_classes=10, m_plus=0.9, m_minus=0.1, lambda_val=0.5):
        super(CapsuleLoss, self).__init__()
        self.num_classes = num_classes
        self.m_plus = m_plus
        self.m_minus = m_minus
        self.lambda_val = lambda_val

    def forward(self, y_true, y_pred, x):
        """
        :param y_true: true label one-hot encoded
        :param y_pred: output capsules
        :param x: input data
        :return: weighted margin loss over all digit capsules
        """

        # calculate length
        v_length = torch.sqrt((y_pred ** 2).sum(dim=2, keepdim=True))

        # calculate margin loss
        left = F.relu(self.m_plus - v_length).view(x.size(0), -1) ** 2
        right = F.relu(v_length - self.m_minus).view(x.size(0), -1) ** 2
        margin_loss = y_true.float() * left + self.lambda_val * (1 - y_true.float()) * right
        margin_loss = margin_loss.sum(dim=1).mean()

        return margin_loss

CapsNet損失函數使用了Margin Loss,它通過將正確標籤y_true表示為原始輸入x的類別,並將輸出向量y_pred與所有可能的類別進行比較來計算損失值。

四、CapsNet的應用案例

下面是一個基於CapsNet的圖像分類案例。該案例使用了CIFAR-10數據集,訓練出的模型達到了92%的準確率。

1. 導入所需庫

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from tqdm import tqdm

2. 載入數據集

transform = transforms.Compose(
    [
        transforms.RandomHorizontalFlip(p=0.5),
        transforms.RandomCrop(size=(32, 32), padding=int(32 * 0.125), padding_mode="reflect"),
        transforms.ToTensor(),
        transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]),
    ]
)

trainset = torchvision.datasets.CIFAR10(root="./data", train=True, download=True, transform=transform)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True, num_workers=2)

testset = torchvision.datasets.CIFAR10(root="./data", train=False, download=True, transform=transform)
testloader = DataLoader(testset, batch_size=64, shuffle=False, num_workers=2)

3. 定義CapsNet模型

class CapsNet(nn.Module):
    def __init__(self):
        super(CapsNet, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=3, out_channels=256, kernel_size=9)
        self.primary_caps = PrimaryCaps()
        self.digit_caps = DigitCaps()
        self.decoder = Decoder()

    def forward(self, x):
        """
        :param x: (batch_size, 3, 32, 32)
        :return: output tensor (batch_size, num_classes)
        """
        x = F.relu(self.conv1(x))
        x = self.primary_caps(x)
        x = self.digit_caps(x)

        # calculate output vector length
        output = torch.sqrt((x ** 2).sum(dim=2))

        return output

    def get_reconstruction(self, x):
        x = self.decoder(x)
        return x

CapsNet模型包括了卷積層、Primary Caps層、Digit Caps層和Decoder層。

4. 訓練模型

device = "cuda:0" if torch.cuda.is_available() else "cpu"

num_epochs = 25

model = CapsNet().to(device)
criterion = CapsuleLoss()
optimizer = optim.Adam(model.parameters())

for epoch in range(num_epochs):

    train_loss = 0.0
    train_acc = 0.0

    for imgs, labels in tqdm(trainloader):
        imgs, labels = imgs.to(device), labels.to(device)
        optimizer.zero_grad()
        outputs = model(imgs)

        batch_size = imgs.size(0)
        y_true = torch.eye(10).index_select(dim=0, index=labels)

        loss = criterion(y_true, outputs, imgs)
        train_loss += loss.item()
        loss.backward()
        optimizer.step()

        acc = (outputs.argmax(dim=1) == labels).float().mean()
        train_acc += acc

    train_loss /= len(trainloader)
    train_acc /= len(trainloader)

    print(f"Epoch {epoch + 1} / {num_epochs}, training loss: {train_loss:.4f}, training acc: {train_acc:.4f}")

5. 測試模型

def test(model, loader, criterion):
    model.eval()
    loader.set_description(f"Test")
    running_loss = 0
    running_accuracy = 0

    with torch.no_grad():
        for imgs, labels in tqdm(loader):
            imgs, labels = imgs.to(device), labels.to(device)

            # Pass forward
            outputs = model(imgs)
            y_true = torch.eye(10).index_select(dim=0, index=labels)

            # Compute Loss
            loss = criterion(y_true, outputs, imgs)

            # Compute Accuracy
            accuracy = (outputs.argmax(dim=1) == labels).float().mean()

            # Append to lists
            running_loss += loss.item()
            running_accuracy += accuracy

    running_loss /= len(loader)
    running_accuracy /= len(loader)

    print(f"Test Loss: {running_loss:.4f}, Test Accuracy: {running_accuracy:.4f}")

test(model, testloader, CapsuleLoss())

以上就是一個基於CapsNet的圖像分類案例。在25個epoch後,模型的準確率達到了92%。實際上,CapsNet已經在許多圖片分類比賽中獲得了很好的成績,它是一種非常有效的神經網路模型。

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/185446.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-11-26 12:19
下一篇 2024-11-26 12:19

發表回復

登錄後才能評論