코딩/pytorch

[PyTorch] 간단한 pytorch lightning 설명과 템플릿 코드

curious_cat 2023. 7. 30. 01:02
728x90
728x90

간단한 pytorch lightning 설명

모델을 만들 때 PyTorch도 충분히 많은 기능을 제공해 주지만 그래도 직접 짜기 귀찮은 것들도 많습니다. 이런 것들을 pytorch lightning이 해줘서 개인적으로 요즘은 대부분 학습을 pytorch lightning을 사용해서 짜고 있습니다.

 

사용 방법

1. 우선 pytorch lightning의 LightningModule이라는 클래스를 작성한다.

  • configure_optimizer: 사용할 optimizer 세팅하고 리턴
  • training_step: 학습할 데이터 batch가 들어왔을 때 모델에 넣어서 loss를 계산하고 loss 값 리턴
  • validation_step: validation 데이터 batch가 들어왔을 때 loss 계산하고 loss 값 리턴
import pytorch_lightning as pl
class pl_model(pl.LightningModule):
	def __init__(self,):
    	...
    def configure_optimizers(self,):
    	...
    def training_step(self,batch,batch_idx):
    	...
    def validation_step(self,batch,batch_idx):
        ...
        
  model = pl_model()

2. 평소처럼 pytorch의 dataloader을 생성한다

trnDL = initDL(is_train=True) # initDL은 dataloader을 생성해주는 함수
valDL = initDL(is_train=False)

3. 그리고 pytorch lightning에 Trainer이라는 클래스를 생성해 주고 (Trainer을 통해서 logging, gpu, 등 다양한 세팅을 할 수 있다) Trainer.fit() 함수를 부르면 알아서 학습을 해주고 (training_step, validation_step 기반으로 loss를 계산하고 configure_optimizer에서 만든 optimizer로 model parameter을 업데이트 해줍니다)

trainer = pl.Trainer(...)
trainer.fit(model,trnDL,valDL)

 

추가 노트

  • 개인적으로 pytorch lightning을 사용하는 이유는 ddp와 amp를 직접 세팅하는 것이 귀찮아서인데, pytorch lightning을 사용하면 아주 간단하게 할 수 있습니다 (pl.Trainer 옵션에 추가하면 끝나는 수준입니다).
  • 자세한 설명을 달아둔 템플릿 코드를 zip 파일로 첨부하였으니 참고해주세요.
  • ddp (distributed data parallel), amp (automatic mixed precision), learning rate scheduler 도 사용을 하도록 세팅해두었습니다.

사용 방법

  1. pip install -r requirements.txt
  2. python train.py --config config/config.yml --device 0
    (GPU가 n개 있으면 --device 0 1 2 3 ... n-1)
  3. 학습이 시작하면
    tensorboard --logdir output/base/lightning_logs/
    로 로깅 파일을 볼 수 있습니다

참고로 이 코드를 그대로 실행시키면 learning rate scheduler의 step을 optimizer의 step보다 먼저 불렀다고 warning이 뜨는데, 이거는 pytorch lightning의 버그인 것 같습니다. 보통 문제를 일으키지 않기 때문에 무시했습니다. 하지만 없애고 싶으면 automatic optimization을 끄고 training_step에 직접 optimizer & scheduler을 순차적으로 불러주면 되기는 합니다.

더 자세한 사항들은 pytorch lightning의 documentation을 참고해주세요:

https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html

https://pytorch-lightning.readthedocs.io/en/stable/common/trainer.html

pl_template.zip
0.01MB

 

728x90
728x90