Tensorflow

callback 함수

728x90

callback 함수란?

함수모델을 더 이상 학습을 못할 경우(loss, metric등의 개선이 없을 경우)나 오버피팅을 방지하기 위해 지정된 값에 도달하면 학습 도중 미리 학습을 종료시키는 함수.

class myCallback(tf.keras.callbacks.Callback):
  def on_epoch_end(self, epoch, logs = {}):
    if(logs.get('accuracy')>0.90):
      print('\n정확도가 90% 이상이면 학습을 멈춥니다.')
      self.model.stop_training = True
 my_cb = myCallback()

mycallback을 함수로 지정하여 사용.

model.fit(train_generator, steps_per_epoch=8, epochs = 15, verbose = 1, callbacks=[my_cb])

model의 accuracy가 90% 이상이 될 경우 학습이 종료된다.

 

728x90