论文结果难复现?本文教你完美实现深度强化学习算法DQN

百家 作者:机器之心 2017-11-24 04:34:05

选自arXiv

作者:Melrose Roderick等

机器之心编译


论文的复现一直是很多研究者和开发者关注的重点,近日有研究者详细论述了他们在复现深度 Q 网络所踩过的坑与训练技巧。本论文不仅重点标注了实现过程中的终止条件和优化算法等关键点,同时还讨论了实现的性能改进方案。机器之心简要介绍了该论文,更详细的实现细节请查看原论文。


过去几年来,深度强化学习逐渐流行,因为它在有超大状态空间(state-spaces)的领域上要比先前的方法有更好的表现。DQN 几乎在所有的游戏上超越了之前的强化学习方法,并在大部分游戏上比人类表现更好。随着更多的研究人员用深度强化学习方法解决强化学习问题,并提出替代性算法,DQN 论文的结果经常被用作展示进步的基准。因此,实现 DQN 算法对复现 DQN 论文结果和构建原算法都很重要。


我们部署了一个 DQN 来玩 Atari 游戏并重复 Mnih 等人的结果。我们的实现要比原始实现快 4 倍,且已经在网上开源。此外,该实现在设计上,对不同的神经网络架构、ALE 之外领域也更为灵活。在重复这些结果时,我们发现实现这些系统的过程的几个关键。在这篇论文中,我们强调了一些关键的技术,这些技术对于获得优良的性能和重复 Mnih 等人的结果是很基本的,其中包括了终止条件和梯度下降优化算法,以及算法的期望结果(也就是网络的性能波动)。


论文:Implementing the Deep Q-Network



论文地址:https://arxiv.org/abs/1711.07478


Mnih 等人在 2015 年提出的深度 Q 网络已经成为了一项基准,也是许多深度强化学习研究的基点。然而,复现复杂系统的结果总是非常难,因为最初的文献经常无法详细描述每个重要的参数和软件工程的解决方案。在此论文中,我们复现了 DQN 的论文结果。此外,我们重点标注了实现过程中的关键点,从而让研究人员能更容易地复现结果,包括终止条件、梯度下降算法等。而这些点是原论文没有详细描述的。最后,我们讨论了改进计算性能的方法,并给出我们的实现,该实现可广泛应用,而不是只能在原论文中的 Arcade 学习环境(ALE)中实现。


3 深度 Q 学习


深度 Q 学习(DQN)是经典 Q 学习算法的变体,有 3 个主要贡献:(1)深度卷积神经网络架构用于 Q 函数近似;(2)使用小批量随机训练数据而不是在上一次经验上进行一步更新;(3)使用旧的网络参数来评估下一个状态的 Q 值。DQN 的伪代码(复制自 Mnih et al. [2015])见算法 1。深度卷积架构提供一个通用机制从图像帧的短历史(尤其是最后 4 帧)中评估 Q 函数的值。后面两个贡献主要关于如何使迭代的 Q 函数估计保持稳定。



