z Vision Transformer의 이해와 Swin Transformer
본문 바로가기

Computer Vision

Vision Transformer의 이해와 Swin Transformer

728x90

목차

  • Introduction
  • Attention
    • Seq2Seq
    • Attention
    • Dot-Product Attention
  • Attention All You Need
    • Positional Encoding
    • Multi=Head Attention
    • Why Self-Attention?
  • ViT
    • Input
    • Positional Encoding
    • Transformer Encoder
    • ViT Review
    • Hybrid Architecture
    • Fine Tuning and Higher Resolution
    • Setup
  • Swin Transformer
    • Introduction
    • Overall architecture
    • Shifted Window based Self-Attention
  • Reference

Introduction

현재 자연어 처리 (NLP) 부분에서 굉장히 높은 성능을 내고 있으며 흔히 잘 알려진 BERT, GPT3 는 Transformer 기반 모델입니다.


자연어 처리에서도 이렇게 좋은 성능을 내는 Transformer가 Computer Vision Task 에서도 잘 적용할 수 있는 방법에 대해 연구가 굉장히 많이 되고 있습니다. 이에 단순 분류 모델로는 ViT, Swin Transformer 등이 존재하고, Object Detection으로는 DETR, Deformable DETR 등의 모델이 존재합니다. 이때 ViT, Swin Transformer는 객체 탐지에서 feature를 추출하며 경우에 따라 예측이 진행되는 Backbone network로서 좋은 성능을 냅니다.

 

특히 단순 분류 task를 위한 Swin Transformer는 객체 탐지에서 HTC++의 backbone으로 사용되어 COCO DATASET, Box AP 기준 SOTA를 달성했습니다. (작성 당시 기준)

 

뒤에서도 말을 하겠지만 Transformer는 Convolutional computation 없이 굉장히 좋은 성능을 냅니다. 특히 dot product를 통한 서로(픽셀, 패치,,) 간의 연관성을 파악하는 형식 입니다.

 

딥러닝을 처음 공부할 때 즈음에 구글에서 ViT(Vision Transformer)라는 분류 모델을 냈었는데 굉장히 반응이 뜨거웠던걸로 기억합니다. 하지만 당시 처음 공부할 때는 CNN, RNN 같은 기본적인 모듈에 대해서 공부를 하고 있던 터라 Transformer에 대한 모델들에 "아 그냥 대단한게 나왔나보다" 하고 지나갔는데 이번 기회에 공부하게 되어 좋았습니다.

 

특히 이미지의 지역성(가까운 픽셀끼리는 연관성이 높고 멀리 있는 픽셀은 연관성이 떨어지는 특성 같은 것)을 파악하기 위해서는 무조건 CNN이다 라고 생각했었는데, Transformer라는 신박한 방법으로 파악하는게 정말 저에게는 신기했습니다.

Attention

원래의 자연어 처리 분야는 sequence 형태의 (순서가 중요한) 데이터를 처리해야 하기 때문에 이를 위해 RNN, LSTM, GRU 모듈을 통한 Encoder Decoder 형식의 seq2seq 구조를 주로 사용했습니다.

우리가 사용하는 문장들은 특정 문법에 따라 전달 됩니다. 이 때 문장의 순서를 문법이라 생각 할 수 있으니 문장에서 순서가 얼마나 중요한지는 이해가 되실거라 생각합니다!

http://incredible.ai/nlp/2020/02/20/Sequence-To-Sequence-with-Attention/ : Incredible.AI

 

위의 사진은 seq2seq 구조를 보여줍니다. 보면 Encoder 부분을 통해 정보를 압축하여 Encoder State(encoder의 마지막 hidden state)로 만들고 이를 decoder의 처음 hidden state로 사용하여 목적에 맞는 작업을 수행합니다.

Encoder : 문장의 정보를 압축함.

Decoder : 압축된 정보를 사용해서 특정 Task를 위한 결과를 내놓음

 

Seq2Seq 문제

하지만 만약 데이터의 sequence가 매우 길다면 encoder에서 처리해야 할 정보의 양은 굉장히 많은데 이를 encoder state로 압축해야해서 정보의 손실이 일어날 수 있습니다.

 

