transformer 在 CV 中的应用(三) ViT 分类网络

Posted by feizaipp on May 10, 2021

我的博客

0. 参考资料

1. 网络结构

       ViT 网络是谷歌在 2020 年提出的基于纯 Transformer 实现的分类网络,它完全抛弃了 CNN 网络。 ViT 网络中的 Transformer 与传统意义上的 Transformer 存在明显的不同,传统的 Transformer 是用于 NLP 中的机器翻译任务中,它的结构由 Encoder 和 Decoder 两部分组成,因为要将输入的序列通过 Encoder 网络进行编码,然后将编码后的序列通过 Decoder 解码,最终得到目标语言。但是在视觉任务中,我们只需要对图像进行特征提取,然后将特征通过全链接层输出目标类别。

       我们知道,在 NLP 模型中 Transformer 的输入是一个序列,那么对于图像数据, ViT 是将一张图像分成一个个小的 patch ,然后对这些 patch 进行编码。

       ViT 网络的特征提取是使用了一个叫做 class token 的结构,该结构是可学习的。它与图像编码后的张量在 dim=1 处进行 cat 操作得到一个新的张量,新张量的第一个元素就是 class token ,这个新的张量与位置编码进行相加后输入到 ViT 网络中,最后学习到的 class token 即为图像的特征图,将这个特征通过全链接层得到类别的输出。

       ViT 网络的整体结构如下图所示。

vit-construct

2. 代码解析

       ViT 代码整体结构也比较简单,主要是注意张量在网络中传播的过程中维度的变化。下面我们从创建 ViT 网络开始。

       ViT 的代码实现中,没有使用像 view 、 reshape 等的函数进行维度的变换,而是使用 einops 库,其实也很简单,看下面代码注释就很好明白了。

from einops import rearrange, repeat
from einops.layers.torch import Rearrange
  • 首先要查创建一个 ViT 类
  • dim: patch embedding 的维度
  • depth: Transformer 结构中 Encoder 的个数
  • heads: MSA 的 head 个数
  • mlp_dim: Transformer 结构中 FFN 的输出维度
  • img: 随机初始化一个维度为 [4, 3, 256, 256] , 4 表示 batch size
  • mask: 可选的 mask ,用于委托哪个 patch 去使用 attend
  • preds: 调用 ViT 的前向传播函数,输出预测的类别,维度为 [4, 1000]
    if __name__ == "__main__":
      v = ViT(
          image_size = 256,
          patch_size = 32,
          num_classes = 1000,
          dim = 1024,
          depth = 6,
          heads = 16,
          mlp_dim = 2048,
          dropout = 0.1,
          emb_dropout = 0.1
      )
    
      img = torch.randn(4, 3, 256, 256)
      mask = torch.ones(1, 8, 8).bool() # optional mask, designating which patch to attend to
    
      # [batch, 1000]
      preds = v(img, mask = mask)
    

       下面看一下 ViT 类的实现,注释中标注了张量维度的变化。

  • 图像的大小是 [256, 256] ,每个 patch 的大小是 [32, 32] ,所以图像被分割成了 64 个 patch
    class ViT(nn.Module):
      def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
          super().__init__()
          assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
          num_patches = (image_size // patch_size) ** 2
          patch_dim = channels * patch_size ** 2
          assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'
    
          # img:[4,3,256,256]
          # patch_size=32 ,每个 patch 的大小是 (32, 32)
          # h=8
          # w=8
          # Rearrange: [4, 64, (32*32*3)]=[4, 64, 3072]
          # patch_dim: 3072
          # dim: 1024
          # 经过全链接层后,维度由 [4, 64, 3072] -> [4, 64, 1024]
          self.to_patch_embedding = nn.Sequential(
              Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = patch_size, p2 = patch_size),
              nn.Linear(patch_dim, dim),
          )
    
          # 位置编码
          # num_patches: (256//32) ** 2=64
          # (1, 65, 1024)
          self.pos_embedding = nn.Parameter(torch.randn(1, num_patches + 1, dim))
          # (1,1,1024)
          self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
          # emb_dropout=0.1
          self.dropout = nn.Dropout(emb_dropout)
    
          # dim=1024
          # depth=6
          # heads=16
          # dim_head=64
          # mlp_dim=2048
          self.transformer = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
    
          self.pool = pool
          # 这个函数建立一个输入模块,什么都不做,通常用在神经网络的输入层。这个可以用在残差学习中。
          self.to_latent = nn.Identity()
    
          # dim=1024
          self.mlp_head = nn.Sequential(
              nn.LayerNorm(dim),
              nn.Linear(dim, num_classes)
          )
    
      def forward(self, img, mask = None):
          # img:[b,3,256,256]
          # x:[b,64,1024]
          x = self.to_patch_embedding(img)
          b, n, _ = x.shape
    
          # self.cls_token: [1, 1, dim]
          # cls_tokens: [b, 1, dim]
          cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
          # x: [b, 65, 1024]
          x = torch.cat((cls_tokens, x), dim=1)
          # x 加上位置编码
          x += self.pos_embedding[:, :(n + 1)]
          x = self.dropout(x)
    
          # x: 输入输出维度都是 [b, 65, 1024]
          x = self.transformer(x, mask)
    
          # [b, 1024]
          x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
    
          x = self.to_latent(x)
          return self.mlp_head(x)
    

       下面我们看下 Transformer 的代码, ViT 中的 Transformer 结构中有 Attention 、FeedForward 两个结构。注意此处没有 Decoder 结构,只有 Encoder 。

