Huggingface:TrainerCallback
TrainerCallback
Huggingface에서 Trainer라는 Class는 모델의 학습을 굉장히 쉽게 해줍니다. 하지만, logging 방식이 정해져 있기 때문에, 상황에 따라 logging을 다르게 할 수 있는 방법을 찾다가 TrainerCallback이라는 것을 알게 되었습니다. Huggingface에서 TrainerCallback이라는 Class를 제공해줍니다. TrainerCallback class를 subclassing해서 다양한 Callback Class들이 제공되고 있습니다. subclassing을 통해 다양한 callback method를 override 할 수 있기 때문에 inheritance에 대한 개념을 인지한다면 scratch부터 짜는 것처럼 짤 수 있다고 생각합니다. 이번 글에 나오는 코드는 huggingface의 docs에서 가져온 코드입니다. 지극히, 초보적인 Python Coder 관점에서 바라본 것이라 중급자나 고급자분들께서는 의아하실 수도 있다고 생각합니다. 좋게 봐주시면 감사하겠습니다. (자세한 정의는 나중에 좀 더 자세하게 다루도록 하겠습니다.)
Callback 이 어떻게 구현되었는지
Huggingface 에서 TrainerCallback이 어떻게 작동하는지 알아보도록 하겠습니다. TrainerCallback이라는 부모 클래스를 여러 개의 다른 callback 클래스들이 subcalssing을 해서 구현되어있습니다.
class DefaultFlowCallback(TrainerCallback):
"""
A :class:`~transformers.TrainerCallback` that handles the default flow of the training loop for logs, evaluation
and checkpoints.
"""
callback의 예시 중 하나인 DefaultFlowCallback을 가져와 보았습니다. TrainerCallback을 subclassing 했다는 것을 알 수 있습니다. Huggingface에서 제공하는 Callback은 저 개인적으로는 두 종류로 분류할 수 있다고 생각합니다. Huggingface에서 기본적으로 제공하는 Callback Class와 외부 서비스와 통합이 된 integration callback class들이 있다고 생각합니다. Huggingface repository를 들어가시면 trainer_callback.py 와 integrations.py 두 개로 나뉘어서 저장되어있음을 알 수 있습니다 . integrations.py에는 wadnb,mlfow,azure등 여러 metric logging 서비스와 통합이 되어있음을 알 수 있습니다.
- Default Callback
- DefaultFlowCallback
- ProgressCallback
- PrinterCallback
- EarlyStoppingCallback
- Integration Callback
- TensorboardCallback
- WandbCallback
- CometCallback
- AzureMLCallback
- MLFlowCallback
- CodeCarbonCallback
현재 repo 에 정의된 callback 클래스들입니다.
Callback 추가 및 삭제
Callback 클래스를 추가하게 되면 특정 조건을 만족할 때마다 callback이 호출되게 됩니다. Integration Callback의 경우 추가하는 방법은 Trainer를 instantiate 할 때 report_to argument에서 지정해주게 됩니다. 예를 들어, WandBCallback을 추가하고 싶으면 “wandb” 라는 argument를 넘겨주면 됩니다. Default Callback들의 경우 DefaultFlowCallback 은 자동으로 추가가 됩니다. 수동으로 Callback들을 추가하려면 Trainer의 add_callback이라는 method를 사용하게 된다면 Callback을 추가할 수 있습니다. Trainer의 remove_callback이라는 method를 사용하게 되면 Callback을 삭제할 수 있습니다.
Callback Method
Huggingface에서 한번 Trainer를 instantiate 하게 된다면 CallbackHandler,TrainerState,TrainerControl 가 Trainer instance의 attribute로 지정이 됩니다. callback handler로 여러 method를 호출하게 됩니다.다른 프레임워크는 잘 모르지만, Huggingface의 경우 on_init_end, on_train_begin , on_step_end , ..등등으로 callback의 method들이 훈련하는 도중에 특정 시점에 호출이 됩니다. TrainerState로 train 했을 때의 metric, best_checkpoint, epoch,step 등등을 넘기고 TrianserControl로 should_evaluation ,should_save, should_log 등등의 boolean 값으로 호출된 method를 수행할 지 정하게 됩니다.
예시를 보겠습니다.
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
예시를 보면, 위 코드는 step이 끝날 때 마다 callback_handler 가 on_step_end 라는 method를 호출하는 것을 볼 수 있습니다 .
def on_step_end(self, args: TrainingArguments, state: TrainerState, control: TrainerControl):
return self.call_event("on_step_end", args, state, control)
이런 식으로 on_step_end method가 call_event라는 method를 호출하게 됩니다.
def call_event(self, event, args, state, control, **kwargs):
for callback in self.callbacks:
result = getattr(callback, event)(
args,
state,
control,
model=self.model,
tokenizer=self.tokenizer,
optimizer=self.optimizer,
lr_scheduler=self.lr_scheduler,
train_dataloader=self.train_dataloader,
eval_dataloader=self.eval_dataloader,
**kwargs,
)
# A Callback can skip the return of `control` if it doesn't change it.
if result is not None:
control = result
return control
call_event에서 Trainer에 추가된 모든 callback들에 대하여 getattr을 통해서 event 에 해당하는 method가 있는지 확인을 합니다 . 여기에서는 이 method가 ‘on_step_end’ 가 되겠습니다. TrainerControl의 boolean값들을 result를 통해 변경해줍니다. (모든 event가 그렇지는 않겟지만, on_step_end 의 경우 boolean 값을 변경해줍니다.)
model.zero_grad()
self.state.global_step += 1
self.state.epoch = epoch + (step + 1) / steps_in_epoch
self.control = self.callback_handler.on_step_end(args, self.state, self.control)
self._maybe_log_save_evaluate(tr_loss, model, trial, epoch, ignore_keys_for_eval)
def _maybe_log_save_evaluate(self, tr_loss, model, trial, epoch, ignore_keys_for_eval):
if self.control.should_log:
logs: Dict[str, float] = {}
# all_gather + mean() to get average loss over all processes
tr_loss_scalar = self._nested_gather(tr_loss).mean().item()
# reset tr_loss to zero
tr_loss -= tr_loss
logs["loss"] = round(tr_loss_scalar / (self.state.global_step - self._globalstep_last_logged), 4)
logs["learning_rate"] = self._get_learning_rate()
self._total_loss_scalar += tr_loss_scalar
self._globalstep_last_logged = self.state.global_step
self.store_flos()
self.log(logs)
metrics = None
if self.control.should_evaluate:
metrics = self.evaluate(ignore_keys=ignore_keys_for_eval)
self._report_to_hp_search(trial, epoch, metrics)
if self.control.should_save:
self._save_checkpoint(model, trial, metrics=metrics)
self.control = self.callback_handler.on_save(self.args, self.state, self.control)
이후 다시 trainer의 stackframe으로 돌아오게 됩니다. 그리고 ,maybe_log_save_evaluate 라는 method를 통해서 control의 attribute의 따라서 logging을 하게 됩니다. 아직 시도를 해보지는 않았지만, Trainer Class의 _maybe_log_save_evaluate 랑 train method를 overriding을 하면 logging을 상황에 맞게 할 수 있습니다. 하지만, 단순히 train data에 대한 metric이 궁금한 수준이라면 , 저는 eval set을 train의 subset으로 사용해서 overfitting 여부를 확인하는게 더 효율적이라고 생각합니다
Custom Callback
class MyCallback(TrainerCallback):
"A callback that prints a message at the beginning of training"
def on_train_begin(self, args, state, control, **kwargs):
print("Starting training")
class PrinterCallback(TrainerCallback):
"""
A bare :class:`~transformers.TrainerCallback` that just prints the logs.
"""
def on_log(self, args, state, control, logs=None, **kwargs):
print(logs)
print(state.log_history)
_ = logs.pop("total_flos", None)
if state.is_local_process_zero:
print(logs)
기존의 Callback들을 그대로 사용해도 되지만 customize해서 나만의 callback 클래스를 만들 수 있습니다. 두번째 예시는 , 필자가 PrinterCallback을 새로 정의한 예시입니다 .하지만, method를 override를 할 때는 callback handler class에 정의되어 있는 method를 해주셔야 합니다. 위에서 봤듯이, callblackhandler에서 정의된 method에서 call_event를 호출을 합니다. 만약에 ,callbackhandler에 정의되지 않은 method를 callback class에서 override를 한다면 호출되지 않을 것입니다.
Leave a comment