이 문제 뿐만 아니라 Recurrent 종류의 모듈의 고질적인 문제인 Vanishing gradient도 존재합니다.

 

그래서 이 문제를 해결하기 위해 Attention이라는 메커니즘을 도입합니다.

Attention

Attention의 기본적인 메커니즘은 decoder에서 출력 단어를 예측하는 매 시점(sequence)마다, Encoder에서의 전체 입력 문장을 다시 한 번 참고한다는 것 입니다. 즉 decoder의 특정 output 위치에 대응되는 순서를 다시 환기시켜 줌을 의미합니다. 

Dot-Product Attention

https://wikidocs.net/22893 : 딥러닝을 이용한 자연어 처리 입문. Wikidocs. Attention Mechansim

 

Attention을 이해하기 위해서는 Dot-product attention을 알아야합니다.


위의 그림을 보면 decoder에서 한 모듈의 출력을 완전히 예측하기 전에 encoder의 모든 입력 단어를 다시 한번 참고를 하는 형태입니다. 위의 사진은 번역 작업을 예로 든건데, 데이터가 sequence이기 때문에 순서가 굉장히 중요합니다.

 

그래서 decoder의 해당 모듈의 출력이 encoder의 모든 입력 단어 중 어떤 것과 관련이 깊은지 dot product를 통해 파악합니다. 즉, 원래의 문장에서 대응되는 번역된 단어와 일치시키는 것인데 이를 Alignment라 합니다.

이 때 dot product는 내적으로 벡터간의 dot product는 \(a \cdot b = |a||b|\cos\theta\) 이기 때문에 두 벡터간 상관관계를 파악할 수 있습니다.
벡터간의 각도(유사성)을 포함하기 때문.

https://wikidocs.net/22893 : 딥러닝을 이용한 자연어 처리 입문. Wikidocs. Attention Mechansim

 

이렇게 각 벡터간의 내적을 통해 구한 값을 attention score라고 하고 이에 softmax 함수를 통해 attention distribution을 구합니다.
attention distribution의 값들을 가중치(weight)로 잡고, weight * encoder hidden state 를 해서 softmax score가 높을수록 살리고, 낮으면 무시합니다.

 

이 후 모든 hidden states를 다 더하거나 Concatenate으로 처리한 후 (논문마다 다양) context vector 를 만듭니다.
만들어진 context vector는 decoder의 input으로 들어가게 됩니다. (hidden state와 함께 사용된다.)

 

이렇게 간단하게 Attention의 필요성과 원리에 대해 알아봤습니다. 간단하게 정리하면 decoder의 output에 Alignment를 통해 연관성 있는 encoder의 정보를 찾고 이를 한번 더 고려하여 결과를 낸다. 라고 할 수 있습니다.


Attention All You Need

Attention All You Need는 Transformer가 처음 소개된 2017년의 논문 이름입니다. 이 제목 자체가 Transformer를 굉장히 잘 소개해주는 것 같아 목차로 지정했습니다.

 

이 이름 그대로 Transformer는 RNN, LSTM, GRU의 모듈 말고 Attention 메커니즘만 사용하는 모듈입니다.


간략하게 말하면 transformer는 encoder-decoder 형태를 띄며, encoder 에서는 Self-Attention이란 것을 사용하여 자기 자신 안에서의 특정 데이터가 다른 데이터들과 어떤 관계를 갖는지를 파악합니다. 이 정보를 decoder에서 attention을 통해 decoder의 data와 상관관계를 한번 더 파악합니다.

https://arxiv.org/pdf/1706.03762.pdf : Attention All You Need (2017) Vaswani et al.

 

transformer의 architecture를 제시하고 모듈에 대해 하나하나씩 설명하겠습니다. _기본적으로 Skip Connection을 통한 Residual Learning을 차용합니다.  Encoder 및 Decoder는 N번 반복 될 수 있습니다.

Postional Encoding

우선 Encoder에서의 Positional Encoding에 대해 보겠습니다.

 

RNN이나 LSTM같은 연속적인 모듈들은 input으로 들어가는 순서가 있기 때문에 각 데이터의 위치를 알 수 있습니다. 하지만 Transformer의 input으로 들어오는 값들은 순서대로 들어오는게 아니라 한번에 들어오기 때문에 위치에 대한 정보를 알 수 없습니다. 만약 위치에 대한 정보가 없다면 번역 뿐만 아니라 아무것도 진행할 수 없을 것입니다.

 