监督式深度学习研究中,在小批量数据上执行梯度下降通常是一种高效训练网络的方式。在 DQN 中,它扮演了另外一个角色。具体来说,DQN 保存大量最近经验的历史,每个经验有五个元组(s, a, s', r, T):智能体在状态 s 执行动作 a,然后到达状态 s',收到奖励 r;T 是一个布尔值,指示 s'是否为最终状态。在环境中的每一步之后,智能体添加经验至内存。在少量步之后(DQN 论文使用了 4 步),智能体从内存中进行小批量随机采样,然后在上面执行 Q 函数更新。在 Q 函数更新中重用先前的经验叫作经验回放(experience replay)[Lin, 1992]。但是,尽管强化学习中的经验回放通常用于加快奖励备份(backup of rewards),DQN 从内存中进行小批量完全随机采样有助于去除样本和环境的相关性,否则容易引起函数近似估计中出现偏差。


最终的主要贡献是使用旧的网络参数来评估一个经验中下一个状态的 Q 值,且只在离散的多步间隔(many-step interval)上更新旧的网络参数。该方法对 DQN 很有用,因为它为待拟合的网络函数提供了一个稳定的训练目标,并给予充分的训练时间(根据训练样本数量决定)。因此,估计误差得到了更好地控制。


尽管这些贡献和整体算法在概念层面上是很直接的,但要想达到 Mnih et al. [2015] 报告中相同的性能水平需要考虑大量重要细节,设计者必须牢记学习过程的重要特性。下文将具体描述细节。


3.1 实现细节


由于原始的科研文献经常无法提供重要参数设置和软件工程解决方案的细节,因此,很多大型系统(比如 DQN)都难以实现。因此,DQN 论文并没有明确地提及或完整地说明一些重要的算法基础细节。本文,我们将强调其中一些额外的关键实现细节(根据原论文的 DQN 代码总结)。


首先,每一个 episode 从随机数量(0 到 30 之间)的「No-op」低级别 Atari 动作开始(相对于将智能体的动作(action)重复 4 个帧),以抵消智能体所看见的帧,这是因为智能体每次只能看到 4 个 Atari 帧。类似地,用作 CNN 输入的 m 个帧历史是智能体最后看见的 m 个帧,而不是最后的 m 个 Atari 帧。此外,在使用梯度下降迭代之前,我们会执行 50000 步的随机策略作为补充经验以避免对早期经验的过拟合。


另一个值得注意的参数是网络更新频率(network update frequency)。原始的 DQN 实现仅在算法的每 4 个环境步骤后执行一个梯度下降步骤,这和算法 1 截然不同(每一个环境步骤执行一个梯度下降步骤)。这不仅仅大大加快了训练速度(由于网络学习步骤的计算量比前向传播大得多),还使得经验内存更加相似于当前策略的状态分布(由于训练步骤之间需要添加 4 个新的帧到内存中,这和添加 1 个帧是截然不同的),可能有防止过拟合的作用。


3.2 DQN 的性能波动(fluctuating performance)


图 1 展示了最佳网络和最差网络(在 Breakout 游戏的开始阶段使用相同的输入启动训练)的 Q 值近似。第一帧展示了这样的场景:智能体可以采取任意的动作,都不会使球在经过未来几个动作之后就掉落。但是在弹回球之前做出的动作也可以帮助智能体瞄准球的位置。这个例子中两个网络的 Q 值是很相近的,但是各自选择的动作是不同的。在第二帧的场景中,假如智能体没有采取向左移动的动作,球就会掉落,游戏终止。在这个例子中,两个网络的 Q 值差别是很大的。因此,当运行这个算法的时候,可能会出现这种性能波动。


图 1:经过 3000 万步的训练后进行测试,最优和最差网络的 Q 值对比。阴影线区域代表 Q 值最高的动作。最上面一帧对应动作对近期奖励无显著影响的情况,底部帧代表必须执行左侧动作以免损失生命的情况。「Release」动作指在每局开始的时候释放球,或当球已经开始运动时什么也不做(和「无操作」(No-op)一样)。


5 结果


我们的结果与 DQN 论文关于 Pong、Breakout 和 Seaquest 的结果对比见表 1。我们的实现中每个训练过程大约用时 3 天,而我们配置的原始实现用时大约 10.5 天。


表 1:我们的 DQN 实现和原 DQN 论文获得的平均游戏分数的对比。


6 核心训练技巧


我们在实现 DQN 时,发现了只在 DQN 论文中简要提及的两种方法,但是它们对算法的整体表现至关重要。下面我们将展示这两种方法,并解释为什么它们对网络训练的影响如此之大。


6.1 掉命终止


绝大多数 Atari 游戏中,玩家都有几条「命」,对应游戏结束之前玩家可以失败的次数。为了提升表现,Mnih et al. [2015] 选择在训练中把生命数的损失(在涉及生命数的游戏中)作为 MDP 的最终状态。这一终止条件在 DQN 论文中没有提及太多,但却对提升性能至关重要。


图 2 展示了在 Breakout 和 Seaquest 中,把和不把生命数损失作为最终状态的区别。在 Breakout 中,使用生命数的结束作为最终状态的学习器的平均分值增长要远快于另一个学习器。但是,训练大约进行一半时,另一个学习器获得相似表现,却带有更高的方差。Seaquest 是一个更为复杂的游戏,其中使用生命数作为最终状态的学习器在整个训练中表现要远好于另一个学习器。这些图表明这一额外的先验信息非常有利于早期训练和稳定性,并在更复杂的游戏中显著提升了整体表现。


如上所述,MDP 中的最终状态意味着智能体无法再获得更多奖励。几乎所有的 Atari 游戏给出正面奖励,因此这一附加信息很关键地告知智能体无论如何都要避免失去生命数,这看起来确实很理性:很多玩家一开始就知道在 Atari 中损失生命数很糟糕,并且很难想象出其中最优策略是失去生命数的场景。


但是,执行该约束存在多个理论问题。首先,由于初始状态分布依赖于当前策略,该过程将不再是马尔科夫性质的。一个相关的例子是在 Breakout 游戏中:如果智能体表现很好,在失去一条生命之前破坏了很多砖,则新生命的初始状态拥有的砖,会比智能体在上个生命中表现不好、破坏不多砖时更少。另一个问题是该信号为 DQN 提供了很强的额外信息,从而使扩展至没有强信号的领域变得困难(如现实机器人或开放性更强的电子游戏)。


图 2:Breakout 和 Seaquest 在每个测试集上使用命数和游戏结束作为最终状态时分别得到的平均训练测试分数(epoch = 250,000 steps)。


ALE 为每个游戏保存了剩余生命数,但它没有向所有界面提供这个信息。为了解决这个局限,我们修改了 ALE 的 FIFO 界面以在屏幕上提供剩余生命数、奖励和最终状态布尔值的信息。我们的 fork 在 FIFO 界面上提供了该数据,大家可在线免费访问。


6.2 梯度下降优化算法


在使用 Mnih et al. [2015] 所提供的超参数时,我们会遇到一个潜在问题,即原论文并不是直接使用许多深度学习库(如 Caffe)所定义的 RMSProp 优化算法。RMSProp 梯度下降优化算法最初是由 Gerffrey Hinton 所提出来的,Hinton 的 RMSProp 针对每个参数保持一个滑动平均(running average)梯度。这种滑动平均梯度的更新规则可以写为:



其中,w 对应单个网络的参数,γ 为梯度衰减参数,E 为经验损失。参数的更新过程可以写为:



其中α为学习率,ε为非常小的常量以避免分母为零。


即使 Mnih et al. [2015] 引用了 Hinton 的 RMSProp,但他们使用的最优化算法仍然略有不同。这个不同点可以在他们的 GitHub 中找到(以下地址),即在 NeuralQLearner.lua 文件的第 266 行到 273 行代码中。该变体将动量因子加入到了 RMSProp 算法中,因此梯度的更新规则可以写为:



Mnih 等人实现地址:www.github.com/kuz/DeepMind-Atari-Deep-Q-Learner


其中η为动量衰减因素,参数的更新规则可以写为:



为了解决优化算法中的这种大幅变化,我们必须将学习率修改为远低于 Mnih et al. [2015] 在实现中设定的学习率,即将他们的 0.00025 修改为 0.00005。我们并没有选择实现这种 RMSProp 变体,因为用 Java-Caffe 捆绑包实现是很重要的,且 Hinton 的一般 RMSPorp 算法产生了类似的效果。


7 性能加速


我们的实现要比原论文使用 Lua 和 Torch 的实现快 4 倍,且测试这些实现的配置是两张 NVIDIA GTX 980 TI 显卡和一个 Intel i7 处理器。我们性能的提升很大部分可以归因于 cuDNN 库的帮助,我们在训练过程中以每秒约 985 Atari 帧(fps)的速度进行,测试中以每秒约 1584 帧(fps)的速度进行。


我们使用了 cuDNN 进行实验,而 Lua 并没有在 Torch 中使用该加速库。为了完成对比,我们在没有使用 cuDNN 的 Caffe 上训练和测试时,速度分别为 268fps 和 485fps。这要比原论文 Lua 实现慢一些。


8. 结论


为了让研究人员更好地实现自己的 DQN,我们在此论文中展现了实现 Mnih 等人提出的 DQN 时的关键点,这些关键点对此算法的整体表现极为重要,但在原论文中却没有提到,以帮助研究者更容易地实现该算法的个人版本。我们也重点标注了在灾难性遗忘(catastrophic forgetting)这样的大型状态空间中用 CNN 逼近 Q 函数时的难点。之后,我们把自己的实现开源到了网上,也鼓励研究人员使用它实现全新的算法,并与 Mnih 等人的结果做比较。 


本论文的 GitHub 实现地址:https://github.com/h2r/burlap_caffe



点击「阅读原文」,参与 NIPS 2017 线上分享第三期。

关注公众号:拾黑(shiheibook)了解更多

[广告]赞助链接:

四季很好,只要有你,文娱排行榜:https://www.yaopaiming.com/
让资讯触达的更精准有趣:https://www.0xu.cn/

公众号 关注网络尖刀微信公众号
随时掌握互联网精彩
赞助链接