CRNN網絡詳解

一、CRNN概述

CRNN(Convolutional Recurrent Neural Network)是由全卷積神經網絡(FCN)和循環神經網絡(RNN)結合而成,主要應用於圖像與文本中的場景文本識別(Scene Text Recognition,STR)任務。CRNN網絡結合了CNN網絡能夠提取高維特徵的優點和RNN網絡能夠捕捉上下文關係的優點,因此在文本識別任務中取得了優秀的表現。

二、CRNN結構

CRNN網絡結構包括卷積層(Convolutional Layer)、循環層(Recurrent Layer)和轉錄層(Transcription Layer)三個部分。

1.卷積層

卷積層負責從原始圖像中提取特徵。一般的,訓練好的卷積層包括了數個卷積層和池化層,其中卷積層負責提取特徵,池化層負責保證計算速度和空間不變性。最後在特徵圖上進行特徵選擇,刪去無用特徵。

import torch.nn as nn
import torch

class Conv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1,
                 norm_layer=None, activation_layer=None, bias=True):
        super(Conv, self).__init__()
        if norm_layer is None:
            norm_layer = nn.BatchNorm2d
        if activation_layer is None:
            activation_layer = nn.ReLU(inplace=True)
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias)
        self.bn = norm_layer(out_channels)
        self.act = activation_layer

    def forward(self, x):
        x = self.conv(x)
        x = self.bn(x)
        x = self.act(x)
        return x

2.循環層

循環層負責對特徵序列進行處理。由於文本向量是一個序列,需要一種能夠捕捉序列信息的算法。RNN即循環神經網絡,它的輸出狀態一方面與上一次的狀態相關,一方面與當前的輸入相關。

class BidirectionalLSTM(nn.Module):

    def __init__(self, nIn, nHidden, nOut):
        super(BidirectionalLSTM, self).__init__()

        self.rnn = nn.LSTM(nIn, nHidden, bidirectional=True)
        self.embedding = nn.Linear(nHidden * 2, nOut)

    def forward(self, input):
        recurrent, _ = self.rnn(input)
        T, b, h = recurrent.size()
        t_rec = recurrent.view(T*b, h)

        output = self.embedding(t_rec)  # [T * b, nOut]
        output = output.view(T, b, -1)

        return output

3.轉錄層

轉錄層負責將特徵圖轉化為文本。具體來說是對卷積層和循環層處理後為一個序列的特徵圖進行轉錄。轉錄可以採用CTC算法(Connectionist Temporal Classification)。

class Transcription(nn.Module):
    def __init__(self, n_class):
        super(Transcription, self).__init__()

        self.fc = nn.Linear(512, n_class)

    def forward(self, x):
        T = x.size(0)
        x = x.view(T, -1)
        x = self.fc(x)

        return x

三、CRNN參數設置

CRNN網絡參數設置如下:

n_class = 37  # 26個字母+數字+一些特殊符號
input_height = 32  # 圖像高度
n_channel = 1  # 圖像通道數,黑白圖像為1
n_hidden = 256  # 循環層隱藏單元個數

四、CRNN訓練

CRNN網絡的訓練需要準備訓練集和驗證集數據,並按照批次大小(batch size)進行訓練。

from torchvision import transforms, datasets
from torch.utils.data import DataLoader

transform = transforms.Compose([
    transforms.Grayscale(),  # 將彩色圖像轉為灰度圖像
    transforms.Resize((input_height, 100)),  # 將圖像高度設置為32,寬度壓縮到100
    transforms.ToTensor(),  # 將圖像轉化為Tensor
])

train_dataset = datasets.ImageFolder(root="./train", transform=transform)
train_loader = DataLoader(train_dataset, batch_size=2, shuffle=True)

test_dataset = datasets.ImageFolder(root="./test", transform=transform)
test_loader = DataLoader(test_dataset, batch_size=2, shuffle=True)

crnn = CRNN(n_channel, n_hidden, n_class)

optimizer = torch.optim.Adam(crnn.parameters(), lr=0.0001)

loss_fn = nn.CTCLoss()

num_epoch = 20