그래서 이를 해결하기 위한 방법으로 Positional Encoding을 제시합니다.

 

Positional Encoding의 기본적인 아이디어는 "sequence에서 토큰의 상대적 또는 절대적 위치에 대한 정보를 주입하자." 라는 것 입니다. 물론 이에 대한 방법으로는 뒷부분 모듈에 넣는 방법, 훈련 가능한 방법 등 여러 방법이 있지만 "Attention All You Need"에 기술된 sinusodial 함수에 대해 설명하고자 합니다.

 

Positional Encoding은 Input embedding과 동일한 차원으로 \(d\{model}\) 차원의 정보를 갖게 됩니다._

이를 위한 방법으로 sine과 cosine을 이용한 frequencies를 사용합니다. 이에 대한 식은 다음과 같습니다.

https://arxiv.org/pdf/1706.03762.pdf : Attention All You Need (2017) Vaswani et al.

위의 식의 i는 차원에서의 몇번째인지를 의미하고 pos는 정보의 위치 값을 의미합니다. 만약 \\(Length,d\_{model}\\)의 데이터라면 \\(L\\)는 pos이고 \\(d\_{model}\\)의 i번째를 i로 칭합니다.


아래의 그림은 주기함수를 띄는 위의 식을 시각화 한건데,

https://inmoonlight.github.io/2020/01/26/Positional-Encoding/ : Space Moon blog. Positional Encoding

 

position vector의 주기가 vector의 dimension마다 변화합니다.. 전체 벡터 크기\\(d\_{model}\\)가 128이라고 가정할 때, i가 작을수록 주기가 짧고 i가 클수록 주기도 길어집니다. 즉 각 고유의 값을 갖게 됩니다.


이 뿐만 아니라 PE vector 간의 distance는 대칭적이고 거리에 따라 일정한 비율로 감소합니다. 이는 Transformer의 self-attention 연산에서 빛을 발함을 나타내는 특징입니다. 아래 그림은 이에 대한 특징을 보여줍니다.

 

https://inmoonlight.github.io/2020/01/26/Positional-Encoding/ : Space Moon blog. Positional Encoding

의문이 들 수 있는 부분인데, 어차피 cos, sin 말고 들어온 순서대로 1,2,3,4 ... 값을 매겨서 더하면 되지 않을까? 라는 생각이 들 수 있습니다.

하지만 값의 크기에 있어서 차이가 나면 결과에 영향을 끼칠 수 있기 때문에 이런 형식은 지양해야 합니다.

 

Multi-Head Attention

Scaled Dot Product Attention

Multi Head Attention을 설명하기 위해서는 Scaled Dot Product Attention에 대한 이해가 필요합니다.

 

https://arxiv.org/pdf/1706.03762.pdf : Attention All You Need (2017) Vaswani et al.

Input으로 Q(Query),K(Key),V(Value)가 들어가는데 Q는 물어보는 주체, K는 대상, V는 이에 대한 가중치라 생각하시면 됩니다. Self Attention의 경우 Q, K, V는 서로 동일한 데이터 입니다(나 자신과 나 자신을 비교해야 하기 때문). (자기 자신의 데이터에 대한 연관성들)


Attention에서 봤던 것과 구조는 거의 비슷합니다.

 

  • MatMul(Dot Product)을 통해 Query와 Key의 관계 파악
  • Scale에서 key, query의 차원 \(d_k\)에 대해 \(d_k^2\)로 나누기 연산
  • Mask(선택)을 통해 중복될 수 있는 부분을 마스킹
  • Softmax 연산
  • 가중치 V와 Attention Distribution 값을 MatMul
    이를 간단히 식으로 나타내면 아래와 같습니다.

 

https://arxiv.org/pdf/1706.03762.pdf : Attention All You Need (2017) Vaswani et al.

Multi-Head Attention

https://arxiv.org/pdf/1706.03762.pdf : Attention All You Need (2017) Vaswani et al.

 

