Centric Software

Batch Normalizaion(2015) 설명 본문

Paper Reviews

Batch Normalizaion(2015) 설명

jh-rrr 2023. 9. 3. 08:44

왜 사용할까?

batch normalization을 사용하는 이유는 1)학습 효율향상, 2)Regularization 때문입니다.

다시말해, batch normalization을 적용 시 1)모델이 훨씬 빠르게 학습되고 2)general feature또한 더 잘 찾아내 성능도 더 좋아진다고 생각해 볼 수 있을 것 같습니다.

 

어떻게 가능할까?

어떻게 이것이 가능할까요?

딥러닝의 학습 방식에는 한가지 문제가 있습니다. 바로 학습이 되면서 레이어들의 parameter가 계속해서 업데이트 되기 때문에, 각 레이어들의 output 또한 매번 다른 분포를 출력하게 되어, 그 값의 범위가 들쭉날쭉 해진다는 점 인데요.(internal covariate shift) 딥러닝은 레이어의 출력에 비선형함수를 달기 때문에 출력값들이 비선형 함수를 통해 다듬어지게(saturation) 됩니다. 예를들어, 한 레이어의 출력이 (10,20,30,1,2) 일 경우, sigmoid 함수를 통과하게 되면 (1,1,1,0.7,0.9)로 바뀌어 첫 10, 20, 30값들은 정보손실을 겪었다고 볼 수 있는것 입니다. 특히 이렇게 정보손실을 겪은 feature들이 중요한 feature일 경우, 학습의 비효율은 더 심각해 질 것 이며, 많은 값들이 다듬어질수록 모델은 학습과정에서 더 많은 정보손실을 겪게될 것 입니다.

0이하의 값에서 0에 수렴하는 대부분의 비선형 함수

저자들은 이 문제를 다음과 같이 접근합니다.

"비선형함수가 이런 학습과정에서 계속 변하는 출력분포를 편향 없이 공평하게 바라볼 수 있도록 도와줄 수 있다면, 모델은 어떤 feature가 중요하고 아닌지 더 빠르게 파악할 수 있지 않을까?"

 

그래서 다음과 같이 레이어의 output분포를 정규화 시켜주는 batch normalization 알고리즘을 제안합니다.

BN 알고리즘

Batch Normalization 의 핵심은 다음과 같습니다. "모델 학습 시, 어떤 Layer의 출력에 대해, mini-batch 내 출력값들의 분포로 각 데이터들을 정규화시켜준다."

하지만, output이 평균0, 표준편차1을 가지도록 정규화 시켜주는 것이 오히려 데이터 분포의 특징을 헤칠 수 있기 때문에 저자는 γ̂와 β라는 학습가능한 파라미터를 통해 그 결과를 다시 조정할 수 있도록 완충해주었습니다.

그리고 이 방식은 아시다 싶이 학습 효율을 크게 향상시키고, 심지어 Regularization기능 까지 수행하는 경우도 보이게 됩니다.

정량적 결과는 다음과 같습니다.

(a) MNIST에 대한 학습 iteration 별 test accuracy입니다. BN을 사용한 경우가 훨씬 빠르게 optimal한 위치에 도달합니다.

(b), (c) 모델 내 sigmoid의 input 분포입니다.BN을 사용한 경우가 더 안정된 분포를 보입니다.

1) 역전파 문제 없을까?

Chain-rule을 활용해, 각 Parameter에 대한 gradient를 구해줄 수 있습니다.

2) 한개의 sample만 들어오는 inference 상황에선 분포를 수정 못할텐데?

따라서 학습 시 사용된 μB와 σB의 평균을 취해 inference 시 활용하게 됩니다. 이렇게 하면 전체 BN계산이 선형계산이 되어, FC, Conv layer과 계산과정을 합쳐줄 수 있게 됩니다. 따라서 BN을 추가적으로 적용하는데 있어 inference 시 계산 부담이 없다는 점을 알 수 있습니다.

 

 

 

References

이미지 출처: https://devlifenote.co.kr/ai-dl-12%EA%B0%80%EC%A7%80-%ED%95%84%EC%88%98-activation-functions-%ED%99%9C%EC%84%B1%ED%99%94-%ED%95%A8%EC%88%98-%EC%A0%95%EB%A6%AC/

논문 출처: https://arxiv.org/abs/1502.03167