散光弱视

首页 » 常识 » 诊断 » 动物与人类的关键学习期,深度神经网络也有
TUhjnbcbe - 2023/6/21 19:28:00

机器之心分析师网络

作者:Jiying

编辑:H4O

在这篇文章中,作者提出了这样一个概念:对于深度神经网络来说,与动物和人类的学习过程类似,其对于技能的学习过程也存在一个“关键学习期”。

0引言

我们这篇文章讨论的问题是根据ICLR中的一篇文章而来:《CRITICALLEARNINGPERIODSINDEEPNETWORKS》[1]。在这篇文章中,作者提出了这样一个概念:对于深度神经网络来说,与动物和人类的学习过程类似,其对于技能的学习过程也存在一个“关键学习期”。从生物学角度来看,关键期(criticalperiods)是指出生后早期发育的时间窗口,在这期间,感知缺陷可能导致永久性的技能损伤。生物学领域的研究人员已经发现并记录了影响一系列物种和系统的关键期,包括小猫的视力、鸟类的歌曲学习等等。对于人类来说,在视觉发育的关键时期,未被矫正的眼睛缺陷(如斜视、白内障)会导致1/50的成人弱视。

生物学领域的研究人员已经确定,人类或动物存在关键期的原因是对神经元可塑性窗口的生物化学调控(thebiochemicalmodulationofwindowsofneuronalplasticity)[2]。神经网络最早起源就是期望模拟人脑神经元的工作模式,Achille等在文献[1]中证明了深度神经网络对感觉缺陷的反应方式与在人类和动物模型中观察到的类似。在动物模型中最终造成的损害的程度取决于缺陷窗口的开始(onset)和长度(length),而在神经网络中则取决于神经网络的大小。不过,在神经网络中,缺陷并不会影响低层次的统计特征,如图像的垂直翻转,同时对性能并没有持久的影响,以及最终可以通过进一步的训练来克服。这一研究发现引发了作者的思考,他认为,深度神经网络学习中存在的“关键期”可能来自于信息处理,而不是生化现象[1]。这一发现最终引发了本文所讨论的问题,即DNNs中的关键学习期现象。

与此类似,我们也看到了其它一些讨论相关问题的文章。当然,这些文章并没有从“关键期”的角度来讨论这个问题,只不过其所揭示的规律与[1]中关于DNNs中的关键期现象的规律非常相似,主要探讨的是深度神经网络训练早期阶段的问题,即在深度神经网络的训练过程中,早期阶段与其它阶段具有不同的“特点”。由于这些研究能够从另外的角度证实DNNs中存在“关键学习期”,所以我们也将它们纳入到本文的讨论中。

例如,来自纽约大学等多家大学和研究机构的研究人员的工作《TheBreak-EvenPointonOptimizationTrajectoriesofDeepNeuralNetworks》[5],提出了一种模拟DNNs早期训练轨迹的简化模型。作者表示,损失面的局部曲率(Hessian的频谱范数)沿DNNs优化轨迹单调地增加或减少。梯度下降在DNNs训练早期阶段会最终达到一个点,在这个点上梯度下降会沿着损失面的最弯曲方向振动,这一点称为损益平衡点(break-evenpoint)。此外,来自Princeton大学和Google大脑团队的研究人员发表的《TheSurprisingSimplicityoftheEarly-TimeLearningDynamicsofNeuralNetworks》[4]指出,可以通过训练一个简单模型来模仿双层全连接神经网络早期学习阶段的梯度下降动态变化。当只训练第一层时,这个简单的模型是输入特征的线性函数;当训练第二层或两层时,它是特征和其L2-norm的线性函数。这一结果意味着,神经网络直到训练后期才会完全发挥其非线性能力。最后一篇文章发表在PLOSCOMPUTATIONALBIOLOGY中,提出了一个模仿人类视觉系统行为的前馈卷积网络,作者具体分析了分析了不同层次的网络表征(virtualfMRI),并研究了网络容量(即单元数量)对内部表征的影响。

1深度网络中的关键学习期[1]

1.1问题阐述