이제 이 과정을 시행할 때 사용자 설정 파라미터인 m 만큼 차원을 나눠서 넣습니다. 이 후 동일하게 Scaled Dot Product Attention을 진행하고 나눠져 나온 결과 값들을 Concat시킨 후 Linear Layer를 거치면 끝입니다.

 

이런 방식을 통하면 여러 방면, 시각에서 보는 것 같은 효과를 볼 수 있다고 합니다.

 

마지막으로 Transformer의 전체 Architecture를 보면 Decoder 부분의 두번째 MHA를 제외하고는 전부 Self Attention 입니다.

self attention은 자기 자신이 Query, Key, Value로 들어가기 때문에 딱히 껄끄러울 게 없지만 Decoder의 두번째 Multi Head Attention은 Value와 Key로 Encoder의 output을 받고 Query로 Decoder 과정의 데이터를 받아 둘 사이의 상관관계를 파악함에 주의 해야 합니다.

 

간단한 FFN, Activation에 대한 설명은 넘어가겠습니다.

Why Self-Attention

이번엔 Self-attention의 다양한 부분을 RNN, CNN과 비교해보려고 합니다.
일단 이에 대해서 세가지 사항을 고려했습니다.

 

  • layer 당 전체 계산 복잡도
  • 병렬화 할 수 있는 계산량
  • 최대 Path Length

 

https://arxiv.org/pdf/1706.03762.pdf : Attention All You Need (2017) Vaswani et al.

 

위의 Table은 이에 대한 결과를 보여줍니다.

 

이를 통해 Convolutional, Recurrent Module에 비해 효율적인 모듈임을 알 수 있습니다.

 

여기에서 주목 해야 할 점은 Computational Complexity 방면에서 sequence의 길이 n이 d보다 작아야 Recurrent 보다 효율적임을 알 수 있습니다. 생각해보면 기계번역에 있어서 SOTA의 대부분의 경우 sequence의 길이 n이 d보다 작습니다.

 

이처럼 매우 긴 sequences에서 Self attention의 Computational performance를 개선시키기 위해 각 출력 위치 기준 r 만큼의 주변만 고려하는 "restricted"한 방법이 있습니다. 이는 위 Table의 가장 하단에 해당합니다.


ViT (Vision Transformer)

An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale _ Dosovitskiy et al.
Code : https://github.com/yhy258/ViT-Simple

 

자 이제 어느정도 Transformer에 대한 이해를 가졌다고 할 수 있습니다. 이를 토대로 vision 분야에 transformer를 접목 시킨 분류 모델인 ViT를 봐봅시다!

 

https://arxiv.org/pdf/2010.11929.pdf : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(2020) Dosovitskiy et al.

 

우선 위의 그림에서 ViT 모델의 대략적인 아키텍쳐를 보여주고 있습니다. 위의 아키텍쳐를 기반으로 코드를 통해 설명 하겠습니다.

Input

Input으로 들어오는 이미지가 여러 patch로 나눠져서 들어가게 됩니다. 여기에서 각 patches는 Linear Projection of Flattened Patches를 통해 Transformer Encoder로 들어가게 됩니다.

 

Input image의 shape은 \(H \times W\times C\)이고, 이 이미지는 flatten된 2D patches의 sequence인 \(N \times (P^2C)\)가 됩니다. 이 때 \((P,P)\)는 각 image patch의 해상도, 즉 크기이며, \(N\)은 Patches의 갯수이며 \(N = HW/P^2\)가 됩니다.

 

Transformer에서는 상수인 latent vector size D에 대해 flatten patches에 매핑합니다.

 

코드 상 에서는 위의 과정을 아래 코드와 같은 Conv2d function 하나를 사용 합니다.

 

nn.Conv2d(C, D, kernel_size = P, stride = P)

https://medium.com/@bdhuma/6-basic-things-to-know-about-convolution-daef5e1bc411 : Conv Image

convolution 연산은 위와 같은 형태의 연산이라서 이 연산을 통해 패치(kernel_size)크기에 대한 정보를 한번에 담을 수 있습니다. 그래서 간단한 위 코드 한 줄을 통해 패치 형태를 나타낼 수 있습니다.

 

이를 다시 말하면, kernel size와 stride 모두 P로 Convolution을 진행하면 P를 patch size로 갖는 patch를 각 이미지로 나눠서 Channel C를 D로 매핑하는 거라서 위와 동일한 과정이 됩니다.

 

