DeepLearining

Learning Curve - 텐서플로우 딥러닝

macro 2021. 3. 12. 12:42
반응형

딥러닝 학습을 하면서, 각 에포크 마다의 밸리데이션 값을 확인해야 오버핏팅 ( Overfitting )이 있는지 눈으로 확인할 수 있다.

 

에포크(epoch)마다의 학습(training)과 밸리데이션(validation) 히스토리(history)를 저장하여, 학습 커브를 그려본다.

 

분류의 문제에서 loss 와 accuracy 를 확인하는데, train loss / validataion loss 와 train accuracy / validation accuracy 를 확인한다.

 

이를 함수로 작성해 놓고 사용하면 된다.

 

def learning_curve(history, epoch):
  plt.figure(figsize=(10,5))
  # 정확도 차트  
  epoch_range = np.arange(1, epoch + 1)

  plt.subplot(1, 2, 1)

  plt.plot(epoch_range, history.history['accuracy'])
  plt.plot(epoch_range, history.history['val_accuracy'])
  plt.title('Model Accuracy')
  plt.xlabel('Epoch')
  plt.ylabel("Accurach")
  plt.legend( ['Train', 'Val']  )
  # plt.show()

  # loss 차트
  plt.subplot(1, 2, 2)

  plt.plot(epoch_range, history.history['loss'])
  plt.plot(epoch_range, history.history['val_loss'])
  plt.title('Model Loss')
  plt.xlabel('Epoch')
  plt.ylabel("Loss")
  plt.legend( ['Train', 'Val']  )

  plt.show()

 

 

위의 함수를 실행하면, 다음과 같은 차트를 볼 수 있다.

 

 

반응형