LSTM时间序列预测

一、LSTM基础原理

LSTM(Long Short-Term Memory)是一种递归神经网络,广泛用于自然语言处理、时间序列预测等领域。LSTM的主要特点是能够捕捉长期依赖关系,即能够在序列中保留多个时刻的信息。

LSTM包含一个单元(cell),可用于存储状态和控制流程。单元的核心是三个门(gate):输入门(input gate)、遗忘门(forget gate)和输出门(output gate)。输入门控制激活状态信息的输入,遗忘门控制遗忘状态信息的输入,输出门控制输出信息的输入。

具体而言,输入门、遗忘门和输出门的计算方式如下:

<img src="input.png">
<img src="forget.png">
<img src="output.png">

其中,W、U和b为可学习参数,σ为sigmoid函数。

二、Keras实现LSTM时间序列预测

Keras是一个用于构建神经网络的高级API,可用于快速构建、训练和评估各种类型的神经网络模型。下面将介绍如何使用Keras实现LSTM时间序列预测模型。

1. 数据准备

首先,我们需要准备数据。假设我们要预测某公司2021年1月至6月的销售额,我们可以使用该公司过去一年的销售额数据作为训练集。

我们将训练集按照时间顺序排序,然后取最后n个数据作为测试集(n为自定义的测试集大小)。接下来,我们需要对数据进行标准化处理(将所有数据缩放到[0,1]的范围内)。

# 加载数据
data = pd.read_csv('sales_data.csv', header=None)

# 排序
data = data.values[1:].astype('float32')
data = data[~np.isnan(data).any(axis=1)]
data = data[np.argsort(data[:, 0])]

# 划分训练集和测试集
train_size = int(len(data) * 0.8)
test_size = len(data) - train_size
train, test = data[0:train_size, :], data[train_size:len(data), :]

# 标准化处理
scaler = MinMaxScaler(feature_range=(0, 1))
train = scaler.fit_transform(train)
test = scaler.transform(test)

2. 创建数据集

接下来,我们需要创建数据集。LSTM模型需要输入一个有序的、三维的数据集。具体而言,对于每个样本,我们需要提供n_steps_in个时间步长(用于预测的历史数据)以及n_features个特征(在我们的例子中,这个特征就是销售额本身)。同时,我们也需要提供n_steps_out个时间步长的目标值(即预测的值)。

# 创建数据集
def create_dataset(dataset, n_steps_in, n_steps_out):
    dataX, dataY = [], []
    for i in range(len(dataset)-n_steps_in-n_steps_out+1):
        x = dataset[i:(i+n_steps_in), :]
        y = dataset[(i+n_steps_in):(i+n_steps_in+n_steps_out), 0]
        dataX.append(x)
        dataY.append(y)
    return np.array(dataX), np.array(dataY)

n_steps_in = 12
n_steps_out = 6

trainX, trainY = create_dataset(train, n_steps_in, n_steps_out)
testX, testY = create_dataset(test, n_steps_in, n_steps_out)

3. 构建模型

接下来,我们需要构建LSTM模型。在这个例子中,我们使用了一个含有两个LSTM层的网络,每个LSTM层含有50个神经元。如果你需要更好的性能,可以使用更深、更宽的网络。

# 构建模型
model = Sequential()
model.add(LSTM(50, activation='relu', return_sequences=True, input_shape=(n_steps_in, n_features)))
model.add(LSTM(50, activation='relu'))
model.add(Dense(n_steps_out))
model.compile(optimizer='adam', loss='mse')

4. 训练模型

接下来,我们需要训练模型。LSTM在序列数据中表现良好,但是它也需要更长的时间来训练。我们在训练过程中使用了EarlyStopping和ModelCheckpoint回调函数,以便在损失不再下降时停止训练并保存最好的模型。

# 训练模型
es = EarlyStopping(monitor='val_loss', patience=10)
mc = ModelCheckpoint('best_model.h5', monitor='val_loss', save_best_only=True)
history = model.fit(trainX, trainY, epochs=100, batch_size=64, validation_data=(testX, testY), callbacks=[es, mc])

