<Underfitting & Overfitting>
머신러닝에서 모델 학습 중, 학습 환경에 따라서 Underfitting 과 Overfitting이 발생할 수 있다
Underfitting이 발생할 경우 train set에서 부터 정확도가 떨어지며,
Overfitting이 발생할 경우 train set을 과하게 학습하여
train set의 정확도는 매우 높게 나오지만 validation set의 정확도는 현저히 떨어지는 현상이 발생한다
model 학습의 최종 목표는 validation set에 대한 loss가 가장 적도록 학습을 해야한다!
<How to prevent Overfitting?>
- Hold-out
dataset을 train, validation, test로 분리하는 것을 Hold-out이라 한다
overfitting과 underfitting을 확인하려면 validation set이 필요하기에, 빠르게 모델을 학습할 때 사용되는 전통적인 방식이다
train : validation 비율은 보통 7:3 or 6:4 정도의 비율을 사용한다
- Cross-validation
Hold-out에서는 dataset을 고정적으로 train, validation, test으로 나눈다
shuffle 하여 선택하겠지만, 여전히 데이터의 편향 가능성이 존재하며 이는 데이터 수가 부족할 수록 심해진다
따라서, 모든 데이터셋을 두루 검증하는 cross-validation이 등장한다
- Data augmentation
dataset의 크기는 클 수록 좋다
data augementation을 이용하여 dataset의 크기를 임의로 늘릴 수 있다
- Feature selection
모델이 너무 많은 feature에 대해 학습하게 되면 과적합 될 가능성이 높다
예를 들어, 냉장고와 컴퓨터를 분류하는데 색깔을 feature로 사용하게 되면 overfitting을 일으킨다
- L1/L2 Regularization
model weight가 너무 큰 값을 가지게 되면 과하게 구불구불한 형태의 함수가 만들어진다
따라서 모델의 복잡도(weight의 크기)를 낮추기 위해 Regularization을 사용한다
$$ E(w)=E_{in}(w)+\frac{\lambda}{N}\Omega (w) $$
과적합은 학습 데이터 수가 부족해서 발생하기 때문에 $ N $을 분모에 나눈다
(L1 regularization)
$$ \Omega (w)=||w||_1=\sum_q|w_q| $$
(L2 regularization)
$$ \Omega (w)=||w||_2=w^Tw=\sum_qw_q^2 $$
L2 weight update를 할 때 다음과 같은 rule을 따른다
$$ w\leftarrow w-\eta\nabla E_{aug} (w) $$
$$ =w-\eta\nabla E_{train}(w)-2\eta\frac{\lambda}{N}w $$
$$=(1-2\eta\frac{\lambda}{N})w-\eta\nabla E_{train} (w) $$
$ 2\eta\frac{\lambda}{N} $ 만큼 원래 있던 weight가 decay 되는 것을 확인할 수 있다
- Remove layers / number of units per layer
Regularization에서 언급했듯이, 모델이 복잡하면 overfitting을 야기할 가능성이 커진다
뉴런의 개수를 줄이는 방법으로 이러한 overfitting을 피할 수 있다
- Dropout
전체 모델에서 특정 비율만큼 뉴런을 활성화 시킨다
확률적으로 뉴런이 사용될수도, 사용되지 않을수도 있기 때문에
특정 뉴런이 결과를 결정짓는(overfitting) 현상이 줄어들게 된다
당연하게도 모델의 일부분만을 사용하기 때문에 수렴속도는 느려진다
- Early stopping
위에서 사용했던 그림을 한번 더 살펴보자
위의 그림 처럼 일정 epoch 이후에 validation loss가 상승하는 것을 확인할 수 있다
Overfitting이 발생했기 때문이며 다시 상승하기 직전에 학습을 중단하는 것을 early stopping 이라고 한다
'AI > ML PyTorch Scikit-Learn' 카테고리의 다른 글
Ch3. SVM(Support Vector Machine) (6) | 2024.07.23 |
---|---|
Ch3. Logistic Regression & Sigmoid (0) | 2024.07.11 |
Ch3. Scikit-learn (0) | 2024.07.08 |
Ch2. Adaline, 로스와 경사하강법(Adaline, Loss and Gradient Descent) (0) | 2024.07.02 |
Ch2. 퍼셉트론(Perceptron) (0) | 2024.06.27 |