SMASH

one-shot model architecture search through hypernetworks

Posted by JY on June 11, 2020

1. Introduction

为深度神经网络设计架构需要专业知识和大量计算时间(expert knowledge and substantial computation time)

  • 本文通过训练一个辅助模型(HyperNet),去训练搜索过程中的候选模型,这个超网络动态地生成具有可变结构的主模型的权值(generate the weights o f a main model conditioned on that model’s architecture)

2. Method

One-shot Model Architecture Search through HyperNet (SMASH)

目标是根据一个网络的验证集性能对一组网络的性能进行排序,这个任务是通过超网络生成权重来完成的

  • 在每一个训练step,随机采样一个网络结构,用超网络生成它的权重,然后通过反向传播对整个网络进行端对端训练
  • 当模型训练好后,在随机采样一些网络,它们的权重是由超网络生成的,在验证集上直接评估它们的性能
  • 选择具有最佳性能网络最终的训练测试

img

img

SMASH有两个核心要点:

网络结构生成的方法:开发了前馈网络存储库视图(memory-bank view of feed-forward networks)

根据给定的采样结构生成权重的方法:使用HyperNet

  • Defining Variable Network Configurations

    为了探索不同深度、连接模式和层数的网络效果,并能将网络结构简单地编码成向量输入到HyperNet中,提出了memory-bank机制

    • 将网络视为一组可以读写的存储库,每一层是一个操作,这个操作从存储库的一个子集中读取数据、修改数据,然后将它写入到另一个存储库子集中。

      对于一个单分支结构,网络包含一个大的可读可写的存储体,每次操作都覆盖存储体的内容(Figure 1左),对于单DenseNet结构,读取之前所有存储体中的数据然后将新得到的数据写入一个空的存储体中(Figure 1中),FractalNet定义了一个更复杂的读写模式(Figure 1右)

      img

    • 基础网络(Figure2(b))由多个block组成,每个block包含基于给定空间分辨率的存储体,并与大多数CNN架构一样,其空间分辨率依次减半(基础网络的权重由HyperNet产生),通过$1\times1$的卷积加池化进行下采样(下采样层和全连接层的权重都是由网络学习到的,不是由HyperNet产生的

    • 当采样结构时,每个block中的存储体数量和每个存储体中的通道数都是随机采样的,当在一个block内定义每个层时,随机选择的读写模式和操作类型op;当从多个存储体中读数据时,将读到的数据进行通道维度拼接,当将得到的数据写入一个存储体时通常是将张量相加

    • 每个op(Figure2(a))是由用于降维的$1\times1$卷积层、数量不等的非线性卷积层以及非线性激活层组成,每次随机选择4个卷积中的一个激活(卷积核尺寸、空洞因子、组的数量和输出单元的数量),$1\times1$卷积的输出channel数与op的输出channel数成一定比例,比例也是随机选取的

      • $1\times 1$卷积的权重由HyperNet生成,其它卷积则通过正常训练获得

      • 为了保证可变的深度,每个block仅学习4个卷积,并且在block的op操作中共享其权值。限制最大卷积核大小以及最大输出channel数,假设选择的op操作的参数小于最大值,则将权重裁剪至目标大小

      • 下采样卷积和输出层同样基于输入的channel数对权重进行裁剪

      img

  • Learning to map architectures to weights

对网络结构进行编码,使其可以作为HyperNet的输入

img

  • 对采样网络结构进行编码,得到嵌入向量$c\in R^{1\times (2M+d_{max})\times(N_{max}/N)^2\times n_{ch}/D}$
    • $M$是一个block中存储体的最大数量
    • $d_{max}$是最大的空洞因子
    • $n_{ch}$是主网络所有$1\times 1$卷积输入通道数总和
    • $N$是存储体中的通道数
    • $N_{max}$和$D$为超参数
    • 每层的输出单元数必须被N整除,输入单元数必须被D整除,$D$和$N$的限制使得嵌入向量的尺寸减小了$DN^2$倍
    • 条件嵌入向量$c$是内存体的one-hot编码,batch=1,有$2M+d_{max}$个通道,其中前M个通道表示从M个存储体读取数据,后M个通道表示要写入的存储体,最后$d_{max}$个通道表示应用到$3\times3$卷积空洞因子的one-hot编码,Height维度和每层的单元数目相关,Length维度与网络深度相关(即输入通道的总数)

img

  • HyperNet有$4DN^2$个输出通道(channel on/off),$W=H(c) \in R^{1 \times 4 D N^{2} \times\left(N_{\max } / N\right)^{2} \times n_{c h} / D}$,可以reshape为$W=H(c) \in R^{N_{m a x} \times 4 N_{m a x} n_{c h} \times 1 \times 1}$,可以视为对每层产生权重

3. Conclusion

该方法通过训练一次辅助网络,避免每个网络都需要训练的情况,使得训练时间大大减小