그리고 이 후 축소된 이미지 크기 \(h ,w\)에 대해서 flatten을 적용해서 \(N\)으로 만들어 주면 됩니다.
이러한 과정을 거치면 결국 \((BS, D, N)\)의 형태로 output이 나오게 됩니다. (Pytorch Tensor 기준.)여기까지가 각 patches에 Linear Projection of Flattened Patches를 적용 한 것 입니다.

 

Input 부분의 마지막으로, 위 Figure 1을 보면 Class Token의 0 이 존재합니다. 이는 우리의 목적인 분류에 대한 정보를 담아주기 위한 토큰 입니다.

 

코드 상 에서는

 

self.cls_token = nn.Parameter(torch.zeros(1,1,hid_dim))

 

와 같이 0으로 초기화 되며 \\\\((1,1,D)\\\\) shape을 갖는 Parameter로 지정합니다. 이 때 모든 batch size에 적용하기 위해 실행부 에서 expand를 거치고 input과 concat 해줍니다.

Positional Embedding

이미지의 경우에도 자연어 처리와 같이 position에 대한 고려가 굉장히 중요합니다.

 

그래서 위 Input의 과정 후에 Positional Embedding을 적용하게 됩니다. 논문에 나와 있는 것에 따르면 2D position embeddings를 사용하는 것과 1D position embeddings를 사용하는 것 사이에 눈에 띄는 performance의 진전이 없다고 합니다.

그래서 Learnable 1D Positional Embedding을 사용했습니다.

 

self.pos_embedding = nn.Parameter(torch.zeros(1, n_patches+1, hid_dim))

 

위 input 과정의 class token을 고려하여 \\((1,N+1, D)\\) shape의 embedding matrix를 생성합니다. 이 embedding matrix를 input과 더해줍니다.

Transformer Encoder

https://arxiv.org/pdf/2010.11929.pdf : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(2020) Dosovitskiy et al.

 

Transformer Encoder는 위와 같은 형태로 이루어져 있습니다.

 

위에서 언급했던 Input과 Positional Embedding 과정을 거친 Embedded Patches를 받습니다. 이 데이터에 Layer Norm을 취해주고 이 후 MHA( Multi-Head Attention )에 보냅니다.

 

기본적으로 Residual Connection을 기반으로 이루어져 있고, 특정 파라미터 L에 의해서 반복되는 layer의 갯수가 정해집니다.

 

MHA에 대한 설명은 따로 하지 않겠습니다. (위 Transformer에서 설명했던 내용)

ViT Review

위의 과정에 대한 전체적인 식 입니다.

https://arxiv.org/pdf/2010.11929.pdf : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(2020) Dosovitskiy et al.

 

지금까지의 내용을 다시 곱씹어보면서 수식을 파악해보시길 바랍니다.

Hybrid Architecture

위에서의 설명은 초기의 Input이 raw image patches에 대한 feature map의 sequence로 들어가게 됩니다.


하지만 hybrid model에서는 patch embedding이 특정 backbone model에 의해서 추출된 CNN feature map을 기반으로 적용됩니다. (원래 이미지에 대해 size가 줄어들고, 응축된 정보를 갖는다.)


backbone model을 따로 사용하기 때문에 patches의 spatial size(resolution)는 1x1이 될 수 있습니다.

Fine Tuning And Higher Resolution

pretrain된 ViT를 미세 조정하는 것 또한 당연히 가능합니다. Fine tuning 시에는 pretrain된 ViT의 prediction head를 제거 후 zero-initialized DxK feedforward layer로 대체합니다.

 

이제 종종 fine tuning 할 때 higher resolution에 대해 진행해야 할 때가 있습니다.(즉 이미지의 크기가 다를 때)


이미지의 크기가 더 클 때, 동일한 patch size로 설정하고 더 긴 길이로 ViT 모델에 넣을 수 있기 때문에 ViT는 이미지의 크기에 대해 자유롭습니다.

 

하지만 이 경우 pretrained Positional Encoding의 의미가 사라져서 해당 pretrained Positional Encoding에 2D interpolation을 적용하는 과정이 필요합니다.

Setup

