BN——虽然玄学,但是养活了很多炼丹师。


BN——虽然玄学,但是养活了很多炼丹师。

大家好我是栗子鑫,先祝大家520快乐!但是谁能拒绝在约会的时候看一集《炼丹传》呢?😎😎

今天主要介绍一下笔者在秋招的面试某厂的时候,面试真题:详细介绍BN。遇到了这个问题笔者主要从如下几个方面来回答。文末附上源码。


提出背景

随着深度学习的发展,研究的问题变的越来越复杂,面对这些复杂的问题我们需要深层次的网络进行训练,从而导致需要大量的时间进行调参。深层网络之所以如此难训练,因为层与层之间存在高度的关联性与耦合性。

关联性会导致:随着训练的进行网络中的参数也随着梯度下降在不停的更新。一方面:当底层网络中的参数发生微弱变化时,由于每一层中的线性变换和非线性激活映射,这一些微弱变化随着网络层数的加深而被放大。另一方面:深度神经网络涉及到很多层的叠加,每一层的参数更新会导致上层的输入数据分布发生变化,通过层层叠加,高层的输入分布变化会非常剧烈,这就使得高层需要不断去重新适应底层的参数更新,使模型的训练变得困难。

Google 将这一现象总结为 Internal Covariate Shift,简称 ICS. ICS具体什么呢?笔者这里引用了一位专业人士的回答:

大家都知道在统计机器学习中的一个经典假设是“源空间(source domain)和目标空间(target domain)的数据分布(distribution)是一致的”。如果不一致,那么就出现了新的机器学习问题,如 transfer learning / domain adaptation 等。而 covariate shift 就是分布不一致假设之下的一个分支问题,它是指源空间和目标空间的条件概率是一致的,但是其边缘概率不同,即:对所有但是

大家细想便会发现,的确,对于神经网络的各层输出,由于它们经过了层内操作作用,其分布显然与各层对应的输入信号分布不同,而且差异会随着网络深度增大而增大,可是它们所能“指示”的样本标记(label)仍然是不变的,这便符合了covariate shift的定义。由于是对层间信号的分析,也即是“internal”的来由。

ICS 会导致什么问题:

简而言之,每个神经元的输入数据不再是“独立同分布”。

  1. 上层网络需要不停调整来适应输入数据分布,导致网络学习速度的降低。
  2. 下层输入的变化可能趋向于变大或者变小,导致上层落入饱和区,使得学习过早停止。
  3. 每层的更新都会影响到其它层,因此每层的参数更新策略需要尽可能的谨慎。

面对ICS问题如何解决

由于 ICS 问题的存在,输入的分布可能相差很大。因此在输入之前对数据进行”白化“处理

白化操作:规范数据分布的方法,对输入数据分布进行变换

“白化”的目的主要有两个目的:

  1. 使得输入特征分布具有相同的均值和方差
  2. 去除特征之间的相关性

白化最典型的方法就是PCA,具体可以查阅。然而“理论正确”的方法就是对每一层的数据都进行白化操作。然而标准的白化操作代价高昂,特别是我们还希望白化操作是可微的,保证白化操作可以通过反向传播来更新梯度。因此,下文介绍的以 BN 为代表的 Normalization 方法退而求其次,进行了简化的白化操作。


BN是什么

BN全称 Batch Normalization译为“批规范化”。,即在每次SGD时,通过mini-batch来对相应的activation做规范化操作,使得结果(输出信号各个维度)的均值为0,方差为1. 而最后的“scale and shift”操作则是为了让因训练所需而“刻意”加入的BN能够有可能还原最初的输入(即当),从而保证整个网络结构的capacity。

capacity的解释:实际上BN可以看作是在原模型上加入的“新操作”,这个新操作很大可能会改变某层原来的输入。当然也可能不改变,不改变的时候就是“还原原来输入”。如此一来,既可以改变同时也可以保持原输入,那么模型的容纳能力(capacity)就提升了

具体的操作流程如下图所示:

BN——虽然玄学,但是养活了很多炼丹师。

上述算法简化的“白化”,直接在每个mini-batch中计算得到mini-batch mean和variance来替代整体训练集的mean和variance. 从某种意义上来说,代表的其实是输入数据分布的方差和偏移。对于没有BN的网络,这两个值与前一层网络带来的非线性性质有关,而经过变换后,就跟前面一层无关,变成了当前层的一个学习参数,这更加有利于优化并且不会降低网络的能力。

BN的参数更新

怎样学BN的参数在此就不赘述了,就是经典的链式法则,具体操作如图:

BN——虽然玄学,但是养活了很多炼丹师。

测试阶段如何使用BN?

BN在每一层计算的均值和方差都是基于当前batch中的训练数据,测试阶段没有像训练样本中那么多的数据,因此一定是有偏估的,应该保留训练阶段每一组mini-batch训练数据在网络的每一层的,然后使用整个样本的统计量来对测试数据进行归一化。


BN的优缺点

