跳转至

An Image is Worth 16x16 Words:Transformers for Image Recognition at Scale.md

论文

https://arxiv.org/abs/2010.11929

代码

https://github.com/google-research/vision_transformer

摘要

Cite

虽然Transformer架构已成为自然语言处理任务的事实标准,但其在计算机视觉中的应用仍然有限。在视觉上,注意力要么与卷积网络结合使用, 要么用于替换卷积网络的某些组件,同时保持其整体结构不变。我们表明,这种对神经网络的依赖是不必要的,直接应用于图像补丁序列的纯变换器可以很好地执行图像分类任务。 当对大量数据进行预训练并将其传输到多个中型或小型图像识别基准(ImageNet、CIFAR-100、VTAB等)时,与最先进的卷积网络相比,视觉变换器(ViT) 获得了优异的结果,同时需要更少的计算资源来训练

谷歌团队,纯Transformer的视觉分类器。在前言章节给出的结论是,纯Transformer的结构相比于基于CNN的,是少了归纳偏置的,这使得在数据量较少时(ImageNet等) 表现不如CNN,但也正是由于没有归纳偏置更加自由,如果给了足够的数据(JFT-300M)效果将会超过CNN。

方法



流程2

  1. 输入图片例如256256,分成多个patch,每个patch例如3232,则有256*256/32/32=64个patch
  2. 对patch做embedding,每个patch投影成一个向量,长度例如1024,则有64个1024向量
  3. 加上位置编码position encoding
  4. 再拼接一个cls token,变成65个1024向量
  5. 输入到encoder进行自注意力提取特征
  6. 对cls token加MLP做分类


全文最主要的结论就是VIT需要大量的数据才能发挥出比CNN更好的性能

总结

Cite

我们已经探索了Transformer在图像识别中的直接应用。与先前在计算机视觉中使用自我关注的工作不同,除了初始的补丁提取步骤之外, 我们没有将图像特定的归纳偏置引入到架构中。相反,我们将图像解释为一系列补丁,并使用NLP中使用的标准Transformer编码器对其进行处理。 这种简单但可扩展的策略与大型数据集上的预训练相结合时,效果出奇地好。因此,Vision Transformer在许多图像分类数据集上达到或超过了最先进水平, 同时预训练成本相对较低。尽管这些初步成果令人鼓舞,但仍存在许多挑战。一种是将ViT应用于其他计算机视觉任务,例如检测和分割。我们的结果, 加上Carion等人(2020)的结果,表明了这种方法的前景。另一个挑战是继续探索自我监督的预训练方法。我们的初步实验表明,自监督预训练有所改进, 但自监督预训与大规模监督预训之间仍有很大差距。最后,ViT的进一步扩展可能会提高性能

实现代码2

#  !/usr/bin/env  python
#  -*- coding:utf-8 -*-
# @Time   :  2021.
# @Author :  绿色羽毛
# @Email  :  lvseyumao@foxmail.com
# @Blog   :  https://blog.csdn.net/ViatorSun
# @Note   :

import torch
from   torch import nn, einsum
import torch.nn.functional as F
from   einops import rearrange, repeat
# from   einops.layers.torch import Rearrang

class Residual(nn.Module):
    def __init__(self, fn):
        super().__init__()
        self.fn = fn
    def forward(self, x, **kwargs):
        return self.fn(x, **kwargs) + x

class PreNorm(nn.Module):
    def __init__(self, dim, fn):
        super().__init__()
        self.norm = nn.LayerNorm(dim)
        self.fn   = fn
    def forward(self, x, **kwargs):
        return self.fn(self.norm(x), **kwargs)

class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout = 0.):
        super().__init__()
        self.net = nn.Sequential(   nn.Linear(dim, hidden_dim),
                                    nn.GELU(),
                                    nn.Dropout(dropout),
                                    nn.Linear(hidden_dim, dim),
                                    nn.Dropout(dropout) )
    def forward(self, x):
        return self.net(x)

