深度學習中的LSTM與GRU

深度學習中有許多的RNN(循環神經網路)結構,其中LSTM(長短時記憶網路)與GRU(門限遞歸單元)是應用比較廣泛的兩種結構。本文將重點介紹這兩種結構的原理和應用,並提供完整的代碼示例。

一、LSTM介紹

LSTM最早是由Hochreiter等人在1997年提出的,它可以避免RNN中的梯度消失問題(當反向傳播過程中,梯度值變得很小,導致很難訓練)。

LSTM的核心是使用三個門結構(輸入門、輸出門和遺忘門)來控制信息的流動,保證信息在訓練過程中可以根據需要選擇保留或遺忘。其中,輸入門可以控制當前信息的「重要性」,輸出門可以控制信息的輸出程度,遺忘門可以控制需要遺忘的信息。

import tensorflow as tf

# 定義LSTM結構
class LSTM(tf.keras.Model):
    def __init__(self, units):
        super(LSTM, self).__init__()
        self.units = units
        self.forget_gate = tf.keras.layers.Dense(units, activation='sigmoid')
        self.input_gate = tf.keras.layers.Dense(units, activation='sigmoid')
        self.output_gate = tf.keras.layers.Dense(units, activation='sigmoid')
        self.memory_gate = tf.keras.layers.Dense(units, activation='tanh')

    def call(self, inputs, memory, state):
        concat_inputs = tf.concat([inputs, memory], axis=-1)
        forget = self.forget_gate(concat_inputs)
        input = self.input_gate(concat_inputs)
        output = self.output_gate(concat_inputs)
        memory_ = forget * state + input * self.memory_gate(concat_inputs)
        state_ = output * tf.tanh(memory_)

        return state_, memory_

二、GRU介紹

GRU是於2014年由Cho等人提出的,它對LSTM進行了簡化,將輸入門和遺忘門合併為「重置門」,將輸出門合併為「更新門」。

GRU的優點在於計算速度比LSTM快,同時也相對容易訓練,因此在一些較為簡單的任務中,GRU的表現可以與LSTM相當甚至更好。

# 定義GRU結構
class GRU(tf.keras.Model):
    def __init__(self, units):
        super(GRU, self).__init__()
        self.units = units
        self.reset_gate = tf.keras.layers.Dense(units, activation='sigmoid')
        self.update_gate = tf.keras.layers.Dense(units, activation='sigmoid')
        self.memory_gate = tf.keras.layers.Dense(units, activation='tanh')

    def call(self, inputs, state):
        concat_inputs = tf.concat([inputs, state], axis=-1)
        reset = self.reset_gate(concat_inputs)
        update = self.update_gate(concat_inputs)
        memory = self.memory_gate(tf.concat([inputs, reset * state], axis=-1))
        state_ = update * state + (1 - update) * memory

        return state_

三、應用案例

在很多自然語言處理(NLP)的任務中,LSTM和GRU都得到了廣泛的應用,例如語言模型、機器翻譯、情感分析等。下面以情感分析為例,展示如何使用LSTM和GRU對文本進行情感分類。

情感分析的數據集通常包括一系列帶有標記的文本數據,如0表示負面情緒,1表示正面情緒。我們可以使用LSTM或GRU對文本進行處理,並利用全連接層進行分類。以下是使用Keras框架實現的完整代碼示例。

import tensorflow as tf
from tensorflow.keras.datasets import imdb
from tensorflow.keras.preprocessing.sequence import pad_sequences

# 模型參數
max_features = 10000
max_len = 200
embedding_dim = 128
lstm_units = 64
batch_size = 32
epochs = 10

# 載入數據集
(x_train, y_train), (x_val, y_val) = imdb.load_data(num_words=max_features)

# 填充序列
x_train = pad_sequences(x_train, maxlen=max_len)
x_val = pad_sequences(x_val, maxlen=max_len)

# 定義模型
model = tf.keras.Sequential([
    tf.keras.layers.Embedding(max_features, embedding_dim, input_length=max_len),
    tf.keras.layers.LSTM(lstm_units, dropout=0.2, recurrent_dropout=0.2),
    tf.keras.layers.Dense(1, activation='sigmoid')
])

