swin transformer源码解读

2020 年 5 月,Facebook AI 推出了DERT( Detection Transformer),用于目标检测和全景分割。

2020 年 10 月,谷歌提出了Vit(Vision Transformer),利用 Transformer 对图像进行分类,而不需要卷积网络。

2021年1月,OpenAI 提出两个模型:DALL·E 基于本文直接生成图像,CLIP将图像映射到文本描述的类别中。两个模型都利用 Transformer 。

2021年3月,微软提出Swin Transformer,把CV各大任务给屠榜了。。。。

我能放过它?我不能。。。总结下前段时间看了论文和代码梳理出来的swin_transformer框架和实现。

论文: https://arxiv.org/abs/2103.14030

代码: https://github.com/microsoft/Swin-Transformer

swin_transformer介绍

1. swin_transformer优化点

swin_transformer对比之前Vit有两个改进点:

1.引入了CNN里常用的多层次transformers结构

Vit的尺度是不变的,不易于接入到下游任务中,比如分割的encoder阶段可以方便的接入resnet等backbone网络,而Vit的特征图尺寸是不变的下图(b)。swin_transfomer通过合并image_patchesd的方式引入多层次结构,如下图(a)。

2. swin_transformer如何优化

针对第一个优化点,论文使用的网络架构如下:

代码模块逻辑:

patch_embed + pos_embed

stage1

-BasicLayer

--SwinTransformerBlock(*2)

---WindowAttention

stage2

-BasicLayer

--SwinTransformerBlock(*2)

---WindowAttention

stage3

-BasicLayer

--SwinTransformerBlock(*6)

---WindowAttention

stage4

-BasicLayer

--SwinTransformerBlock(*4)

---WindowAttention

主要模块的代码逻辑:

1.patch_embed:PatchEmbed

首先进行一次patch_embed,patch_embed就是把输入按patch进行一次向量映射。我认为就是卷积操作(标题swin_transfomer,第一步就是卷积~卷积yyds)

设定输入:(3,256,256),patch_size=4,embeding_dim=96

(1)分辨率不够4整除就pad到4的倍数

(2)通用卷积kernel=4,stride=4,将image映射为无重叠的4*4的patchs:(96,64,64)

(3)如果需要norm,再进行一次layerNorm

(4)(3,256,256) 通过patch_embed,特征为(96,64,64)

2.absolute_pos_embed

如果有position_embeding步骤,需要学习一个96,64,64的pos_emded参数。和patch_embed进行concat.

将emded矩阵进行flatten+transpose-->64*64, 96

3.stages

对分辨率缩小*4的特征图进行4个stage的-BasicLayer

BasicLayer

1.attn_mask

设定window_size=7,以stage1为例输入特征图大小为(64,64)。img_mask初始为(70,70),那么通过window_partition就把特征图切分为100个7*7的窗口。

img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device)
h_slices = (slice(0, -self.window_size),
 slice(-self.window_size, -self.shift_size),
 slice(-self.shift_size, None))
w_slices = (slice(0, -self.window_size),
 slice(-self.window_size, -self.shift_size),
 slice(-self.shift_size, None))
cnt = 0
for h in h_slices:
 for w in w_slices:
 img_mask:, h, w, : = cnt
 cnt += 1
mask_windows = window_partition(img_mask, self.window_size)
mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))

以上代码目的是得到100个49*49的attn_mask。

这里的attn_mask是为后续的cyclic shift,也就是SW-MSA使用。

首先,对img_mask70*70的图进行切分9大块赋值

63*63=0 4*63=1 3*64=2

63*4=3 4*4=4 3*4=5

64*3=6 4*3=7 3*3=8

2.SwinTransformerBlock(*n)

(1)reshape+pad

对输入64*64, 96进行layer_norm+reshape+pad操作。pad作用是要FM的H,W是window_size的倍数。对stage1:64*64, 96-->70,70,96

(2)window_mask_self_attention(W-MSA/SW-MSA)

先看第一阶段W-MSA blcok,也就是不加入cyclic shift。

(a)进行window_partition,将特征图切分为window_size*window_size的patch,1,70*70,96切分为100,7,7,96,再reshape100,49,96

(b) WindowAttention

计算self_attention

然后在X和Y方向计算relative_coords。计算relative_coords第一步加(window_size-1)是为了让值都为正数,在X方向再*(2*window_size-1)是为了后续求和能区分(0,1)和(1,0)这类坐标。

(b)windowAttention

计算attention和上诉步骤一致,只是在步骤a中我们提到了,ABC区域在计算attention时需要mask掉,这里的mask就是我们BasicLayer的第一步获取的attn_mask(100,49,49)~

if mask is not None:
 nW = mask.shape0
 attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)
 attn = attn.view(-1, self.num_heads, N, N)
 attn = self.softmax(attn)
else:
 attn = self.softmax(attn)

mask主要逻辑,attn假设目前是200,3,49,49,我们计算的attn_mask是(100,49,49),因为是针对窗口位置mask和bs和head_num无关,所以将attn和mask分别reshape到(2, 100, 3, 49, 49)和(1,100,1,49,49)就好了。

最后记得window_rever后,记得把shift_x给sereverse回去。

x = torch.roll(shifted_x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
以上就将最复杂的SwinTransformerBlock模块介绍完了~

3.down_sample

downsamp(最后一个stage不需要)使用的是PatchMerging.对FM进行间隔采样达到降采样的目的,再concat低分辨率FM后,通过全连接对C通道裁剪。很像pixelShuffle的反向操作。

self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
x = x.view(B, H, W, C)
padding
pad_input = (H % 2 == 1) or (W % 2 == 1)
if pad_input:
 x = F.pad(x, (0, 0, 0, W % 2, 0, H % 2))
x0 = x:, 0::2, 0::2, : # B H/2 W/2 C
x1 = x:, 1::2, 0::2, : # B H/2 W/2 C
x2 = x:, 0::2, 1::2, : # B H/2 W/2 C
x3 = x:, 1::2, 1::2, : # B H/2 W/2 C
x = torch.cat(x0, x1, x2, x3, -1) # B H/2 W/2 4*C
x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
x = self.norm(x)
x = self.reduction(x)

以上就是一个basicLayer的逻辑,通过四个stage得到不同尺度的特征图(Swin-T)

stage1-->96, 64, 64

stage2-->192, 32, 32

stage3-->384, 16, 16

stage4--> 768, 8, 8

有了这个四个特征图就可以和resnet等结构一样,接入到下游任务了~

本站文章资源均来源自网络,除非特别声明,否则均不代表站方观点,并仅供查阅,不作为任何参考依据!
如有侵权请及时跟我们联系,本站将及时删除!
如遇版权问题,请查看 本站版权声明
THE END
分享
二维码
海报
swin transformer源码解读
2020 年 5 月,Facebook AI 推出了DERT( Detection Transformer),用于目标检测和全景分割。
<<上一篇
下一篇>>