一个非常著名的影响人类的关键期缺陷的示例是人类在婴儿期或儿童期白内障引起的弱视(一只眼睛的视力下降)[6]。即使在手术矫正白内障后,患者恢复患眼正常视力的能力也取决于视力缺陷的持续时间和发病年龄,早期和长期的视力缺陷会造成更严重的影响。本文的的目标是研究DNN中类似缺陷的影响。为此,作者训练了一个标准的All-CNN架构,对CIFAR-10数据库中的32x32大小图像中的物体进行分类。实验中使用SGD进行训练。为了模拟白内障的影响,在最初的t_0个epoch中,数据库中的图像被下采样为8x8大小,然后使用双线性插值上采样为32x32大小以得到模糊处理的图像,破坏了小尺度图像细节。之后,继续训练个epoch以确保网络收敛,并确保它能够得到与对照组(t_0=0)实验中相同数量的未损坏的图像。

图1给出了受缺陷影响的网络的最终性能,具体的,将该性能展示为纠正缺陷epocht_0的函数。我们可以很容易地从图1中观察到一个关键时期的存在。如果在最初的40-60个epoch中没有去除模糊,那么与基线方法相比,最终的性能会严重下降(误差最多会增加三倍)。这种性能的下降遵循在动物身上普遍观察到的趋势,例如早期研究中证实的在小猫出生后被剥夺单眼的情况下观察到的视觉敏锐度的损失与缺陷的长度有关[7]。

图1.DNN中显示出的关键期

由上述实验给出的结果人们很自然地会提问:是否输入数据分布的任何变化都会有一个相应的学习关键期?作者表示,对于神经元网络来说,情况并非如此,它们有足够的可塑性来适应感觉处理(sensoryprocessing)的high-level变化。例如,成年人类能够迅速适应某些剧烈的变化,如视野的倒置。在图2中,我们观察到DNN也基本上不受high-level缺陷的影响—比如图像的垂直翻转或输出标签的随机排列。在缺陷修正之后,网络很快就恢复了它的基线性能。这暗示了数据分布的结构和优化算法之间存在更精细的相互作用,进而导致存在一个关键期。

接下来,作者对网络施加了一个更激烈的缺陷攻击,令每个图像都被白噪声取代。图2显示,这种极端的缺陷所表现出的效果明显比只模糊图像所得到的效果要轻。用白噪声训练网络并不会提供任何关于自然图像的信息,因此,与其它缺陷(例如,图像模糊)相比,白噪声的效果更温和。不过,白噪声中包含了一些信息,从而导致网络(错误地)学习图像中并没有存在的精细结构。

图2.(左)High-level的扰动并不会导致关键期。当缺陷只影响high-level特征(图像的垂直翻转)或CNN的最后一层(标签互换)时,网络不会表现出关键期(测试准确度基本保持平稳)。另一方面,类似于感知剥夺的缺陷(图像被随机噪声取代)确实会导致缺陷,但其影响没有图像模糊的情况那么严重。(右)关键期曲线对网络深度的依赖情况。添加更多的卷积层会增大关键期缺陷的影响。

图3显示,在MNIST库上训练的全连接网络也存在图像模糊缺陷的关键期。因此,作者认为(对于重现模型训练的关键期)卷积结构不是必需的,使用自然图像也不是必需的。同样,在CIFAR-10上训练的ResNet-18也有一个关键期,它也比标准卷积网络中的关键期明显更清晰(图1)。作者分析,ResNets允许梯度更容易地反向传播到低层,其关键期的存在可以表明关键期不是由梯度消失引起的。图2(右)显示,关键期的存在确实关键地取决于网络的深度。在图3中,作者确认,即使在网络以恒定的学习速率训练时,也存在一个关键期。图3(右下角)显示,当使用Adam作为优化器时,使用其前两个时刻的运行平均值对梯度进行重归一化,我们仍然观察到一个与标准SGD类似的关键期。改变优化的超参数可以改变关键期的形状:图3(左下角)显示,增加权重衰减(weightdecay)使关键期更长,更不尖锐。这可以解释为它既减慢了网络的收敛速度,又限制了high-level为克服缺陷而改变的能力,从而鼓励low-level也学习新特征。

图3.不同DNN架构和优化方案中的关键期

1.2Fisher信息分析

