본문 바로가기

AI/ML PyTorch Scikit-Learn

Ch3. Overfitting Prevent Method

<Underfitting & Overfitting>

머신러닝에서 모델 학습 중, 학습 환경에 따라서 Underfitting 과 Overfitting이 발생할 수 있다

Fitting

 

 

 

Underfitting이 발생할 경우 train set에서 부터 정확도가 떨어지며,

Overfitting이 발생할 경우 train set을 과하게 학습하여

train set의 정확도는 매우 높게 나오지만 validation set의 정확도는 현저히 떨어지는 현상이 발생한다

 

model 학습의 최종 목표는 validation set에 대한 loss가 가장 적도록 학습을 해야한다!

Loss by Epochs

 

 

<How to prevent Overfitting?>

  • Hold-out

dataset을 train, validation, test로 분리하는 것을 Hold-out이라 한다

overfitting과 underfitting을 확인하려면 validation set이 필요하기에, 빠르게 모델을 학습할 때 사용되는 전통적인 방식이다

train : validation 비율은 보통 7:3 or 6:4 정도의 비율을 사용한다

Split dataset

 

 

  • Cross-validation

Hold-out에서는 dataset을 고정적으로 train, validation, test으로 나눈다

shuffle 하여 선택하겠지만, 여전히 데이터의 편향 가능성이 존재하며 이는 데이터 수가 부족할 수록 심해진다

따라서, 모든 데이터셋을 두루 검증하는 cross-validation이 등장한다

 

Cross Validation

 

 

  • Data augmentation

dataset의 크기는 클 수록 좋다

data augementation을 이용하여 dataset의 크기를 임의로 늘릴 수 있다

Data augmentation

 

 

  • Feature selection

모델이 너무 많은 feature에 대해 학습하게 되면 과적합 될 가능성이 높다

예를 들어, 냉장고와 컴퓨터를 분류하는데 색깔을 feature로 사용하게 되면 overfitting을 일으킨다

Feature selection

 

 

  • L1/L2 Regularization

model weight가 너무 큰 값을 가지게 되면 과하게 구불구불한 형태의 함수가 만들어진다

따라서 모델의 복잡도(weight의 크기)를 낮추기 위해 Regularization을 사용한다

$$ E(w)=E_{in}(w)+\frac{\lambda}{N}\Omega (w) $$

 

과적합은 학습 데이터 수가 부족해서 발생하기 때문에 $ N $을 분모에 나눈다

 

(L1 regularization)

$$ \Omega (w)=||w||_1=\sum_q|w_q| $$

 

(L2 regularization)

$$ \Omega (w)=||w||_2=w^Tw=\sum_qw_q^2 $$

 

 

L2 weight update를 할 때 다음과 같은 rule을 따른다

$$ w\leftarrow w-\eta\nabla E_{aug} (w) $$

$$ =w-\eta\nabla E_{train}(w)-2\eta\frac{\lambda}{N}w $$

$$=(1-2\eta\frac{\lambda}{N})w-\eta\nabla E_{train} (w) $$

 

 

$ 2\eta\frac{\lambda}{N} $ 만큼 원래 있던 weight가 decay 되는 것을 확인할 수 있다

 

 

  • Remove layers / number of units per layer

Regularization에서 언급했듯이, 모델이 복잡하면 overfitting을 야기할 가능성이 커진다

뉴런의 개수를 줄이는 방법으로 이러한 overfitting을 피할 수 있다

 

 

  • Dropout

전체 모델에서 특정 비율만큼 뉴런을 활성화 시킨다

확률적으로 뉴런이 사용될수도, 사용되지 않을수도 있기 때문에

특정 뉴런이 결과를 결정짓는(overfitting) 현상이 줄어들게 된다

 

당연하게도 모델의 일부분만을 사용하기 때문에 수렴속도는 느려진다

Dropout

 

 

  • Early stopping

위에서 사용했던 그림을 한번 더 살펴보자

Loss by Epochs



위의 그림 처럼 일정 epoch 이후에 validation loss가 상승하는 것을 확인할 수 있다

Overfitting이 발생했기 때문이며 다시 상승하기 직전에 학습을 중단하는 것을 early stopping 이라고 한다