Abstract
오늘날 heavily overparameterized model에서, training loss의 값은 model generalization ability에 대해서 이점을 크게 주지 않는다.
이는 training loss에 대해서만 최적화를 시작하면 suboptimal model quality를 얻음을 뜻합니다.
논문에서는 loss landscape와 generalization의 connection에 영감을 받아, novel effective procedure를 제시 합니다. (SAM)
SAM은 낮은 loss를 갖는 neighborhoods를 찾고, 이 계산은 minmax optimization 결과를 갖습니다.
Introduction
현대 머신러닝의 발전으로, 모델들은 점점 더 무거운 overparameterization에 크게 의존합니다.
여러 부분에서의 SOTA를 살펴보면 보통 overparameterization이 요구됨을 알 수 있습니다.
이로 인해 generalize하는 절차가 필수적임을 알 수 있습니다.
오늘날 models의 training loss landscape는 보통 복잡하고 non convex합니다. 그래서 가능한 optimizer 사이의 선택이 중요한 design choice가 되었습니다.
이와 관련해서 training process를 modify하기 위한 여러 방법이 제안 되었습니다. (dropout, batch normalization, ...)
전부터, loss landscape와 minima의 평탄도, generalization 의 connection에 대해 연구가 되어왔습니다. 이 connection은 더 나은 generalization을 위한 새로운 approach에 대한 유망성을 가지고 있습니다.
논문에서는, 이를 바탕으로 기존에 존재하는 techniques에 대해 보충하는 효과적인 approach를 제시합니다.
SAM : loss value를 낮추고 동시에 sharpness도 최소화함으로써 model generalization 개선
논문에셔의 contributions
- SAM : loss value를 낮추고 동시에 sharpness도 최소화함으로써 model generalization 개선.
- 엄격한 경험적 연구를 통해 SAM을 사용하는 것이 Computer vision task의 넓은 방면에서 generalization ability를 개선시킴을 보여줌.
- SAM은 noisy label에 대해서도 robustness를 보임.
- SAM으로부터의 lens를 통해 "m-sharpness"라는 새로운 선명도 개념을 제시해서 loss sharpness와 generalization 사이의 connection을 설명.
Sharpness-Aware Minimization (SAM)
이 부분에서는 논문의 Appendix A의 PAC bayesian Generalization Bound에 대한 이해가 있어야하는데 아직 이해를 못해서,, 나중에 수정하겠습니다.
Probably Approximately Correct *"주어진 분류기는 test set에 대해 거의 정확할 것이다."*
Generalization에 대한 이야기.
이에 대해서 training data와 test data 모두 같은 분포에서 파생되었고, iid 관계를 갖는다고 가정.
Notation
일단은 D distribution으로부터의 training dataset
현대의 overparameterize된 models는 기본적인 train loss로 훈련하면 test시에 suboptimal 성능을 낼 수 있습니다.
그리고 이런 models는
이에 대해서
loss landscape의 sharpness와 generalization 사이의 connection에 영감을 받아서 단순히 training loss
$$
L_D(w)\leq\max_{||\epsilon||_2\leq \rho}L_S(w+\epsilon) + h({||w||_2^2 \above 1pt \rho^2})
$$
population loss는 위와 같이 나타낼 수 있습니다. 위 식에 따라, 우항을 minimize 한다면
참고로 우항의 두번째 h함수는 단조증가 function을 나타냅니다. (strictly increasing function)
그리고 위 식의 우항을 이해하기 좋게 아래와 같이 나타낼 수 있습니다.
위 식의
이제 의미적인 파악을 했으니 이를 다시 아래와 같이 적어보면,
where
으로 나타낼 수 있습니다.
이제 식에서 중요한
이를 통해 아래의 식을 봐보겠습니다.
위의 식에서
이제 여기에 대해 Dual norm problem을 적용하면,
이러한 식이 됩니다. training loss의 미분에 대한 식이기 때문에
이제
$
\nabla_w L_S(w)|_{w+\hat \epsilon (w)}
$
식이 위와 같이 자동미분으로 구할 수 있는 식으로 유도 됩니다. 이제 마지막 유도 부분의 뒷부분은 Hessian matrix 계산을 요구하는데, 이는 계산량을 너무 많이 필요로 해서 이 부분을 제외합니다.
$$
\nabla_w L_S(w)|_{w+\hat \epsilon (w)} + {d\hat \epsilon(w) \above 1pt dw}\nabla_w L_S(w)|_{w+ \hat \epsilon (w)}
$$
마치면서..
좀 계산적인 내용이 너무 많아서 공부하기 너무 힘들었습니다. Convex optimization의 내용인 dual norm이라던지.. 베이지안과 빈도주의를 모두 반영한 PAC bayesian Generalization Bound라던지..
이 논문은 좀 잡고 오래 꾸준히 보면서 공부해야 할 것 같습니다.
'Others' 카테고리의 다른 글
What is Local Optimality in Nonconvex-Nonconcave Minimax Optimization? - 내쉬 균형의 고찰과 새로운 local minimax의 정의 (0) | 2022.02.16 |
---|---|
최적화 공부 (0) | 2022.02.11 |
[논문 리뷰] : Manifold Mixup : Better Representations by Interpolating Hidden States (0) | 2021.05.20 |
[논문 리뷰] mixup : Beyond Empirical Risk Minimization (0) | 2021.05.20 |
Deep Neural Network 경량화 (Inception, Xception) (0) | 2021.05.19 |