作者根据经验确定,在动物和DNN中,训练的早期阶段对训练过程的结果至关重要。在动物中,这与缺陷有关的区域的大脑结构变化密切相关。这在人工网络中不可避免地有所不同,因为它们的连接性在训练期间一直都是固定的。然而,并不是所有的连接对网络都同样有用。考虑一个编码近似后验分布p_ω(y

x)的网络,其中,ω表示权重参数。来自特定连接的最终输出的依赖性可以通过扰动相应的权重和观察最终分布的变化幅度来估计。给定权重扰动ω=ω+δω,p_ω(y

x)和由扰动生成的p_ω(y

x)之间的偏差可以由K-L散度度量,即:

其中的F为Fisher信息矩阵(FisherInformationMatrix,FIM):

FIM可以被认为是一个局部指标,用于衡量一个单一权重(或一个权重组合)的扰动对网络输出的影响程度。特别是,具有低Fisher信息的权重可以被改变或修剪,对网络的性能影响不大。这表明,Fisher信息可以作为DNN有效连接的衡量指标,或者,更广泛地说,作为连接的突触强度(synapticstrength)的衡量标准。最后,FIM也是损失函数Hessian的半定逼近,因此也是训练过程中某一点ω的损失情况的曲率,在FIM和优化程序之间提供了一种关联性。

不幸的是,完整的FIM太大,无法计算。因此,本文作者使用它的轨迹来测量全局或逐层的连接强度。作者提出使用以下方法计算FIM:

为了捕捉非对角线项的行为,作者还尝试使用Kronecker-Factorized近似计算全矩阵的对数行列式。作者观察到了与trace相同的定性趋势。由于FIM是一个局部测量,它对损失情况的不规则性非常敏感。因此,作者在文中主要使用ResNets,ResNets具备相对平滑的损失情况。对于其他架构,作者则使用一个基于在权重中注入噪声的更稳健的FIM估计器。

FIM可以被确定为对模型中包含的训练数据信息量的一种衡量。在此基础上,人们会期望随着从经验中获得信息,连接(connection)的总体强度会单调地增加。然而,情况并非如此。虽然在早期阶段网络就获得了有关数据的信息,从而使得连接强度的大幅增加,但一旦任务的表现开始趋于平稳,网络就开始降低其连接的整体强度。然而,这并不对应于性能的降低,相反,性能一直在缓慢提高。这可以被看作是一个遗忘或压缩阶段,在这个阶段,多余的连接被消除,数据中不相关的变化被抛弃。在学习和大脑发育过程中,消除(修剪)不必要的突触是一个基本的过程,这一点已经得到了前期研究的证实(图4,中心)[8]。在图4(左)中,类似的现象在DNN中得到了清晰和定量的显示。

连接强度的这些变化与对关键期诱发的缺陷(如图像模糊)的敏感性密切相关,如图1中使用滑动窗口方法计算。在图4中,我们看到敏感性与FIM的趋势密切相关。FIM是在没有缺陷的情况下在网络训练过程中的一个点上计算的局部数量,而关键期的敏感性是在有缺陷的网络训练结束后,使用测试数据计算的。图4(右)进一步强调了缺陷对FIM的影响:在存在缺陷的情况下,FIM会增长,甚至在缺陷消除后仍然大幅增长。作者分析,这可能是由于当数据被破坏到无法分类时,网络被迫记忆标签,因此增加了执行相同任务所需的信息量。

图4.DNN的关键期可追溯到Fisher信息的变化

对FIM的逐层分析进一步揭示了缺陷对网络的影响。在没有缺陷的情况下训练网络时(在这种情况下是All-CNN,它比ResNet有更清晰的层次划分),最重要的连接是在中间层(图5,左),它可以在最有信息量的中间尺度上处理输入的CIFAR-10图像。然而,如果网络最初是在模糊的数据上训练的(图5,右上方),连接的强度是由顶层(第6层)主导的。作者分析,这是因为图像的低层和中层结构被破坏了。然而,如果在训练的早期消除缺陷(图5,顶部中心),网络会设法重组,以减少最后一层所包含的信息,同时增加中间层的信息。作者把这些现象称为信息可塑性的变化。然而,如果数据变化发生在巩固阶段(consolidationphase)之后,网络就无法改变其有效连接。每层的连接强度基本上保持不变。此时,网络失去了它的信息可塑性,错过了它的关键期。

