深入探究PyTorch中torch.nn.lstm

一、LSTM模型介紹

LSTM(Long Short-Term Memory)是一種常用的循環神經網絡模型,它具有較強的記憶功能和長短期依賴學習能力,常用於序列數據的建模。相較於傳統的RNN,LSTM擁有三個門,分別為輸入門、遺忘門、輸出門,能夠有效地控制信息的流動和遺忘。在PyTorch中,LSTM模型封裝在torch.nn.lstm類中。

二、torch.nn.lstm參數解析

在實例化torch.nn.lstm類時,需要傳遞的參數如下:


torch.nn.LSTM(
    input_size: int, # 輸入數據的特徵維度,即輸入數據的最後一維的大小
    hidden_size: int, # 隱藏層的特徵維度
    num_layers: int = 1, # LSTM模型的層數,默認是1層,表示只有一個LSTM單元
    bias: bool = True, # 是否使用偏置,默認為True
    batch_first: bool = False, # 是否將batch_size放在輸入數據的第一維,默認為False,即維度順序為(seq_len, batch, input_size)
    dropout: float = 0., # 在LSTM網絡中,對輸出的dropout比例,默認為0,即不對輸出進行dropout處理
    bidirectional: bool = False, # 是否使用雙向LSTM,默認為False,即單向LSTM
) -> None

其中,最重要的參數為input_size和hidden_size。input_size表示輸入數據的特徵維度,即輸入數據的最後一維的大小;hidden_size表示隱藏層的特徵維度。

三、torch.nn.lstm輸入數據格式

在使用torch.nn.lstm進行序列數據建模時,輸入數據需要滿足以下格式要求:

  1. 輸入數據維度為(seq_len, batch, input_size),其中seq_len表示序列長度,batch表示批量大小,input_size表示每個時間步的特徵維度。
  2. 如果使用batch_first=True,輸入數據維度需調整為(batch, seq_len, input_size)。即將批量大小放在第一維。
  3. 如果輸入數據的seq_len不足LSTM模型所要求的長度,可以使用torch.nn.utils.rnn.pad_sequence()函數進行填充,使得所有輸入數據的seq_len一致。

四、torch.nn.lstm輸出數據格式

LSTM模型的輸出包含兩個部分,一個是每個時間步的輸出,另一個是每個時間步的隱藏狀態和cell狀態。

  1. 每個時間步的輸出為(seq_len, batch, num_directions * hidden_size),其中num_directions表示LSTM模型是否為雙向LSTM。
  2. 每個時間步的隱藏狀態和cell狀態為(num_layers * num_directions, batch, hidden_size),其中num_layers表示LSTM模型的層數。

五、torch.nn.lstm使用示例

1、基本使用方法

以下代碼展示了如何使用torch.nn.lstm類創建一個單層單向LSTM模型,並對輸入數據進行前向傳播。


import torch

# 定義輸入數據
input_data = torch.randn(10, 32, 128)

# 定義LSTM模型
lstm_model = torch.nn.LSTM(input_size=128, hidden_size=256)

# 對輸入數據進行前向傳播
output, (h_n, c_n) = lstm_model(input_data)

2、使用雙向LSTM進行情感分析

以下代碼展示了如何使用torch.nn.lstm類創建一個雙向LSTM模型,並對情感分類數據進行訓練和預測。


import torch
import torchtext
from torchtext.datasets import IMDB
from torchtext.data import get_tokenizer
from torchtext.vocab import Vocab
from torch.utils.data import DataLoader

# 加載情感分類數據集IMDB
train_data, test_data = IMDB(split=('train', 'test'))

# 構建詞彙表
tokenizer = get_tokenizer('basic_english') # 使用basic_english分詞器進行分詞
counter = torchtext.vocab.Counter()
for data in train_data:
    counter.update(tokenizer(data.text))
vocab = Vocab(counter, min_freq=10)
vocab.set_default_index(vocab[''])

# 定義對輸入文本進行預處理的函數
def text_transform(text):
    tokens = tokenizer(text)
    return [vocab[token] for token in tokens]

# 加載訓練集和測試集,並對數據進行處理後,轉化為數據加載器
train_data_processed = [(text_transform(data.text), data.label) for data in train_data]
test_data_processed = [(text_transform(data.text), data.label) for data in test_data]
train_data_loader = DataLoader(train_data_processed, batch_size=64, shuffle=True)
test_data_loader = DataLoader(test_data_processed, batch_size=64, shuffle=False)