BN优点:

  1. 减轻了对参数初始化的依赖,这是利于调参的朋友们的。
  2. 训练更快,可以使用更高的学习率。
  3. BN一定程度上增加了泛化能力,dropout等技术可以去掉。

BN的缺点:

从上面可以看出,BN依赖于batch的大小,当batch值很小时,计算的均值和方差不稳定。研究表明对于ResNet类模型在ImageNet数据集上,batch从16降低到8时开始有非常明显的性能下降,在训练过程中计算的均值和方差不准确,而在测试的时候使用的就是训练过程中保持下来的均值和方差。

这一个特性,导致BN不适合以下的几种场景。

  1. Batch非常小,比如训练资源有限无法应用较大的Batch,也比如在线学习等使用单例进行模型参数更新的场景。
  2. RNN系列,因为它是一个动态的网络结构,同一个Batch中训练实例有长有短,导致每一个时间步长必须维持各自的统计量,这使得BN并不能正确的使用。在RNN中,对BN进行改进也非常的困难。不过,困难并不意味着没人做,事实上现在仍然可以使用的,不过这超出了咱们初识境的学习范围。

当时面试管追问了几个问题:

1.BN除了上述优点还有什么优点?

使用BN后,可以使用更大的学习率,从而跳出不好的局部极值,增强泛化能力

2.BN训练时为什么不用全量训练集的均值和方差呢?

用全量训练集的均值和方差容易过拟合,对于BN,其实就是对每一批数据进行归一化到一个相同的分布,而每一批数据的均值和方差会有一定的差别,而不是用固定的值,这个差别实际上能够增加模型的鲁棒性,也会在一定程度上减少过拟合。

3.BN一般设定的位置

卷积->bn->激活函数,但并非绝对。(denseNet网络则建议 bn->激活函数->卷积 )

总结:

相信给朋友看到这篇笔记,对于这道面试以及整个BN的提出背景,作用和意义有了更深的了解。假如读者遇到该道面试题,相信你们能够很好的面对,切记面试不要死记硬背,面试官主要考察的不是你的记忆能力,而是你的理解能力。祝大家能够早日拿到心意的offer。

源码解析

主要计算过程:

  1. 求均值。
  2. 求方差。
  3. 对数据进行标准化(将数据规范到标准正态分布)。
  4. 训练参数γ和β。
  5. 通过线性变换输出
import torch
import torch.nn as nn

def batch_norm(x,gamma,beta,moving_mean,moving_var,eps=1e-5, momentum=0.9):
#eps 和 momentu为超参数 Hyperparameter
#测试和训练的均值方差计算是不同的
if not torch.is_grad_enabled():
#测试的时候直接使用全局的
x_hat = (x-moving_mean ) / torch.sqrt(moving_var+eps)
else:
assert (x.shape) in (2,4)
if (x.shape) == 2:#liner
mean = x.mean(dim=0)
var = ((x-mean)**2).mean(dim=0)
else:#cnn
mean = x.mean(dim=(0,2,3),keepdim=True)
var = ((x-mean)**2).mean(dim=(0,2,3),keepdim=True)
x_hat = (x-mean ) / torch.sqrt(var+eps)
#然后更新 moving_mean_var
moving_mean = moving_mean*momentum + mean*(1-momentum)
moving_var = moving_var * moving_var + var*(1-momentum)
y = gamma*x_hat +beta
return y,moving_mean,moving_var


class BatchNorm(nn.Module):
def __init__(self,num_features,num_dims):
super(BatchNorm, self).__init__()
self.batch_norm = batch_norm
if num_dims==2:#如果是liner bn后的shape
res_shape = (1,num_features)
else:#卷积的bn后的shape
res_shape = (1,num_features,1,1)
#学习的参数gamma_beta
self.gamma = nn.Parameter(torch.ones(res_shape))
self.beta = nn.Parameter(torch.ones(res_shape))

#均值方差
self.moving_mean = torch.zeros(res_shape)
self.moving_var = torch.zeros(res_shape)

def forward(self,x):
#device
if x.device != self.moving_mean.device:
self.moving_mean = self.moving_mean.to(x.device)
self.moving_var = self.moving_var.to(x.device)
Y,self.moving_mean,self.moving_var = self.batch_norm(x,self.gamma,self.beta,
self.moving_mean,self.moving_var)
return Y

参考文献:

  1. PCA Whitening[http://ufldl.stanford.edu/tutorial/unsupervised/PCAWhitening/]

作者   栗子鑫   

编辑   一口栗子



原文始发于微信公众号(六只栗子):BN——虽然玄学,但是养活了很多炼丹师。

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 举报,一经查实,本站将立刻删除。

文章由极客之音整理,本文链接:https://www.bmabk.com/index.php/post/88774.html

(0)
小半的头像小半

相关推荐

发表回复

登录后才能评论
极客之音——专业性很强的中文编程技术网站,欢迎收藏到浏览器,订阅我们!