Overview of Few-shot Segmentation

Posted by JY on March 19, 2021

1. Few-Shot 分割任务

  • 在测试任务上:利用很少的具有分割标签的support set来获得query set目标的分割结果(query set和support set两者要求类别相同)
  • 训练任务的要求:训练任务的support set和query set都需要给出分割的mask,support set的mask用于提供指导信息query set mask 用于计算预测损失
  • 训练任务与测试任务的要求:训练任务中的类别与测试任务中的类别是不同的

2. Method

主流的Few-shot Segmentation 框架

  • Network Parameter Prediction框架:OSLSM
  • Metric Learning Based框架
    • Dense Comparison框架:CANet
    • Similarity Guided 框架:SG-One
    • Prototype Alignment框架:PANet、PMMs

2.1 OSLSM

《One-Shot Learning for Semantic Segmentation》 BMVC 2017

首次提出双分支的网路用于few-shot segmentation,使用conditioning branch直接预测sementation branch的权重参数

对segmentation branch提取到的特征进行加权融合,经过上采样后得到分割结果。

1

conditioning branch:

  • 输入:原始三通道彩色图像与其目标对应的mask相乘得到,保证了输入只包含目标的信息
  • 为了使参数量与segmentation branch的特征图的通道数相对应,采用weight hashing的策略,将输出的1000维向量映射为4097维

segmentation branch:

  • 采用FCN-32s结构,对query image进行特征提取,将其与conditioning branch所得参数进行点乘得到分割结果

2

2.2 CANet

《CANet: Class-Agnostic Segmentation Networks with Iterative Refinement and Attentive Few-Shot Learning》 CVPR 2019

模型包括三部分:密集比较模块(Dense comparison module),迭代优化模块(iterative optimization module)和ASPP模块

3

DCM:

  • backbone提取特征网络为Res-50,只用了block2和block3中间层的特征,block2后经过dilation=2的卷积得到block3,然后拼接后再经过一层卷积重新表示图像特征信息。
  • 接着mask和图像特征点乘操作,得到的结果只包含了类别的信息,去除了图像中无关的背景信息。
  • 最后使用全局平均池化得到特定类别的特征向量(类别信息都包含在这个特征向量中)
  • 密集比较:support和query生成相同维度的特征表示,两者拼接在一起,目的在于对空间每个位置都进行比较,最后得到匹配后的结果

IOM:

  • 同一类别的不同图像,存在明显的外观差异,上一步密集比较可能只匹配了目标的一部分,因此需要进行迭代优化得到准确的分割结果

  • 该模块的输入是由密集比较模块生成的要素图和最后一次迭代预测的mask。采用残差的形式来融合预测的mask信息,相当于自监督学习 \(M_t = x + F(x,y_{t-1})\)

ASPP:

  • 主要为了捕获多尺度信息

4

引入注意力机制:

  • 在DCM中添加了一个与密集比较卷积平行的注意力模块
  • 每个注意力模块由两个卷积块组成,第一个卷积块是256$\times$3,并使用3$\times$3的池化,第二个只有一个3$\times$3的卷积核,采用的是全局平均池化,最终得到该support image下的影响分数,并将所有的影响分数做softmax后得到最终每个图像对结果的影响分数
  • 然后将得到的分数与每个support image特征表示相乘得到新的特征表示,最后将所有support image的特征表示累加之后得到最终support set在该类别下的特征表示

5

2.3 SG-One

《SG-One: Similarity Guidance Network for One-Shot Semantic Segmentation》

网络包含一个由CNN构成的主干网络Stem和两个分支Guidance Branch, Segmentation Branch,这两个分支共用三个卷积块

6

  • 提出了使用Masked Average Pooling来提取support set中目标的表征向量
    • 全卷积网络(FCN)能够保留输入图像的中每个像素相对位置;因此通过将二值mask与提取到的特征图相乘就可以完全保留目标的特征信息,排除掉背景等无关类别的特征
  • 采用余弦相似度来度量query set的表征向量与support set的表征向量之间的距离,用于指导query set的分割

7

2.4 PANet

《PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment》ICCV 2019

基于metric-learning的思想,提出 Prototype Alignment network,能够更好地利用support set中包括的语义信息,实现小样本语义分割

  1. 将support和query images映射到同一个表征空间,得到对应的feature maps
  2. 基于support features,为每个类别构建其对应的prototype
  3. Each pixel of the query image is labeled by referring to the class-specific prototypes nearest its embedding representation,得到对应的predicted masks
  4. 在训练过程中,还采用Prototype Alignment Regularization(PAR),使得信息可以从query set传向support set。即利用predicted mask + query features,生成对应的prototypes,并预测原始support set中的图像
    • 优点:generate more consistent prototypes between support and query images, offering better generalization performance.