class Attention(nn.Module):
    def __init__(self, dim, heads = 8, dim_head = 64, dropout = 0.):
        super().__init__()
        inner_dim = dim_head *  heads
        self.heads = heads
        self.scale = dim ** -0.5

        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias = False)
        self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout) )


    def forward(self, x, mask = None):
        # b, 65, 1024, heads = 8
        b, n, _ = x.shape
        h = self.heads
        # self.to_qkv(x): b, 65, 64*8*3
        # qkv: b, 65, 64*8
        qkv = self.to_qkv(x).chunk(3, dim = -1)     # 沿-1轴分为3块
        # b, 65, 64, 8
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h = h), qkv)
        # dots:b, 65, 64, 64
        dots       =  torch.einsum('bhid,bhjd->bhij', q, k) * self.scale
        mask_value = -torch.finfo(dots.dtype).max
        if mask is not None:
            mask = F.pad(mask.flatten(1), (1, 0), value = True)
            assert mask.shape[-1] == dots.shape[-1], 'mask has incorrect dimensions'
            mask = mask[:, None, :] * mask[:, :, None]
            dots.masked_fill_(~mask, mask_value)
            del mask

        # attn:b, 65, 64, 64
        attn = dots.softmax(dim=-1)

        # 使用einsum表示矩阵乘法:
        # out:b, 65, 64, 8
        out = torch.einsum('bhij,bhjd->bhid', attn, v)

        # out:b, 64, 65*8
        out = rearrange(out, 'b h n d -> b n (h d)')

        # out:b, 64, 1024
        out =  self.to_out(out)
        return out

class Transformer(nn.Module):
    def __init__(self, dim, depth, heads, dim_head, mlp_dim, dropout):
        super().__init__()
        self.layers = nn.ModuleList([])
        for _ in range(depth):
            self.layers.append(nn.ModuleList([  Residual(PreNorm(dim, Attention(  dim, heads = heads, dim_head = dim_head, dropout = dropout))),
                                                Residual(PreNorm(dim, FeedForward(dim, mlp_dim, dropout = dropout)))    ]))
    def forward(self, x, mask = None):
        for attn, ff in self.layers:
            x = attn(x, mask = mask)
            x = ff(x)
        return x

class ViT(nn.Module):
    def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, mlp_dim, pool = 'cls', channels = 3, dim_head = 64, dropout = 0., emb_dropout = 0.):
        super().__init__()
        assert image_size % patch_size == 0, 'Image dimensions must be divisible by the patch size.'
        num_patches = (image_size // patch_size) ** 2
        patch_dim   = channels * patch_size ** 2

        # assert num_patches > MIN_NUM_PATCHES, f'your number of patches ({num_patches}) is way too small for attention to be effective (at least 16). Try decreasing your patch size'
        assert pool in {'cls', 'mean'}, 'pool type must be either cls (cls token) or mean (mean pooling)'

        self.patch_size         = patch_size
        self.pos_embedding      = nn.Parameter(torch.randn(1, num_patches + 1, dim))
        self.patch_to_embedding = nn.Linear(patch_dim, dim)
        self.cls_token          = nn.Parameter(torch.randn(1, 1, dim))
        self.dropout            = nn.Dropout(emb_dropout)
        self.transformer        = Transformer(dim, depth, heads, dim_head, mlp_dim, dropout)
        self.pool               = pool
        self.to_latent          = nn.Identity()
        self.mlp_head           = nn.Sequential(  nn.LayerNorm(dim), nn.Linear(dim, num_classes) )

    def forward(self, img, mask = None):
        p = self.patch_size

        # 图片分块
        # print(img.shape)
        x = rearrange(img, 'b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1 = p, p2 = p)    # 1,3,256,256  ->  1,64,3072

        # 降维(b,N,d)
        x       = self.patch_to_embedding(x)
        b, n, _ = x.shape

        # 多一个可学习的x_class,与输入concat在一起,一起输入Transformer的Encoder。(b,1,d)
        cls_tokens = repeat(self.cls_token, '() n d -> b n d', b = b)
        x = torch.cat((cls_tokens, x), dim=1)

        # Positional Encoding:(b,N+1,d)
        x += self.pos_embedding[:, :(n + 1)]
        x  = self.dropout(x)

        # Transformer的输入维度x的shape是:(b,N+1,d)
        x = self.transformer(x, mask)

        # (b,1,d)
        x = x.mean(dim = 1) if self.pool == 'mean' else x[:, 0]
        x = self.to_latent(x)
        return self.mlp_head(x) # (b,1,num_class)

if __name__ == '__main__':
    v = ViT(image_size=256, 
            patch_size=32, 
            num_classes=10, 
            dim=1024, 
            depth=6, 
            heads=16, 
            mlp_dim=2048, 
            dropout=0.1,
            emb_dropout=0.1)
    img = torch.randn(1, 3, 256, 256)
    mask = torch.ones(1, 8, 8).bool()  # optional mask, designating which patch to attend to
    preds = v(img, mask=mask)  # (1, 1000)
    print(preds)

评论