PyTorch Lightning은 사람들이 PyTorch 코드들을 작성하는 게 다 스타일이 다르다 보니까 이런 스타일을 통일하고자 만든 라이브러리이다.
PyTorch의 기존 단점은 사람들마다 train 시키는 코드가 다 제각각이라는 것인데, 이것을 함수화 시켜서 train, validate, test까지 다 해결할 수 있게 만들었다.
그리고, 기존의 save, logging, checkpointing, hyperparameter logging, version management 같은 부분을 통일화시켜서 다른 사람들의 코드를 더 보기 쉽게 할 수 있다. 즉, 코드 자체가 재사용 가능하고 공유 가능해진다는 것이다.
게다가 제일 좋은 부분은, Lightning 자체가 PyTorch의 wrapper 형식으로 작동하기 때문에 full-scalable 해진다는 것이다. 이 말 뜻은, 어떤 부분을 Lightning으로 작성하기가 싫다면 PyTorch 형식으로 작성해도 작동한다는 것이다.
Lightning Lite라는 것도 있는데, 이는 Lightning의 부분적인 기능을 PyTorch에 가지고 오고 싶을 때 사용할 수 있다. 이를 사용하면 굳이 모든 코드를 Lightning 형식에 맞춰 짜지 않더라도 AMP, Apex 같은 mixed precision 같은 기능을 사용할 수 있다.
Lightning은 크게 보면 PyTorch 코드들의 train.py 부분을 보기 편하게 바꾸어 준다고 생각하면 편하다. DataLoader, Dataset, Loss 같은 부분들은 PyTorch 형식으로 짜서 Trainer에 넣어주면 간단하게 학습이 시작된다.
model = MyLightningModule()
trainer = Trainer()
trainer.fit(model, train_dataloader, val_dataloader)
. fit()이라는 메서드를 사용하여 scikit-learn과 비슷하게 학습을 시작할 수가 있다. LightningModule안에서는 학습 순서와 방법을 정의해주고, Trainer로 학습을 설정해준다(save, acceleration 등등).
trainer=pl.Trainer(
default_root_dir=HyperParameter.DEFAULT_ROOT,
accelerator='gpu',
callbacks=[checkpoint1,checkpoint2],
log_every_n_steps=HyperParameter.LOG_INTERVAL,
deterministic=True,
max_epochs=HyperParameter.EPOCH,
num_sanity_val_steps=0
)
.cuda() 같은 device 설정을 PyTorch처럼 안 해줘도 되고, 모든 seed 값을 설정해주는 함수도 있다. 터미널에서 tqdm으로 현재 훈련 상태를 자동으로 format도 해줘서 굳이 복잡하게 코드를 짤 필요도 없다.
그리고 tensorboard 지원이 기본적으로 들어있어서 self.log를 불러서 아주 쉽게 학습을 기록할 수가 있다. 여기서 필자가 제일 신기했던 기능이 있는데, self.save_hyperparameters()를 LightningModule 안에서 불러주기만 한다면 __init__으로 전달된 모든 hyperparameter들을 checkpoint에 저장해줘서 나중에 불러올 때 기록도 되고 바로 학습을 이어서 할 수가 있다.
checkpoint1=ModelCheckpoint(
monitor="Validation/f1",
filename='epoch{epoch:02d}-val_f1={Validation/f1:.3f}-val_acc={Validation/acc:.3f}',
save_top_k=3,
mode="max",
save_on_train_epoch_end=False,
auto_insert_metric_name=False
)
checkpoint2=ModelCheckpoint(
filename='last',
save_on_train_epoch_end=True
)
callback을 설정해주기만 하면 다양한 방식으로 체크포인트를 설정해줄 수 있고, 게다가 train을 돌릴 때마다 version이름을 바꾸어서 폴더를 새로 만들어준다. 그리고 어떤 log를 monitor 해주면서 상위 k개의 성능의 모델들만 저장할 수 있다. 이 것들은 기존의 PyTorch 코드로 짜면 조금 복잡하고 실수를 하기가 쉬운데 이렇게 모듈화를 하니까 인자 몇 개만 설정해주면 통일 되어있는 체계를 갖출 수가 있다.
LightningModule 안에서 설정해주는 다양한 training 순서들을 (training_step, training_epoch_end, validation_step, validation_epoch_end, configure_optimizers) 원하는 대로 custom이 가능해져서, 간단한 코드에 무언가를 추가 할 때 이것이 무엇을 하는지 쉽게 알 수 있다.
그리고, 이 회사가 만든 torchmetrics 라는 metric을 쉽게 계산할 수 있는 다른 라이브러리와 연동이 되어서 logging이 아주 쉬워진다.
자 이렇게 다양한 장점들을 말 해봤다. 그렇다면 PyTorch Lightning의 단점은 무엇일까?
첫째, 함수에 인자들이 너무 많아서 외우고 사용하는데 시간이 걸린다. 제공하는 기능이 많아서 이를 배우는데 시간이 조금 걸린다. 둘째, 아직 그렇게 주류는 아니기 때문에 이 코드를 이해 못하는 사람이 많을 수 있다.
결론: Lightning은 충분히 좋고 그냥 PyTorch보다 사용 가치가 좋은 것 같다. 게다가 짜기 귀찮은 부분들을 거의 다 해결해 줘서 그냥 그 자체로도 다른 라이브러리 없이 쓰기가 좋다. 아직 여러 코드들은 Lightning을 사용하지는 않지만 미래에는 꽤 좋은 라이브러리가 될 것이라 생각한다. 필자는 계속 Lightning을 배우는 여행을 떠나보려 한다.