Tensorflow

Model Checkpoint와 CSVLogger

728x90

Model Checkpoint와 CSVLogger는 모두 콜백함수중 하나이다.

 

 

Model Checkpoint - 모델을 저장할때 사용되는 콜백함수.

 

from tensorflow.keras.callbacks import ModelCheckpoint

cp = ModelCheckpoint(filepath=CHECKPOINT_PATH, monitor='val_accuracy',
                     save_best_only = True,
                     verbose = 1)

 

 

인자 설명

 

filepath - 모델을 저장할 경로를 입력함.
추가 설명으로 만약 monitor가 val_loss일 때,
모델 경로를 '{epoch:02d}-{val_loss:.5f}.h5' 라고 입력하면, 에폭-해당에폭에서의 val_loss.h5로 모델이 저장됨. 예: 01-0.39121.h5

monitor - 모델을 저장할 때, 기준이 되는 값을 지정함.
예를 들어, validation set의 loss가 가장 작을 때 저장하고 싶으면 'val_loss'를 입력하고
만약 train set의 loss가 가장 작을 때 모델을 저장하고 싶으면 'loss'를 입력한다.
이 외에도 다양한 값들을 기준으로 삼을 수 있다.

verbose - 0, 1
1일 경우 모델이 저장 될 때, '저장되었습니다' 라고 화면에 표시되고,
0일 경우 화면에 표시되는 것 없이 그냥 바로 모델이 저장됨.

save_best_only - True, False
True 인 경우, monitor 되고 있는 값을 기준으로 가장 좋은 값으로 모델이 저장됨.
False인 경우, 매 에폭마다 모델이 filepath{epoch}으로 저장된다. (model0, model1, model2....)

save_weights_only - True, False
True인 경우, 모델의 weights만 저장됨.
False인 경우, 모델 레이어 및 weights 모두 저장됨.

mode - 'auto', 'min', 'max'
val_acc 인 경우, 정확도이기 때문에 클수록 좋다. 따라서 이때는 max를 입력해줘야한다.
만약 val_loss 인 경우, loss 값이기 때문에 값이 작을수록 좋다. 따라서 이때는 min을 입력해줘야한다.
auto로 할 경우, 모델이 알아서 min, max를 판단하여 모델을 저장한다.

save_freq - 'epoch' 또는 integer(정수형 숫자)
'epoch'을 사용할 경우, 매 에폭마다 모델이 저장된다.
integer을 사용할 경우, 숫자만큼의 배치를 진행되면 모델이 저장된다.
예를 들어 숫자 8을 입력하면, 8번째 배치가 train 된 이후, 16번째 배치가 train 된 이후 ..... 모델이 저장된다.

options tf.train.CheckpointOptions를 옵션으로 줄 수 있다.
분산환경에서 다른 디렉토리에 모델을 저장하고 싶을 경우 사용한다. 자세한 내용은 아래 링크를 참조.

 

CSVLogger

jupyter notebook에서 모델학습을 시키다 페이지 상태 변경이 일어나면 output에서 더 이상 학습진행 상황을 못볼수가 있다.

아래와 같이 로깅 콜백을 걸어두면 매 epoch마다 파일 로깅을 진행하여, 브라우저를 닫아버린 뒤에도 학습 상황을 확인할 수 있다.

from keras.callbacks import CSVLogger
 
csv_logger = CSVLogger('./log.csv', append=True, separator=';')
 
hist = model.fit_generator(training_set,
                         steps_per_epoch = 20,
                         epochs = 1000,
                         validation_data = validation_set,
                         validation_steps = 10,
                         callbacks=[csv_logger])

 

 

www.tensorflow.org/api_docs/python/tf/keras/callbacks/ModelCheckpoint  

tf.keras.callbacks.ModelCheckpoint  |  TensorFlow Core v2.4.1

Callback to save the Keras model or model weights at some frequency.

www.tensorflow.org

archive.htrucci.com/1235/keras-jupyter-%EC%A3%BC%ED%94%BC%ED%84%B0%EC%97%90%EC%84%9C-%EC%BC%80%EB%9D%BC%EC%8A%A4-%ED%95%99%EC%8A%B5%EC%83%81%ED%99%A9-%EB%A1%9C%EA%B9%85%ED%95%98%EA%B8%B0/archive.htrucci.com/1235/keras-jupyter-%EC%A3%BC%ED%94%BC%ED%84%B0%EC%97%90%EC%84%9C-%EC%BC%80%EB%9D%BC%EC%8A%A4-%ED%95%99%EC%8A%B5%EC%83%81%ED%99%A9-%EB%A1%9C%EA%B9%85%ED%95%98%EA%B8%B0/

 

728x90