# 定義雙向LSTM模型
class BiLSTM(torch.nn.Module):
    def __init__(self, vocab_size, embedding_size, hidden_size, num_classes):
        super(BiLSTM, self).__init__()
        self.embedding = torch.nn.Embedding(num_embeddings=vocab_size, embedding_dim=embedding_size)
        self.lstm = torch.nn.LSTM(input_size=embedding_size, hidden_size=hidden_size, bidirectional=True)
        self.fc = torch.nn.Linear(in_features=hidden_size*2, out_features=num_classes)
        
    def forward(self, input_dict):
        embedding = self.embedding(input_dict['text']) # 對文本進行嵌入
        lstm_output, _ = self.lstm(embedding) # 對嵌入後的文本進行LSTM處理
        last_output = lstm_output[-1] # 取最後一個時間步的輸出
        output = self.fc(last_output) # 進行分類輸出
        return output

# 實例化模型,並對模型進行訓練和測試
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = BiLSTM(vocab_size=len(vocab), embedding_size=128, hidden_size=256, num_classes=2).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
criterion = torch.nn.CrossEntropyLoss()
for epoch in range(10):
    for input_dict in train_data_loader:
        input_dict = {key: value.to(device) for key, value in input_dict.items()}
        optimizer.zero_grad()
        output = model(input_dict)
        loss = criterion(output, input_dict['label'])
        loss.backward()
        optimizer.step()
    with torch.no_grad():
        total_correct = 0.
        for input_dict in test_data_loader:
            input_dict = {key: value.to(device) for key, value in input_dict.items()}
            output = model(input_dict)
            pred = output.argmax(dim=-1)
            total_correct += (pred == input_dict['label']).sum().item()
        accuracy = total_correct / len(test_data)
        print(f'Epoch {epoch}, test accuracy: {accuracy:.2%}')

六、總結

本文通過對PyTorch中的torch.nn.lstm類進行詳細的解析,詳細介紹了LSTM模型的特點、類的參數、輸入數據格式、輸出數據格式以及使用方法。並以一個雙向LSTM情感分類模型的例子進行了實際應用,希望能夠對使用PyTorch進行序列數據建模的讀者有所幫助。

原創文章,作者:HRHII,如若轉載,請註明出處:https://www.506064.com/zh-hant/n/368655.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
HRHII的頭像HRHII
上一篇 2025-04-12 01:13
下一篇 2025-04-12 01:13

相關推薦

  • PyTorch模塊簡介

    PyTorch是一個開源的機器學習框架,它基於Torch,是一個Python優先的深度學習框架,同時也支持C++,非常容易上手。PyTorch中的核心模塊是torch,提供一些很好…

    編程 2025-04-27
  • 深入解析Vue3 defineExpose

    Vue 3在開發過程中引入了新的API `defineExpose`。在以前的版本中,我們經常使用 `$attrs` 和` $listeners` 實現父組件與子組件之間的通信,但…

    編程 2025-04-25
  • 深入理解byte轉int

    一、字節與比特 在討論byte轉int之前,我們需要了解字節和比特的概念。字節是計算機存儲單位的一種,通常表示8個比特(bit),即1字節=8比特。比特是計算機中最小的數據單位,是…

    編程 2025-04-25
  • 深入理解Flutter StreamBuilder

    一、什麼是Flutter StreamBuilder? Flutter StreamBuilder是Flutter框架中的一個內置小部件,它可以監測數據流(Stream)中數據的變…

    編程 2025-04-25
  • 深入探討OpenCV版本

    OpenCV是一個用於計算機視覺應用程序的開源庫。它是由英特爾公司創建的,現已由Willow Garage管理。OpenCV旨在提供一個易於使用的計算機視覺和機器學習基礎架構,以實…

    編程 2025-04-25
  • 深入了解scala-maven-plugin

    一、簡介 Scala-maven-plugin 是一個創造和管理 Scala 項目的maven插件,它可以自動生成基本項目結構、依賴配置、Scala文件等。使用它可以使我們專註於代…

    編程 2025-04-25
  • 深入了解LaTeX的腳註(latexfootnote)

    一、基本介紹 LaTeX作為一種排版軟件,具有各種各樣的功能,其中腳註(footnote)是一個十分重要的功能之一。在LaTeX中,腳註是用命令latexfootnote來實現的。…

    編程 2025-04-25
  • 深入探討馮諾依曼原理

    一、原理概述 馮諾依曼原理,又稱“存儲程序控制原理”,是指計算機的程序和數據都存儲在同一個存儲器中,並且通過一個統一的總線來傳輸數據。這個原理的提出,是計算機科學發展中的重大進展,…

    編程 2025-04-25
  • 深入了解Python包

    一、包的概念 Python中一個程序就是一個模塊,而一個模塊可以引入另一個模塊,這樣就形成了包。包就是有多個模塊組成的一個大模塊,也可以看做是一個文件夾。包可以有效地組織代碼和數據…

    編程 2025-04-25
  • 深入剖析MapStruct未生成實現類問題

    一、MapStruct簡介 MapStruct是一個Java bean映射器,它通過註解和代碼生成來在Java bean之間轉換成本類代碼,實現類型安全,簡單而不失靈活。 作為一個…

    編程 2025-04-25

發表回復

登錄後才能評論