5. 测试模型

最后,我们需要使用测试集验证模型的性能。我们可以使用模型的predict()方法来进行预测,并将预测结果与实际值进行比较。

# 测试模型
model = load_model('best_model.h5')
predictions = model.predict(testX)
predictions = scaler.inverse_transform(predictions)

# 计算R²分数
r2score = r2_score(testY, predictions)
print('R²分数:{}'.format(r2score))

三、小结

本文介绍了LSTM时间序列预测的基础原理和Keras实现。在实现过程中,我们对数据进行了准备和标准化处理,创建了LSTM模型,并使用训练集进行训练,最后使用测试集进行预测。LSTM相比于其他算法具有更好的性能和更强的泛化能力,可广泛应用于自然语言处理、时间序列预测等领域。

原创文章,作者:AAHDD,如若转载,请注明出处:https://www.506064.com/n/369397.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
AAHDDAAHDD
上一篇 2025-04-12 13:01
下一篇 2025-04-12 13:01

相关推荐

  • 解决docker-compose 容器时间和服务器时间不同步问题

    docker-compose是一种工具,能够让您使用YAML文件来定义和运行多个容器。然而,有时候容器的时间与服务器时间不同步,导致一些不必要的错误和麻烦。以下是解决方法的详细介绍…

    编程 2025-04-29
  • Python序列的常用操作

    Python序列是程序中的重要工具,在数据分析、机器学习、图像处理等很多领域都有广泛的应用。Python序列分为三种:列表(list)、元组(tuple)和字符串(string)。…

    编程 2025-04-28
  • 想把你和时间藏起来

    如果你觉得时间过得太快,每天都过得太匆忙,那么你是否曾经想过想把时间藏起来,慢慢享受每一个瞬间?在这篇文章中,我们将会从多个方面,详细地阐述如何想把你和时间藏起来。 一、一些时间管…

    编程 2025-04-28
  • 计算斐波那契数列的时间复杂度解析

    斐波那契数列是一个数列,其中每个数都是前两个数的和,第一个数和第二个数都是1。斐波那契数列的前几项为:1,1,2,3,5,8,13,21,34,…。计算斐波那契数列常用…

    编程 2025-04-28
  • 时间戳秒级可以用int吗

    时间戳是指从某个固定的时间点开始计算的已经过去的时间。在计算机领域,时间戳通常使用秒级或毫秒级来表示。在实际使用中,我们经常会遇到需要将时间戳转换为整数类型的情况。那么,时间戳秒级…

    编程 2025-04-28
  • 如何在ACM竞赛中优化开发时间

    ACM竞赛旨在提高程序员的算法能力和解决问题的实力,然而在比赛中优化开发时间同样至关重要。 一、规划赛前准备 1、提前熟悉比赛规则和题目类型,了解常见算法、数据结构和快速编写代码的…

    编程 2025-04-28
  • 使用JavaScript日期函数掌握时间

    在本文中,我们将深入探讨JavaScript日期函数,并且从多个视角介绍其应用方法和重要性。 一、日期的基本表示与获取 在JavaScript中,使用Date对象来表示日期和时间,…

    编程 2025-04-28
  • Python整数序列求和

    本文主要介绍如何使用Python求解整数序列的和,给出了多种方法和示例代码。 一、基本概念 在Python中,整数序列指的是一组整数的集合,可以使用列表(list)或元组(tupl…

    编程 2025-04-27
  • Python序列最大值的实现方法

    本篇文章主要介绍如何使用Python寻找序列中的最大值,在文章中我们将通过多个方面,详细阐述如何实现。 一、Python内置函数max() 使用Python内置函数max()可以快…

    编程 2025-04-27
  • Java Date时间大小比较

    本文将从多个角度详细阐述Java中Date时间大小的比较,包含了时间字符串转换、日期相减、使用Calendar比较、使用compareTo方法比较等多个方面。相信这篇文章能够对你解…

    编程 2025-04-27

发表回复

登录后才能评论