z [논문 리뷰] SAM Optimizer : Sharpness-Aware Minimization for Efficiently Improving Generalization
본문 바로가기

Others

[논문 리뷰] SAM Optimizer : Sharpness-Aware Minimization for Efficiently Improving Generalization

728x90

Abstract

오늘날 heavily overparameterized model에서, training loss의 값은 model generalization ability에 대해서 이점을 크게 주지 않는다.
이는 training loss에 대해서만 최적화를 시작하면 suboptimal model quality를 얻음을 뜻합니다.
논문에서는 loss landscape와 generalization의 connection에 영감을 받아, novel effective procedure를 제시 합니다. (SAM)
SAM은 낮은 loss를 갖는 neighborhoods를 찾고, 이 계산은 minmax optimization 결과를 갖습니다.

Introduction

현대 머신러닝의 발전으로, 모델들은 점점 더 무거운 overparameterization에 크게 의존합니다.
여러 부분에서의 SOTA를 살펴보면 보통 overparameterization이 요구됨을 알 수 있습니다.
이로 인해 generalize하는 절차가 필수적임을 알 수 있습니다.


오늘날 models의 training loss landscape는 보통 복잡하고 non convex합니다. 그래서 가능한 optimizer 사이의 선택이 중요한 design choice가 되었습니다.
이와 관련해서 training process를 modify하기 위한 여러 방법이 제안 되었습니다. (dropout, batch normalization, ...)

전부터, loss landscape와 minima의 평탄도, generalization 의 connection에 대해 연구가 되어왔습니다. 이 connection은 더 나은 generalization을 위한 새로운 approach에 대한 유망성을 가지고 있습니다.
논문에서는, 이를 바탕으로 기존에 존재하는 techniques에 대해 보충하는 효과적인 approach를 제시합니다.

SAM : loss value를 낮추고 동시에 sharpness도 최소화함으로써 model generalization 개선

논문에셔의 contributions

  • SAM : loss value를 낮추고 동시에 sharpness도 최소화함으로써 model generalization 개선.
  • 엄격한 경험적 연구를 통해 SAM을 사용하는 것이 Computer vision task의 넓은 방면에서 generalization ability를 개선시킴을 보여줌.
  • SAM은 noisy label에 대해서도 robustness를 보임.
  • SAM으로부터의 lens를 통해 "m-sharpness"라는 새로운 선명도 개념을 제시해서 loss sharpness와 generalization 사이의 connection을 설명.

Sharpness-Aware Minimization (SAM)

이 부분에서는 논문의 Appendix A의 PAC bayesian Generalization Bound에 대한 이해가 있어야하는데 아직 이해를 못해서,, 나중에 수정하겠습니다.

Probably Approximately Correct *"주어진 분류기는 test set에 대해 거의 정확할 것이다."*
Generalization에 대한 이야기.
이에 대해서 training data와 test data 모두 같은 분포에서 파생되었고, iid 관계를 갖는다고 가정.

Notation

\(S\triangleq U^n_{i=1}{(x_i,y_i)}\) : distribution D로부터의 training dataset
\(L_S(w)\triangleq {1 \above 1pt n}\sum^n_{i=1}l(w,x_i,y_i)\) : training loss
\(L_D(w)\triangleq E_{(x,y)\sim D}[l(w,x,y)]\) : population loss

일단은 D distribution으로부터의 training dataset \(S\triangleq U^n_{i=1}{(x_i,y_i)}\)를 고려할 때 잘 generalization 되는 모델을 학습하고자 합니다.

현대의 overparameterize된 models는 기본적인 train loss로 훈련하면 test시에 suboptimal 성능을 낼 수 있습니다.
그리고 이런 models는 \(L_S(w)\)는 w에 대해non convex 합니다. 이는 multiple local minima를 갖음을 의미하고, 심지어 여러 global minima를 갖기도 합니다.

이에 대해서 \(L_S(w)\)와 유사하지만 generalization performance에 대한 값 \(L_D(w)\) (population loss)를 갖을 수 있습니다.

loss landscape의 sharpness와 generalization 사이의 connection에 영감을 받아서 단순히 training loss \(L_s(w)\)를 최소화 하는 w를 구하는게 아니라 neighborhoods 전체가 낮은 loss를 갖는 w를 찾는 방법을 제시합니다. (neighborhoods는 low loss와 low curvature(곡률)을 갖습니다.)


