Paper Note: Remembering for the Right Reasons: Explanations Reduce Catastrophic Forgetting
Remembering for the Right Reasons: Explanations Reduce
Catastrophic Forgetting
Sayna Ebrahimi, Suzanne Petryk, Akash Gokul, William Gan, Joseph E. Gonzalez, Marcus Rohrbach, trevor darrell
ICLR 2021
Main Idea
我们认为,“灾难性遗忘”现象的出现部分原因是由于对于先前的观察,不能再依赖于与之前相同的推理。因此,我们假设当模型被鼓励记住先前做出决定时的证据(evidence)时,可以减少遗忘。换句话说,一个模型可以记住其最终决策并可以重建相同的先验推理。基于这种方法,我们开发了一种新颖的策略来利用可解释的模型来提高推测表现。
在 XAI (explainable artificial intelligence) 中提出的各种可解释性技术中,显着性方法(saliency methods)已成为一种流行的工具,以此找出输入中的相关特征对模型预测的支持。 这些方法生成 saliency maps(定义为网络在其上做出决策的视觉证据区域)。 我们的目标是研究用解释重放(explanation replay)增强的经验重放(experience replay)是否可以减少遗忘,以及强制记住解释将如何影响解释本身。Figure 1说明了我们提出的方法。

Figure 1
在这项工作中,我们提出了RRR,一种可以由任意白盒(white-box)可微解释方法生成的模型解释指导下的训练策略;RRR为持续学习添加了一个解释性损失项(explanation loss)。 白盒方法通过使用模型的某些内部状态(比如梯度)来产生解释,从而使其可以在端到端训练中使用。我们将RRR方法组合到了几种SOTA的CIL (class incremental learning) 方法当中,包括 iTAML (Rajasegaran et al., 2020)、EEIl (Castro et al., 2018)、BiC (Wu et al., 2019)、TOPIC (Tao et al., 2020)、iCaRL (Rebuffi et al., 2017)、EWC (Kirkpatrick et al., 2017) 以及 LwF (Li & Hoiem, 2016) 当中。 请注意,RRR在测试时不需要任务ID。 接下来,我们以 saliency maps 的形式对模型解释进行了定性和定量分析,并证明 RRR 由于需要聚焦于正确的证据(evidence),因而记住了其在一系列任务中的早期决策。另外,文章中的实验表明 RRR 对于 experience replay 以及其它 memory-based 的方法都能够提高准确率和避免遗忘。
我们的贡献是三方面的:我们首先提出了一种新颖、简单但有效的 memory constraint,我们将其称为 “Remembering for the Right Reasons”(RRR),并通过鼓励模型进行决策时,采用与最初发现的相同的解释,来减少“灾难性遗忘”。其次,我们展示了如何将RRR轻松地与 memory-based 和regularization-based 的持续学习方法结合起来以提高性能。第三,我们展示了引导 continual learner 记住其解释可以提高解释本身的质量;即,模型在做出正确的决策时会聚焦于图像中的正确区域,而在对对象进行错误分类时会将其最大的注意力置于背景上。
Method
考虑以下持续学习场景:有按时间序列到来的 个任务,对应的数据分布为 ,其中 是任务 的数据分布,由 个样例元组构成。目标是为每个任务依时间顺序习得模型 ,并且保持模型对于之前任务的性能。我们希望通过使用 memory 来增强知识迁移、避免“灾难性遗忘”,以此实现上述目标。我们假设有两个有限大小的 memory pools: 用于存储原始样例, 用于存储模型解释,即由在 Sec 2 中讨论的解释性方法()之一基于 为 中图像生成的 saliency maps,其中 是 为任务 所训练的。我们使用single-head结构,在推断时不需要任务ID 。
当完成第 个任务时,我们从每个任务的训练数据中随机选择 个样例,并以此更新我们的 replay buffer memory 。RRR 使用基于所存储样例的模型解释来进行连续学习,从而使模型保留其对先前观察到的观测结果的推理。我们探索了几种解释性技术,使用 为 replay buffer 中存储的样例计算 saliency maps,以填充 。所存储的 saliency maps 将在对未来任务的学习期间用作参照解释,以防止模型参数被更改而导致对相同样例的不同推理。我们使用在训练完一个新的任务之后生成的 saliency maps 与所存储的对应的 reference evidence 之间的残差的 loss 项来实现 RRR。
其中 表示使用在最后的任务 上的数据训练所得的模型来计算 saliency maps 的解释性方法, 是在学习完 之前的每个任务之后,由 生成的 reference saliency maps,存储在 当中。
论文随后展示了将 RRR 组合到 SOTA 的 memory-based 和 regularization-based 方法的目标函数当中,可以改善性能。应当注意,选择大的 可能会阻碍参数学习新任务。