# dim=1024
# depth=6
# heads=16
# dim_head=64
# mlp_dim=2048
class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout = 0.):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([
                Residual(PreNorm(dim, Attention(dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))
            ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

       先看 Attention 。

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        # project_out: True
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        # 计算 q k v 矩阵
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x, mask = None):
        # x: [b, 65, 1024]
        # h:16
        b, n, _, h = *x.shape, self.heads
        # chunk: 将 tensor 进行分割成 3 份,如果指定轴的元素被 3 除不尽,那么最后一块的元素个数变少
        # qkv: [b, 65, inner_dim * 3]
        qkv = self.to_qkv(x).chunk(3, dim = -1)
        # q k v: [b, 65, inner_dim]=[b, 65, 16*64]->[b, 16, 65, 64]
        # map: 取出 qkv 中的每个值,然后执行 lambda 表达式
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)

        # dots: 返回的维度为 [b, 16, 65, 65] ,相当于 q*k.T
        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale
        # 获取最小值,也就相当于 0
        mask_value = -torch.finfo(dots.dtype).max

        # mask: [1, 8, 8]
        if mask is not None:
            # pad: 矩阵填充函数
            # input: 需要扩充的 tensor ,可以是图像数据,抑或是特征矩阵数据
            # pad: 扩充维度,用于预先定义出某维度上的扩充参数
            # mode: 扩充方法,’constant‘, ‘reflect’ or ‘replicate’三种模式,分别表示常量,反射,复制
            # value: 扩充时指定补充值,但是 value 只在 mode='constant’ 有效,即使用 value 填充在扩充出的新维度位置,而在’reflect’和’replicate’模式下,value不可赋值
            # pad 定义:
            # 如果参数pad只定义两个参数,表示只对输入矩阵的最后一个维度进行扩充
            # 如果参数pad只定义四个参数,前两个参数对最后一个维度有效,后两个参数对倒数第二维有效。
            # 如果参数pad定义六个参数,前4个参数完成了在高和宽维度上的扩张,后两个参数则完成了对通道维度上的扩充。
            # p1d = (左边填充数, 右边填充数)
            # p2d = (左边填充数, 右边填充数, 上边填充数, 下边填充数)
            # p3d = (左边填充数, 右边填充数, 上边填充数, 下边填充数, 前边填充数,后边填充数)
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            a = rearrange(mask, 'b i -> b () i ()')
            b = rearrange(mask, 'b j -> b () () j')
            # [b, 1, 65, 1] * [b, 1, 1, 65]=[b, 1, 65, 65]
            mask = rearrange(mask, 'b i -> b () i ()') * rearrange(mask, 'b j -> b () () j')
            dots.masked_fill_(~mask, mask_value)
            del mask

        attn = dots.softmax(dim=-1)

        # out = attn.v
        # [b, 16, 65, 65]*[b, 16, 65, 64]->[b, 16, 65, 64]
        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        out =  self.to_out(out)
        return out

       再看 FeedForward 。

  • FFN: 两个全链接层,第一个使用激活函数,第二个不使用。
    class FeedForward(nn.Module):
      def __init__(self, dim, hidden_dim, dropout = 0.):
          super().__init__()
          self.net = nn.Sequential(
              nn.Linear(dim, hidden_dim),
              nn.GELU(),
              nn.Dropout(dropout),
              nn.Linear(hidden_dim, dim),
              nn.Dropout(dropout)
          )
      def forward(self, x):
          return self.net(x)
    

       残差网络的实现,先执行 fn 操作,然后与输入进行相加。

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

       归一化的实现,先对输入进行归一化,然后执行 fn 操作

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)