栏目分类:
子分类:
返回
文库吧用户登录
快速导航关闭
当前搜索
当前分类
子分类
实用工具
热门搜索
文库吧 > IT > 软件开发 > 后端开发 > Python

在Keras中为截断的BPTT准备序列预测的迷你教程

Python 更新时间: 发布时间: IT归档 最新发布 模块sitemap 名妆网 法律咨询 聚返吧 英语巴士网 伯小乐 网商动力

在Keras中为截断的BPTT准备序列预测的迷你教程

文章目录
  • 内容介绍
  • 通过时间截断反向传播
  • TBPTT的Keras实现
  • 在 Keras 中为 TBPTT 准备序列数据

内容介绍

循环神经网络能够学习序列预测问题中跨多个时间步长的时间依赖性。

现代循环神经网络,如长短期记忆或 LSTM,网络是用反向传播算法的变体进行训练的,称为反向传播时间。该算法已被进一步修改,以提高非常长序列的序列预测问题的效率,称为Truncated Backpropagation Through Time。

使用截断反向传播训练循环神经网络(如 LSTM)时的一个重要配置参数是决定使用多少时间步作为输入。也就是说,如何准确地将很长的输入序列拆分为子序列以获得最佳性能。

在这篇文章中,您将发现 6 种不同的方法,您可以拆分非常长的输入序列,以使用 Keras 在 Python 中使用截断反向传播有效地训练循环神经网络。

通过时间截断反向传播

反向传播是一种训练算法,用于更新神经网络中的权重,以最小化给定输入的预期输出和预测输出之间的误差。

对于观察之间存在顺序依赖性的序列预测问题,使用循环神经网络代替经典的前馈神经网络。循环神经网络使用反向传播算法的变体进行训练,该算法称为 Backpropagation Through Time,简称 BPTT。

实际上,BPTT 展开循环神经网络,并在整个输入序列上向后传播错误,一次一个时间步。然后用累积的梯度更新权重。

BPTT 在输入序列很长的问题上训练循环神经网络可能很慢。除了速度之外,在如此多的时间步长上累积梯度可能会导致值缩小到零,或者最终溢出或爆炸的值增长。

BPTT 的一个修改是限制反向传播使用的时间步数,实际上估计用于更新权重的梯度而不是完全计算它。

这种变化称为时间截断反向传播,或 TBPTT。

TBPTT训练算法有两个参数:

  • k1:定义在前向传递中向网络显示的时间步数。
  • k2:定义在估计反向传播的梯度时要查看的时间步数。

因此,我们可以在考虑如何配置训练算法时使用符号 TBPTT(k1, k2),其中 k1 = k2 = n,其中 n 是经典非截断 BPTT 的输入序列长度。

TBPTT 配置对 RNN 序列模型的影响
像 LSTM 这样的现代循环神经网络可以使用它们的内部状态来记住很长的输入序列。例如超过数千个时间步。

这意味着 TBPTT 的配置不一定定义您正在通过选择时间步数优化的网络的内存。您可以选择何时将网络的内部状态与用于更新网络权重的机制分开重置。

相反,TBPTT 参数的选择会影响网络如何估计用于更新权重的误差梯度。更一般地说,配置定义了可以考虑网络来模拟序列问题的时间步数。

我们可以将其正式表述为:

yhat(t) = f(X(t), X(t-1), X(t-2), ... X(t-n))

其中 yhat 是特定时间步长的输出,f(…) 是循环神经网络逼近的关系,X(t) 是特定时间步长的观测值。

它与在时间序列问题上训练的多层感知器上的窗口大小或线性时间序列模型(如 ARIMA)的 p 和 q 参数在概念上相似(但在实践中完全不同)。TBPTT 定义了模型在训练期间输入序列的范围。

TBPTT的Keras实现

Keras 深度学习库提供了用于训练循环神经网络的 TBPTT 实现。

实现比上面列出的通用版本更受限制。

具体来说,k1 和 k2 值彼此相等且固定。

  • TBPTT(k1, k2),其中 k1 = k2

这是通过训练循环神经网络(如长短期记忆网络或 LSTM)所需的固定大小的 3D 输入来实现的。

LSTM 期望输入数据具有以下维度:样本、时间步长和特征。

这是此输入格式的第二个维度,时间步长定义了用于序列预测问题的前向和后向传递的时间步长数。

因此,在为 Keras 中的序列预测问题准备输入数据时,必须仔细选择指定的时间步数。

时间步长的选择将影响两者:

  • 前向传递期间累积的内部状态。
  • 用于更新反向传播权重的梯度估计。

注意,默认情况下,网络的内部状态会在每批后重置,但可以通过使用所谓的有状态 LSTM 并手动调用重置操作来实现对内部状态何时重置的更明确的控制。

在 Keras 中为 TBPTT 准备序列数据

分解序列数据的方式将定义 BPTT 向前和向后传递中使用的时间步数。