ViT 모델은 크기에 대해 3가지로 나뉩니다. -> Base, Large, Huge
만약 ViT-L/16이라는 모델이 있다면 이는 크기가 Large이며 input patch size가 16x16인 ViT 모델을 뜻 합니다.

https://arxiv.org/pdf/2010.11929.pdf : An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale(2020) Dosovitskiy et al.

 

각 크기에 대한 parameters는 위 사진과 같고, 코드로는 아래와 같습니다.

 

import ml_collections

def get_ViT_B() :
    config = ml_collections.ConfigDict()
    config.num_layers = 12
    config.hid_dim = 768
    config.ff_dim = 3072
    config.n_heads = 12
    return config

def get_ViT_L() :
    config = ml_collections.ConfigDict()
    config.num_layers = 24
    config.hid_dim = 1024
    config.ff_dim = 4096
    config.n_heads = 16
    return config

def get_ViT_H() :
    config = ml_collections.ConfigDict()
    config.num_layers = 32
    config.hid_dim = 1280
    config.ff_dim = 5120
    config.n_heads = 16
    return config


def get_ViT_B_16():
    config = get_ViT_B()
    config.patch_size = 16
    return config

def get_ViT_B_32():
    config = get_ViT_B()
    config.patch_size = 32
    return config

def get_ViT_L_16():
    config = get_ViT_L()
    config.patch_size = 16
    return config

def get_ViT_L_32():
    config = get_ViT_L()
    config.patch_size = 32
    return config

def get_ViT_H_16():
    config = get_ViT_H()
    config.patch_size = 16
    return config

def get_ViT_H_32():
    config = get_ViT_H()
    config.patch_size = 32
    return config

Swin Transformer

Swin Transformer: Hierarchical Vision Transformer using Shifted Windows _ Ze Liu et al.
Code : https://github.com/yhy258/SwinTransformer_nonmask

 

자 어떤 물체가 있는 사진을 찍었다고 해봅시다. 근데 이 물체가 사진 전체에 있나요??

아니죠. 물체는 사진의 부분에 위치합니다.

 

즉, 어떤 픽셀이 있으면 그 픽셀과 거리가 멀어질 수록 당연히 관련도가 떨어지겠죠???

근데 굳이 ViT처럼 전체 사진에 대해 Self Attention을 진행 할 필요가 있을까요??(전체 사진에 대해 self attention을 진행하기 때문에 computational complexity가 높다.)

 

특히 사진 같은 vision task에 해당하는 data들은 resolution(해상도)가 자연어에 비해 굉장히 높기 때문에 계산 복잡도 또한 굉장히  높게됩니다. 따라서 Transformer를 그대로 가져다 사용하기에는 적합하지 않습니다 ㅜㅜ

 

이 단점을 해결하기 위한 방안으로 Swin Transformer라는 논문이 나오게 됩니다!

Introduction

언어 부분에서 좋은 성능을 내던 transformer는 vision에서 성능이 잘 안나오고는 합니다. 이에 대해 생각을 해보면 두 가지 정도 생각할 수 있는데

 

  • Vision은 scale을 포함한다. 이는 vision transformer에서 고정된 크기의 token을 사용하면 안됌을 의미한다.(resolution에 따라서 정보의 압축이 심해집니다.)
  • Vision의 경우는 text에 비해 high resolution이다. 이건 computational complexity 문제를 일으킨다.
    이 문제점을 기반으로해서 Swin Transformer를 제안합니다.

 

https://arxiv.org/pdf/2010.11929.pdf : Swin Transformer: Hierarchical Vision Transformer using Shifted Windows(2021) Ze Liu et al.

 

위 Figure 1은 Swin Transformer가 대략적으로 어떻게 진행되는지 보여주고 ViT와 비교해보여줍니다.
Swin Transformer의 경우 작은 패치에서 시작해서 주변과 merging 해 나갑니다.

 

이 때 window 단위로 self attention을 진행하기 때문에 complexity가 image size에 대해 linear해서, backbone으로 잘 역할을 수행할 수 있습니다.

 

https://arxiv.org/pdf/2010.11929.pdf : Swin Transformer: Hierarchical Vision Transformer using Shifted Windows(2021) Ze Liu et al.

 