$$

L_D(w)\leq\max_{||\epsilon||_2\leq \rho}L_S(w+\epsilon) + h({||w||_2^2 \above 1pt \rho^2})

$$


population loss는 위와 같이 나타낼 수 있습니다. 위 식에 따라, 우항을 minimize 한다면 \(L_D(w)\)가 낮아져서 generalization이 증가하게 됩니다.
참고로 우항의 두번째 h함수는 단조증가 function을 나타냅니다. (strictly increasing function)

그리고 위 식의 우항을 이해하기 좋게 아래와 같이 나타낼 수 있습니다.
$$
[\max_{||\epsilon||_2\leq \rho} L_S(w+\epsilon)-L_S(w)] + L_S(w) + h({||w||_2^2 \above 1pt \rho^2})
$$

위 식의 \([\max_{||\epsilon||_2\leq \rho} L_S(w+\epsilon)-L_S(w)]\)는 주변의 loss에서 현재의 loss를 뺐으니까 sharpness를 의미하고, 논문에서는 \(h({||w||_2^2 \above 1pt \rho^2})\)를 단순한 일반적인 L2 정규화라 했습니다.

이제 의미적인 파악을 했으니 이를 다시 아래와 같이 적어보면,
$$
\min_w L^{SAM}_S(w) + \lambda ||w||^2_2
$$
where \(L^{SAM}_S(w) \triangleq \max_{||\epsilon||\leq \rho}L_S(w+\epsilon)\)
으로 나타낼 수 있습니다.

이제 식에서 중요한 \(\epsilon\) 부분에 대해서 봐보겠습니다.
\(\epsilon\) 은 0 주변의 값이기 때문에 1차 Taylor 근사식으로 근사할 수 있습니다.
이를 통해 아래의 식을 봐보겠습니다.

$$
\epsilon^{*} \triangleq \arg\max_{||\epsilon||_p \leq \rho}L_S(w+\epsilon)
$$

$$
\approx \arg\max_{||\epsilon||_p \leq \rho}L_S(w) + \epsilon^T \nabla_w L_S(w)
$$

$$
= \arg\max_{||\epsilon||_p \leq \rho}\epsilon^T \nabla_w L_S(w)
$$

위의 식에서 \(L_S (w)\)는 이 식에서 상관 없는 부분이기 때문에 제외해서 가장 우항의 식처럼 나옵니다.
이제 여기에 대해 Dual norm problem을 적용하면,

$$
\epsilon^{*} = \rho ~ sign(\nabla_w L_S(w)){|\nabla_w L_S(w)|^{q-1} \above 1pt (||\nabla_w L_S(w)||^q_q)^{1/p}}
$$

이러한 식이 됩니다. training loss의 미분에 대한 식이기 때문에 \(\epsilon\)을 사용할 수 있게 되었습니다.
이제 \(\epsilon\)을 사용할 수 있게 되었으니 본래 우리의 목표인 식으로 돌아가보겠습니다.

$$
\nabla_w L^{SAM}_S(w) \approx \nabla_w L_S(w+\hat \epsilon(w)) = {d(w + \hat \epsilon(w)) \above 1pt dw} \nabla_w L_S(w)|_{w+\hat \epsilon}
$$

 

$

\nabla_w L_S(w)|_{w+\hat \epsilon (w)} 

$

 

식이 위와 같이 자동미분으로 구할 수 있는 식으로 유도 됩니다. 이제 마지막 유도 부분의 뒷부분은 Hessian matrix 계산을 요구하는데, 이는 계산량을 너무 많이 필요로 해서 이 부분을 제외합니다.

 

$$

\nabla_w L_S(w)|_{w+\hat \epsilon (w)} + {d\hat \epsilon(w) \above 1pt dw}\nabla_w L_S(w)|_{w+ \hat \epsilon (w)}

$$

 

$$
\nabla_w L^{SAM}_S(w) \approx \nabla_w L_S(w)|_{w+\hat\epsilon (w)}
$$

마치면서..

좀 계산적인 내용이 너무 많아서 공부하기 너무 힘들었습니다. Convex optimization의 내용인 dual norm이라던지.. 베이지안과 빈도주의를 모두 반영한 PAC bayesian Generalization Bound라던지..
이 논문은 좀 잡고 오래 꾸준히 보면서 공부해야 할 것 같습니다.

728x90