创新点:

  • 采用基于原型网络的结构实现了少样本语义分割任务
  • 设计PAR方法,充分利用支持集图像信息,提高分割的精度

img

Details of Model:

  • Step 1: Feature Extraction

    • 对于每个Episode,首先提取support and query features by a shared backbone network
    • 采用了后混合掩码(late fusion mask)的方式,先提取特征图,再对特征图进行掩码。能够保证输入到feature extractor中的query和support images保持一致性。
  • Step 2: 计算Prototypes

    • 在support features上应用Masked Average Pooling,计算每个类别的prototype

    img

    img

    分子:像素灰度和;分母:像素个数

  • Step3: 生成预测结果

    label each pixel of query images as the class of the nearest prototypes

    • Calculate the distance between the query feature vector $F_q^{(x,y)}$ at each spatial location with each computed prototypes $p_j$.

    • Apply a softmax over the distances to produce a probability over semantic classes (including background).

    • The final predicted result $\hat{M}_q^{(x,y)}$ is then given by the max function. \(\tilde{M}_{q ; j}^{(x, y)}=\frac{\exp \left(-\alpha d\left(F_{q}^{(x, y)}, p_{j}\right)\right)}{\sum_{p_{j} \in \mathcal{P}} \exp \left(-\alpha d\left(F_{q}^{(x, y)}, p_{j}\right)\right)}\)

      \[\hat{M}_{q}^{(x, y)}=\underset{j}{\arg \max } \tilde{M}_{q ; j}^{(x, y)}\]
    • 对QuerySet计算分割损失函数 \(\mathcal{L}_{\mathrm{seg}}=-\frac{1}{N} \sum_{x, y} \sum_{p_{i} \in \mathcal{P}} \mathbb{1}\left[M_{q}^{(x, y)}=j\right] \log \tilde{M}_{q ; j}^{(x, y)}\)

  • Step4:PAR

    • 将predicted mask和query features 作为新的support set,求取其对应的prototypes。再对原来的support set进行分割,从而获得反向信息。使得query和support互相监督,学习更多有效信息

      \[\tilde{M}_{c, k ; j}^{(x, y)}=\frac{\exp \left(-\alpha d\left(F_{c, k}^{(x, y)}, \bar{p}_{j}\right)\right)}{\sum_{\bar{p}_{j} \in\left\{\bar{p}_{c}, \bar{p}_{\mathrm{bg}}\right\}} \exp \left(-\alpha d\left(F_{c, k}^{(x, y)}, \bar{p}_{j}\right)\right)}\] \[\overline{\mathcal{P}}=\left\{\bar{p}_{c} \mid c \in\right. \left.\mathcal{C}_{i}\right\} \cup\left\{\bar{p}_{\mathrm{bg}}\right\}\] \[\mathcal{L}_{\mathrm{PAR}}=-\frac{1}{C K N} \sum_{c, k, x, y} \sum_{p_{j} \in \mathcal{P}} \mathbb{1}\left[M_{q}^{(x, y)}=j\right] \log \tilde{M}_{q ; j}^{(x, y)}\]

      最终损失函数: \(\mathcal{L} = \mathcal{L}_{seg} + \lambda \mathcal{L}_{PAR}\)

    img

算法优点:

  • segmentation阶段,没有引入额外的learnable parameters,减少了over-fitting的可能性
  • Prototype embedding & prediction都是直接作用于feature maps,无需在网络结构汇总进行额外的信息传递
  • 由于PAR还在训练阶段进行,因此不会影响inference cost

2.5 PMMs

《Prototype Mixture Models for Few-Shot Semantic Segmentation》 ECCV 2020

目前主流的Few-shot segmentation方法主要是基于metric-learning framework,并主要采用了prototype model。存在以下两点不足:

  • prototype model主要是基于global average pooling (GAP) guided by GT masks来计算prototype的,但是在GAP算法中,the spatial layout of objects is completely dropped,导致easily mix semantic from various parts
  • Single prototype 不足以包含充足的信息,表征能力有限

    以上两点不足会导致 Semantic ambiguity problem