for epoch in range(num_epoch):
    train_loss = 0.0
    for idx, (image, label) in enumerate(train_loader):
        image = image.to(device)
        label = label.to(device)
        output = crnn(image)
        output_size = torch.IntTensor([output.size(0)] * output.size(1))
        loss = loss_fn(output, label, output_size, label.size(0))
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        train_loss += loss.item()

    print("Epoch: ", epoch, "Loss: ", train_loss/len(train_loader))

五、CRNN識別

CRNN網絡可以通過輸入待識別的圖像,得到對應的文本。代碼如下:

image_path = "./test/1.png"
image = Image.open(image_path)
image = transform(image).unsqueeze(0)
image = image.to(device)

crnn.eval()
output = crnn(image)
output_argmax = output.argmax(dim=2).squeeze()
predicted_sentence = convert_to_text(output_argmax, id_to_char)
print("Predicted sentence: ", predicted_sentence)

原創文章,作者:小藍,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/275820.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
小藍的頭像小藍
上一篇 2024-12-17 16:06
下一篇 2024-12-17 16:06

相關推薦

  • 使用Netzob進行網絡協議分析

    Netzob是一款開源的網絡協議分析工具。它提供了一套完整的協議分析框架,可以支持多種數據格式的解析和可視化,方便用戶對協議數據進行分析和定製。本文將從多個方面對Netzob進行詳…

    編程 2025-04-29
  • 微軟發布的網絡操作系統

    微軟發布的網絡操作系統指的是Windows Server操作系統及其相關產品,它們被廣泛應用於企業級雲計算、數據庫管理、虛擬化、網絡安全等領域。下面將從多個方面對微軟發布的網絡操作…

    編程 2025-04-28
  • 蔣介石的人際網絡

    本文將從多個方面對蔣介石的人際網絡進行詳細闡述,包括其對政治局勢的影響、與他人的關係、以及其在歷史上的地位。 一、蔣介石的政治影響 蔣介石是中國現代歷史上最具有政治影響力的人物之一…

    編程 2025-04-28
  • 基於tcifs的網絡文件共享實現

    tcifs是一種基於TCP/IP協議的文件系統,可以被視為是SMB網絡文件共享協議的衍生版本。作為一種開源協議,tcifs在Linux系統中得到廣泛應用,可以實現在不同設備之間的文…

    編程 2025-04-28
  • 如何開發一個網絡監控系統

    網絡監控系統是一種能夠實時監控網絡中各種設備狀態和流量的軟件系統,通過對網絡流量和設備狀態的記錄分析,幫助管理員快速地發現和解決網絡問題,保障整個網絡的穩定性和安全性。開發一套高效…

    編程 2025-04-27
  • 用Python爬取網絡女神頭像

    本文將從以下多個方面詳細介紹如何使用Python爬取網絡女神頭像。 一、準備工作 在進行Python爬蟲之前,需要準備以下幾個方面的工作: 1、安裝Python環境。 sudo a…

    編程 2025-04-27
  • 如何使用Charles Proxy Host實現網絡請求截取和模擬

    Charles Proxy Host是一款非常強大的網絡代理工具,它可以幫助我們截取和模擬網絡請求,方便我們進行開發和調試。接下來我們將從多個方面詳細介紹如何使用Charles P…

    編程 2025-04-27
  • 網絡拓撲圖的繪製方法

    在計算機網絡的設計和運維中,網絡拓撲圖是一個非常重要的工具。通過拓撲圖,我們可以清晰地了解網絡結構、設備分布、鏈路情況等信息,從而方便進行故障排查、優化調整等操作。但是,要繪製一張…

    編程 2025-04-27
  • 網絡爬蟲什麼意思?

    網絡爬蟲(Web Crawler)是一種程序,可以按照制定的規則自動地瀏覽互聯網,並將獲取到的數據存儲到本地或者其他指定的地方。網絡爬蟲通常用於搜索引擎、數據採集、分析和處理等領域…

    編程 2025-04-27
  • 網絡數據爬蟲技術用法介紹

    網絡數據爬蟲技術是指通過一定的策略、方法和技術手段,獲取互聯網上的數據信息並進行處理的一種技術。本文將從以下幾個方面對網絡數據爬蟲技術做詳細的闡述。 一、爬蟲原理 網絡數據爬蟲技術…

    編程 2025-04-27

發表回復

登錄後才能評論