Swin Transformer에서 가장 중요한 부분은 shifted window 입니다. 이를 통해 나눠진 windows간 연결을 해줄 수 있습니다.


위 Figure 2.에 보면 잘 나와있는데, 처음 레이어에서는 4개로 나눠진 window 단위로 self attention을 진행합니다. 만약 이대로만 진행한다면 나눠진 윈도우 사이의 connection을 표현 할 수 없습니다.

 

그래서 Fig 2.의 우측과 같은 shifted window scheme을 제안했습니다. window를 window_size//2 만큼 우측 하단으로 이동시켜 window를 새로 만들면 됩니다.

 

Overall Architecture

https://arxiv.org/pdf/2010.11929.pdf : Swin Transformer: Hierarchical Vision Transformer using Shifted Windows(2021) Ze Liu et al.

 

Figure 3에 Swin Architecture의 전체적인 구조가 나와있습니다.

 

Patch Splitting Module Patch Partition
RGB 이미지를 여러 Patches로 나눠줍니다. 이 때 Patches는 token으로 생각하면 됩니다.
patch size는 H, W 를 4로 나눈 값이고, 이로 인해 channel은 3x4x4가 됩니다. 이후 Linear Embedding을 통해 사용자 정의 channel 값으로 바꿔줍니다.

 

Swin Transformer Block
Figure 3.의 오른쪽 부분을 보면 Swin Transformer Block 구조에 대해 나와있습니다.
우선 window size에 맞게 분할해줘야 합니다.

 