PMM在训练过程中,使用EM算法来估计prototypes,并分别为前景和背景进行建模,计算对应的multiple prototypes,以提高模型的判断能力

优点:

  • PMM可以作为spatially squeezed representation, 与query features 匹配以激活与object class相关的feature channels
  • each prototype vector 可以看作一个C维的线性分类器,将P-conv与query features进行element-wised,生成相应的probability map

9

Prototype Mixture Model (PMM)

  • 混合概率模型:linearly combine probabilities from base distributions ($w_k$使用相同权重)

\(p\left(s_{i} \mid \theta\right)=\sum_{k=1}^{K} w_{k} p_{k}\left(s_{i} \mid \theta\right)\)

  • 一个probability model基于核距离函数

\(p_{k}\left(s_{i} \mid \theta\right)=\beta(\theta) e^{\operatorname{Kernel}\left(s_{i}, \mu_{k}\right)}\)

  • 核距离函数可以使用:Gaussian模型(RBF距离)或者VMF模型(cosine距离) \(p_{k}\left(s_{i} \mid \theta\right)=\beta_{c}(\kappa) e^{\kappa \mu_{k}^{T} s_{i}}\)
  • 根据矢量距离,将PMM定义为:

\(p_{k}\left(s_{i} \mid \theta\right)=\beta_{c}(\kappa) e^{\kappa \mu_{k}^{T} s_{i}}\) ​ $\theta={\mu, \kappa}$

​ $\beta_{c}(\kappa)=\frac{\kappa^{c / 2-1}}{(2 \pi)^{c / 2} I_{c / 2-1}(\kappa)}$ is the normalization coefficient.

  • Models Learning:使用EM算法估计$\mu^+$和$\mu^-$作为prototypes vectors

    • E-step:给定模型参数和sample feature extracted,计算sample $s_i$的期望公式 \(E_{i k}=\frac{p_{k}\left(s_{i} \mid \theta\right)}{\sum_{k=1}^{K} p_{k}\left(s_{i} \theta\right)}=\frac{e^{\kappa \mu_{k}^{T} s_{i}}}{\sum_{k=1}^{K} e^{\kappa \mu_{k}^{T} s_{i}}}\)

    • M-step:使用上一步计算得到的期望更新PMM’s的mean vectors \(\mu_{k}=\frac{\sum_{i=1}^{N} E_{i k} s_{i}}{\sum_{i=1}^{N} E_{i k}}\)

P-Match

  • 每个foreground prototypes都包含一个object part对应的表征信息,将其组合在一起则可包含更对的表征信息,尽可能表示完整的对象范围 \(Q^{\prime}=\mathrm{P}-\operatorname{Match}\left(\mu_{k}^{+}, Q\right), k=1, \ldots, K\)

P-Conv

  • 每个prototype vector 都包含着discriminative information,都可以看作一个classifier,用于计算probability maps \(M_{k}=\mathrm{P}-\operatorname{Conv}\left(\mu_{k}^{+}, \mu_{k}^{-}, Q\right), k=1, \ldots, K \\ M_{p}^{+}=\sum_{k} M_{k}^{+} \\ M_{p}^{-}=\sum_{k} M_{k}^{-}\)

  • 之后再与$Q’$进行concatenate,得到activate query map $Q’’$ \(Q^{\prime \prime}=M_{p}^{+} \oplus M_{p}^{-} \oplus Q^{\prime}\)
  • 最后,将activated query map $Q’‘$输入ASPP,在输入卷积层,生成predicted mask

8

Residual Prototype Mixture Models (RPMM)

  • 通过堆叠多个PMM来实现集成模型

10

References

  • One-shot learning for semantic segmentation
  • Conditional networks for few-shot semantic segmentation
  • Few-Shot Segmentation Propagation with Guided Networks
  • Few-Shot Semantic Segmentation with Prototype Learning
  • One-Shot Segmentation in Clutter
  • SG-One: Similarity Guidance Network for One-Shot Semantic Segmentation
  • PANet: Few-Shot Image Semantic Segmentation with Prototype Alignment
  • Attention-based Multi-Context Guiding for Few-Shot Semantic Segmentation
  • CANet: Class-Agnostic Segmentation Networks with Iterative Refinement and Attentive Few-Shot Learning
  • Prototype Mixture Models for Few-shot Semantic Segmentation
  • Part-aware Prototype Network for Few-shot Semantic Segmentation
  • Few-Shot Semantic Segmentation with Democratic Attention Networks