딥러닝 모델을 callback 없이 개발하는 것은 브레이크 없는 차를 운전하는 것과 같다
만약 딥러닝 모델을 개발하고 있다고 해보자. 만약 머리 속에 시간이 많이 소요되는 훈련 중에 프로세스를 조정할 방법에 대해 아무 생각이 없다면 그 프로젝트는 바람직하지 않은 방향으로 끝날 가능성이 높다. 이번 호 에서는 Keras Callback를 통하여 ModelCheckpoint, EarlyStopping 등의 함수들을 사용하여 어떻게 모델을 모니터하고 개선하는지 알아보도록 한다.
Callback 이란?
Callback 이란 훈련 절차의 주어진 단계에서 적용되는 일련의
함수들을 말한다. Callback 을 통하여 훈련 중인 모델의 내부 상태나 통계를 볼 수 있다. 예를 들어, 1000 번의 epoch
가 있는 훈련을 매번 끝날 때까지 시간을 소요하며 돌려보고 그 정확도나 손실을 확인한다면 무척 시간이 아까울것이다. 만약 정확도나 손실이 어떤 단계에 다다르면 훈련을 중단하게 한다면 어떨까? 훈련이
성공적으로 끝난 모델을 저장하는 방법은 없을까? 혹은 learning
rate 를 물론 decay 를 설정하고 epoch 마다
조정하는 방법도 있지만 시간에 따라 조절되게 할 순 없을까?
이러한 예들 처럼,
매 단계의 훈련/epoch 과정에 훈련 프로세스를 조정할 수 있게 도와주도록 어떤 자동화된
작업을 정의하고 사용하는 것을 callback 이 맡아서 해준다.
과적합(overfitting)을 방지하는 EarlyStopping
과적합을 방지하기위해 EarlyStopping을 사용하여 훈련을 중지했을
때, 모델의 validation error 가 더 이상 낮아지지
않도록 조절할 수 있었지만, 중지된 상태가 최고의 모델은 아닐 수 있다. 따라서 가장 validation 성능이 좋은 모델을 저장하는 것이
필요한데, keras 에서는 이를 위해 ModelCheckpoint 라는
객체를 제공한다. 이 객체는 validation_error 를
모니터하면서 이전 epoch 에 비해 validation 성능이
좋으면 무조건 이때의 parameter 들을 저장한다. 이를
통해 훈련이 중지되었을 때, 가장 validation 성능이
높았던 모델을 저장할 수 있다.
이 ModelCheckpoint
인스턴스와 앞의 EarlyStopping 인스턴스를 Callbacks
파라미터에 넣어 줌으로써 가장 validation 성능이 좋았던 모델을 저장하게 된다.
Callbacks 라는 키워드 인자를 사용하여 여러 개의 callbacks 을 모델의 .fit() 메소드에 전달할 수 있다. 즉, EarlyStopping, ModelCheckpoint 그리고 TensorBoard callbacks 들을 my_callbacks 으로 정의하고 이를 model.fit 메소드에 callbacks=my_callbacks 인자로 넘겨줌으로써 3가지 callbacks 함수를 한꺼번에 사용할 수 있다.
텐서보드(TensorBoard)는 모델을 이해하고 디버깅 및 최적화를 돕기 위한 시각화도구로 아래와 같은 화면을 제공함으로써 epoch 가 진행되면 서의 정확도와 손실을 손쉽게 파악하게 도와준다.
텐서플로우 와 텐서플로우 callbacks에 대한 내용은 뉴스레터 2호 에서 다루었다. 보다 자세한 사항은 뉴스레터를 클릭하여 살펴보기 바란다.
결언
이상 살펴본 바와 같이, Keras 의callbacks 을 통해서 과적합을 방지하기위한 EarlyStopping 함수를 사용하면 validation loss 의 성능이 더 이상 증가하지 않는 시점에서 몇 번의 epoch 수를 기다리게 할 것인가를 정함으로써 과적합이 일어나기 전의 최고 성능의 weight를 저장하게 된다. 하지만 EarlyStopping 을 사용하여 훈련을 중지할 수 있지만, 이때가 최고의 모델을 아닐 수 있다.
이때, ModelCheckPoint 함수를
사용하여 validation error를 모니터 하면서 이전 epoch에
비하면 성능이 좋으면 무조건 parameter를 저장하는 기능을 사용할 수 있다. 그리고 tensorboard callback을 통하여 디버깅 및 최적화를 돕기 위한 시각화도구를 제공받는
것을 살펴보았다. 그리고 여러 개의 callbacks 를
모델의 .fit() 메소드를 통하여 전달하여 한꺼번에 사용하는 방법을 살펴보았다.