图5.各层权重所含信息的归一化数量与训练epoch的关系。(左上)在没有缺陷的情况下,网络主要依靠中间层(3-4-5)来解决任务。(右上)在存在图像模糊缺陷的情况下,直到第个epoch,更多的资源被分配到高层(6-7),而不是中间层。(顶部中心)当缺陷在较早的epoch被消除时,各层可以部分地重新配置(例如,第6层中信息的快速损失)。(最下面一行)同样的图,但引入的是翻转缺陷,并不会诱发关键期

最后,对FIM的分析也揭示了损失函数的几何形状和学习动态。由于FIM可以被解释为残余分布(landscape)的局部曲率,图4显示,学习需要越过瓶颈阶段。在初始阶段,网络进入高曲率的区域(高Fisher信息),一旦开始进入巩固阶段,曲率就会下降,使其能够跨越瓶颈以进入后续阶段。收敛的早期阶段是引导网络走向正确的收敛结果的关键。关键期的结束是在网络跨越了所有的瓶颈(从而学会了特征)并进入一个收敛区域(低曲率的权重空间区域,或低Fisher信息)之后。

1.3讨论

到目前为止,关键期仍被认为是一种专门的生物现象。同时,对DNN的分析主要集中在其渐进特性上,而忽略了其初始的瞬态行为。作者表示,本文是第一个探讨人工神经网络临界期现象的文章,并强调瞬态在决定人工神经网络的渐进性能中的关键作用。受突触连接在调节关键期作用的启发,作者引入了Fisher信息来研究这个阶段。文章表明,对缺陷的最初敏感性与FIM的变化密切相关,既是全局性的,因为网络首先迅速增加,然后减少储存的信息量;也是分层的,因为网络重组其有效连接,以最佳方式处理信息。

本文工作与生物学中关于关键期的大量文献相关。尽管人工网络是神经元网络的一种极其简化的近似,但它们表现出的行为与在人类和动物模型中观察到的关键期有本质上的相似。本文给出的信息分析表明,DNN中最初的快速记忆阶段之后是信息可塑性的损失,这反过来又进一步提高了其性能。在文献[9]中,作者观察到并讨论了训练的两个不同阶段的存在,他们的分析建立在激活的(香农)信息上,而不是权重的(费雪)Fisher信息。在多层感知器(MLP)上,文献[9]根据经验将这两个阶段与梯度协方差的突然增加联系起来。然而,必须注意的是,FIM的计算是使用与模型预测有关的梯度,而不是与ground-truth标签有关的梯度,这就会导致质量差异。图6显示梯度的均值和标准偏差在有缺陷和无缺陷的训练中没有表现出明显的趋势,因此,与FIM不同,它与对关键期的敏感性没有关联。

图6.训练期间梯度均值(实线)和标准偏差(虚线)的对数值。(左)不存在缺陷,(中)第70个epoch后出现模糊缺陷,(右)最后一个epoch出现缺陷。

除了与关键期的缺陷敏感性有密切的关系外,Fisher信息还具有一些技术优势,包括对角线易估计、对互信息的选择估计器不敏感,以及能够辅助探测人工神经网络中各层有效连接的变化情况。

对激活的完整分析不仅要考虑到信息量(包括与任务有关的和与干扰有关的),还要考虑其可及性,例如,与任务有关的信息能多容易被一个线性分类器提取出来。按照类似的想法,Montavon等人[10]通过对每层表征的径向基函数(RBF)核嵌入进行主成分分析(PCA),研究了表征的简单性的逐层或空间(不是时间)的演变。他们表明,在多层感知器上,与任务相关的信息更多地集中在表征嵌入的第一个主成分上,从而使得它们变得更容易被逐层访问。本文工作专注于权重的时间演变。一个具有较简单权重的网络(由FIM测量)也需要一个较简单的平滑表示(如由RBF嵌入测量),以抵抗权重的扰动从而正常运行。因此,本文分析与Montavon等人的工作是一致的。同时使用这两个框架来研究网络的联合时空演变情况将会非常有趣。

1
查看完整版本: 动物与人类的关键学习期,深度神经网络也有