# 編譯模型
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy'])

# 訓練模型
model.fit(x_train, y_train, batch_size=batch_size, epochs=epochs, validation_data=(x_val, y_val))

# 測試
test_loss, test_acc = model.evaluate(x_val, y_val)
print('Test Accuracy:', test_acc)

以上代碼中,我們使用了Keras自帶的IMDB數據集,利用LSTM和全連接層進行情感分析任務的訓練,並在測試集上進行測試,最終輸出測試的準確率。

原創文章,作者:PRXOK,如若轉載,請註明出處:https://www.506064.com/zh-tw/n/368623.html

(0)
打賞 微信掃一掃 微信掃一掃 支付寶掃一掃 支付寶掃一掃
PRXOK的頭像PRXOK
上一篇 2025-04-12 01:13
下一篇 2025-04-12 01:13

相關推薦

  • 深度查詢宴會的文化起源

    深度查詢宴會,是指通過對一種文化或主題的深度挖掘和探究,為參與者提供一次全方位的、深度體驗式的文化品嘗和交流活動。本文將從多個方面探討深度查詢宴會的文化起源。 一、宴會文化的起源 …

    編程 2025-04-29
  • Python下載深度解析

    Python作為一種強大的編程語言,在各種應用場景中都得到了廣泛的應用。Python的安裝和下載是使用Python的第一步,對這個過程的深入了解和掌握能夠為使用Python提供更加…

    編程 2025-04-28
  • Python遞歸深度用法介紹

    Python中的遞歸函數是一個函數調用自身的過程。在進行遞歸調用時,程序需要為每個函數調用開闢一定的內存空間,這就是遞歸深度的概念。本文將從多個方面對Python遞歸深度進行詳細闡…

    編程 2025-04-27
  • Spring Boot本地類和Jar包類載入順序深度剖析

    本文將從多個方面對Spring Boot本地類和Jar包類載入順序做詳細的闡述,並給出相應的代碼示例。 一、類載入機制概述 在介紹Spring Boot本地類和Jar包類載入順序之…

    編程 2025-04-27
  • 深度解析Unity InjectFix

    Unity InjectFix是一個非常強大的工具,可以用於在Unity中修復各種類型的程序中的問題。 一、安裝和使用Unity InjectFix 您可以通過Unity Asse…

    編程 2025-04-27
  • 深度剖析:cmd pip不是內部或外部命令

    一、問題背景 使用Python開發時,我們經常需要使用pip安裝第三方庫來實現項目需求。然而,在執行pip install命令時,有時會遇到「pip不是內部或外部命令」的錯誤提示,…

    編程 2025-04-25
  • 動手學深度學習 PyTorch

    一、基本介紹 深度學習是對人工神經網路的發展與應用。在人工神經網路中,神經元通過接受輸入來生成輸出。深度學習通常使用很多層神經元來構建模型,這樣可以處理更加複雜的問題。PyTorc…

    編程 2025-04-25
  • 深度解析Ant Design中Table組件的使用

    一、Antd表格兼容 Antd是一個基於React的UI框架,Table組件是其重要的組成部分之一。該組件可在各種瀏覽器和設備上進行良好的兼容。同時,它還提供了多個版本的Antd框…

    編程 2025-04-25
  • 深度解析MySQL查看當前時間的用法

    MySQL是目前最流行的關係型資料庫管理系統之一,其提供了多種方法用於查看當前時間。在本篇文章中,我們將從多個方面來介紹MySQL查看當前時間的用法。 一、當前時間的獲取方法 My…

    編程 2025-04-24
  • 深度學習魚書的多個方面詳解

    一、基礎知識介紹 深度學習魚書是一本系統性的介紹深度學習的圖書,主要介紹深度學習的基礎知識和數學原理,並且通過相關的應用案例來幫助讀者理解深度學習的應用場景和方法。在了解深度學習之…

    編程 2025-04-24

發表回復

登錄後才能評論