因此,您必须仔细考虑如何准备训练数据。

本节列出了您可以考虑的 6 种技术。

1. 按原样使用数据

如果每个序列中的时间步数不多,例如几十或几百个时间步长,您可以按原样使用您的输入序列。

已经建议了大约 200 到 400 个时间步长的 TBPTT 的实际限制。

如果您的序列数据小于或等于此范围,您可以将序列观测值重塑为输入数据的时间步长。

例如,如果您有 25 个时间步长的 100 个单变量序列的集合,则可以将其重构为 100 个样本、25 个时间步长和 1 个特征或 [100, 25, 1]。

2. 朴素的数据拆分

如果您有很长的输入序列,例如数千个时间步长,您可能需要将长输入序列分解为多个连续的子序列。

这将需要在 Keras 中使用有状态的 LSTM,以便在子序列的输入中保留内部状态,并且仅在真正更完整的输入序列的末尾重置。

例如,如果您有 50,000 个时间步长的 100 个输入序列,那么每个输入序列可以分为 500 个时间步长的 100 个子序列。一个输入序列将变成 100 个样本,因此 100 个原始样本将变成 10,000 个。Keras 的输入维度为 10,000 个样本、500 个时间步长和 1 个特征或 [10000, 500, 1]。需要注意保存每 100 个子序列的状态,并在每 100 个样本后明确地或使用 100 的批量大小重置内部状态。

将整个序列整齐地划分为固定大小的子序列的拆分是首选。全序列的因子(子序列长度)的选择是任意的,因此得名“naive data split”。

将序列拆分为子序列并没有考虑有关合适数量的时间步长的域信息来估计用于更新权重的误差梯度。

3. 特定领域的数据拆分

很难知道提供有用的误差梯度估计所需的正确时间步数。

我们可以使用朴素的方法(上图)快速得到一个模型,但模型可能远未优化。

或者,我们可以使用特定领域的信息来估计在学习问题时与模型相关的时间步数。

例如,如果序列问题是回归时间序列,也许对自相关和偏自相关图的回顾可以告知时间步数的选择。

如果序列问题是自然语言处理问题,也许可以将输入序列按句子分割然后填充到固定长度,或者根据域中的平均句子长度进行分割。

广泛思考并考虑您可以使用哪些特定于您的领域的知识将序列分成有意义的块。

4. 系统数据拆分(例如网格搜索)

您可以针对序列预测问题系统地评估一组不同的子序列长度,而不是猜测合适的时间步数。

您可以对每个子序列长度执行网格搜索,并采用导致平均性能最佳模型的配置。

如果您正在考虑使用这种方法,请注意以下几点:

从作为完整序列长度因子的子序列长度开始。
如果探索不是完整序列长度的因素的子序列长度,请使用填充和可能的掩码。
考虑使用比解决问题所需的稍微过度规定的网络(更多的记忆单元和更多的训练时期),以帮助排除网络容量限制您的实验。
取每种不同配置多次运行(例如 30 次)的平均性能。
如果计算资源不是限制,则建议对不同数量的时间步进行系统调查。

5. 使用 TBPTT(1, 1) 严重依赖内部状态

您可以将序列预测问题重新表述为每个时间步长一个输入和一个输出。

例如,如果您有 50 个时间步长的 100 个序列,则每个时间步长都将成为一个新样本。100 个样本将变成 5,000 个。三维输入将变为 5,000 个样本、1 个时间步长和 1 个特征,或 [5000, 1, 1]。

同样,这将需要在序列的每个时间步长内保留内部状态,并在每个实际序列(50 个样本)结束时重置。

这会将学习序列预测问题的负担放在循环神经网络的内部状态上。根据问题的类型,它可能超出网络的处理能力,预测问题可能无法学习。

个人经验表明,这种公式可能适用于需要对序列进行记忆的预测问题,但当结果是过去观察的复杂函数时表现不佳。

6.解耦前向和后向序列长度

Keras 深度学习库用于支持通过时间截断反向传播的前向和后向传递的解耦数量的时间步长。

本质上,k1 参数可以由输入序列的时间步数指定,k2 参数可以由 LSTM 层上的“truncate_gradient”参数指定。

这不再受支持,但有些人希望将此功能重新添加到库中。它是目前还不清楚究竟为什么它被删除,但有证据表明这是出于效率的考虑做。

您可以在 Keras 中探索这种方法。一些想法包括:

安装并使用支持“truncate_gradient”参数的旧版 Keras 库(大约 2015 年)。
在 Keras 中扩展 LSTM 层实现以支持“truncate_gradient”类型的行为。
也许有第三方扩展可用于支持这种行为的 Keras。

转载请注明:文章转载自 www.wk8.com.cn
本文地址:https://www.wk8.com.cn/it/280172.html
我们一直用心在做
关于我们 文章归档 网站地图 联系我们

版权所有 (c)2021-2022 wk8.com.cn

ICP备案号:晋ICP备2021003244-6号