코딩/pytorch

[PyTorch] learning rate scheduler 직접 짜기

curious_cat 2023. 2. 4. 23:36
728x90
728x90

PyTorch에 많은 learning rate scheduler이 있지만 직접 짜고 싶을 때도 있다. 

방법은 어렵지 않다.

torch.optim.lr_scheduler에 있는 _LRScheduler를 기반으로 class를 하나 만들면 된다. 미니멀하게는 __init__이랑 get_lr만 잘 정의해주면 된다.

linear하게 learning rate를 warmup해주는 scheduler을 예로 들어보자:

from torch.optim.lr_scheduler import _LRScheduler

class LinearWarmup(_LRScheduler):
    def __init__(self,optimizer,base_lr,warmup_steps,last_epoch=-1):
        self.base_lr = base_lr # warmup후 갖게되는 learning rate
        self.warmup_lr_init = 0.0001 # 처음에 갖게되는 learning rate
        self.warmup_steps = warmup_steps # warmup 할 총 step 수 (epoch x)
    	# optimizer을 사용해서 _LRScheduler을 initialize해주자
        # 보통 last_epoch는 -1로 두면 된다.
    	super().__init__(optimizer, last_epoch, False) 

    def get_warmup_lr(self):
    	""" learning rate 계산해주는 method"""
        # alpha: learning rate를 계산할 때 사용 할 multiplicative factor
        # alpha = (현제 step 수) / (warmup할 사용하는 총 step 수)
        # 밑에 식에서 self.last_epoch를 사용하지만 현제 step이라고 생각하면 된다 (나중에 추가 설명)
        alpha = float(self.last_epoch) / float(self.warmup_steps)
        _lr = self.base_lr * alpha
        return [_lr for _ in self.optimizer.param_groups] # optimizer에 있는 param group만큼 lr을 리턴

    def get_lr(self):
        if self.last_epoch == -1: # 처음 initialize됐을 때 self.warmup_lr_init 사용
            return [self.warmup_lr_init for _ in self.optimizer.param_groups]
        if self.last_epoch < self.warmup_steps: # 총 warmup할 step까지 get_warmup_lr 사용해서 lr 계산
            return self.get_warmup_lr()
        else:
            return [self.base_lr for _ in self.optimizer.param_groups] # 이후 base_lr 사용

 

사용 방법: optimizer이랑 scheduler을 정의했으면 같이 optimizer의 step()을 콜 하고 scheduler의 step()을 콜 하면 된다.

optimizer.step()
scheduler.step()

참고로 learning rate scheduler의 scheduler.step()을 콜 하면 last_epoch가 1씩 증가한다. 그래서 위에서 self.last_epoch를 기반으로 learning rate scheduler을 작성했다.

728x90
728x90