def window_partition(x, window_size):
    # B, H, W, C : x.size -> B*Window_num, window_size, window_size, C
    B, H, W, C = x.size() 
    x = x.view(B, H // window_size, window_size, W // window_size, window_size, C)
    windows = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(-1, window_size, window_size, C)
    return windows

이 과정을 거친 후 Fig 3의 우측과 같은 과정에 들어가게 됩니다.

 

여기 Figure 3 에서 헷갈렸던 부분이 있었는데, 좌측 부분에서 Swin TransformerBlock x2 이렇게 되어있습니다. 여기에서 Figure 3의 우측 부분은 그 연속된 x2를 가리키는 것 이었습니다.

 

x2의 첫번째는 W-MSA를 사용하고, 두번째는 SW-MSA를 사용하는 것 입니다.

 

여기에서 주목해야할 건 W-MSA, SW-MSA인데, 앞에서 언급했던 shifted window scheme을 적용한 SW-MSA 모듈이 이 Swin Transformer의 핵심이라 생각 할 수 있습니다.

 

대략적인 Swin Transformer Block의 구조는 residual connection을 동반하고 있고, 각 Attention Module, MLP에 들어가기 전 LayerNorm을 적용해줍니다.

 

어차피 짝수로 Swin Transformer block이 들어가기 때문에, 저는 Figure 3 우측의 연속된 블럭을 한번에 나타냈습니다. 그래서 선언 할 때는 [2,2,6,2] 이런 식이 아닌 [1,1,3,1] 이런 식으로 선언했습니다.

Shifted Window based Self-Attention

이전의 standard transformer의 구조를 살펴보면 global computation을 동반합니다.
해당 query와 전역적인 key들과의 연산.

 

이 경우 quadratic complexity를 발생시키기 때문에, local computation을 진행 할 필요가 있습니다.
그래서 M 만큼 window size를 정하고 이에 대해서 self attention을 진행합니다.

 

Shifted window partitioning in succesive blocks
만약 shifted window scheme 없이 W-MSA를 진행한다면 이 떄 connection의 결함이 발생하게 됩니다.
이 결함을 극복하기 위해 cross-window connection을 사용합니다.

 

Figure 2.에 나와있는 것처럼 window size가 M일 때 ([M/2], [M/2])만큼 옮겨서 partition 해줍니다.

 

https://arxiv.org/pdf/2010.11929.pdf : Swin Transformer: Hierarchical Vision Transformer using Shifted Windows(2021) Ze Liu et al.

swin transformer block에서의 계산을 나타내는 equation 입니다.

 

Efficient batch computation for shifted configuration

https://arxiv.org/pdf/2010.11929.pdf : Swin Transformer: Hierarchical Vision Transformer using Shifted Windows(2021) Ze Liu et al.

 

W-MSA, SW-MSA를 진행하면서 여러 문제가 생길 수 있습니다.

 

만약 input size가 MxM보다 작거나 나누어 떨어지지 않는경우 naive하게 padding을 통해 해결 할 수 있습니다. (나누어 떨어지도록)

그리고 SW-MSA 에서는 cyclic shift를 사용해서 쉽게 나타낼 수 있습니다.

 

이를 기반으로 아래와 같이 SwinTransformerLayer를 구성했습니다.

 

class SwinTransformerLayer(nn.Module):
    def __init__(self, C, num_heads, window_size, ffn_dim, act_layer = nn.GELU, dropout = 0.1):
        super().__init__()
        self.mlp1 = Mlp(C, ffn_dim, act_layer=nn.GELU, drop=dropout)
        self.mlp2 = Mlp(C, ffn_dim, act_layer=nn.GELU, drop=dropout)

        self.norm1 = nn.LayerNorm(C)
        self.norm2 = nn.LayerNorm(C)
        self.norm3 = nn.LayerNorm(C)
        self.norm4 = nn.LayerNorm(C)


        self.shift_size = window_size // 2
        self.window_size = window_size
        self.W_MSA = SwinAttention(num_heads=num_heads, C=C, dropout=dropout )
        self.SW_MSA = SwinAttention(num_heads=num_heads, C=C, dropout=dropout )

    def forward(self, x): # BS, L, C
        BS, L, C = x.shape 
        S = int(math.sqrt(L))

        shortcut = x

        x = self.norm1(x) # BS, L, C

        x_windows = self.window_to_attention(x, S, C)

        attn_x = self.W_MSA(x_windows)

        x = self.attention_to_og(attn_x, S, C)

        x = x + shortcut

        shorcut = x

        x = self.norm2(x)
        x = self.mlp1(x)

        x = x + shortcut

        shortcut = x

        x = self.norm3(x)

        x_windows = self.window_to_attention(x, S, C ,shift=True) # cyclic shift for SW_MSA

        x_attn = self.SW_MSA(x_windows)

        x = self.attention_to_og(x, S, C ,shift=True) # reverse cyclic shift for SW_MSA

        x = x+ shortcut

        shortcut = x

        x = self.norm4(x)
        x = self.mlp2(x)

        return x + shortcut

    def window_to_attention(self, x, S, C, shift = False):
        x = x.view(-1, S, S, C)
        if shift :
            x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
        x_windows = window_partition(x, self.window_size)
        x_windows = x_windows.view(-1, self.window_size * self.window_size, C)
        return x_windows

    def attention_to_og(self, attn_x, S, C,shift=False):
        attn_x = attn_x.view(-1, self.window_size, self.window_size, C)
        x = window_reverse(attn_x, self.window_size, S, S)
        if shift :
            x = torch.roll(x, shifts=(self.shift_size, self.shift_size), dims=(1, 2))
        x = x.view(-1, S*S, C)
        return x
class SwinAttention(nn.Module):
    def __init__(self, num_heads, C, dropout):
        super().__init__()

        self.scale = C ** -0.5

        self.qkv = nn.Linear(C, C * 3, bias=True)
        self.num_heads = num_heads

        self.softmax = nn.Softmax(dim=-1)

        self.attn_drop = nn.Dropout(0.1)

        self.proj = nn.Linear(C, C)
        self.proj_drop = nn.Dropout(0.1)

    def forward(self ,x):# BS, L, C
        # x = [B, H, W, C]
        B, L, C = x.shape


        qkv = self.qkv(x).reshape(B, L, 3, self.num_heads, C//self.num_heads).permute(2,0,3,1,4) # 3, B, Head, L, C_v

        q, k, v= qkv[0], qkv[1], qkv[2]

        q = q*self.scale

        attn = (q @ k.transpose(-1,-2)) # dot product


        """
        여기서부터 attention 작업
        """

        attn_score = self.softmax(attn)
        attn_score = self.attn_drop(attn_score) # L, L
        # B, Head, L, C_v

        out = (attn @ v).transpose(1,2).flatten(-2) # B, L, C 


        out = self.proj(out)
        out = self.proj_drop(out)

        return out

 

Reference

Attention

Transformer

ViT

Swin Transformer

728x90