코딩/pytorch

[numpy][pytorch] np.memorymap로 빠른 dataloading

curious_cat 2023. 1. 29. 13:50
728x90
728x90

학습할 때 dataloading 이 생각보다 시간을 많이 잡아먹습니다. 물론 dataloading을 가장 빠르게 하는 방법은 모든 data를 RAM에 올려서 작업하는 방식이지만 현실적으로 큰 데이터 셋을 다루다 보면 이것은 불가능하죠.

 

Dataloading을 빠르게 할 수 있는 방법 중 하나가 이번 포스트에서 소개할 memory map을 사용하는 것입니다.

class numpy.memmap(filename, dtype=<class 'numpy.ubyte'>, mode='r+', offset=0, shape=None, order='C')

memory map은 큰 데이터를 다 메모리로 읽어오지 않고, 디스크에 저장해둔 다음 빠르게 데이터를 읽어올 때 사용합니다. 

 

memory map 파일을 만드는 코드:

from tqdm import tqdm
import numpy as np

# 학습을 위해서 이미지와 이미지에 대응되는 레이블에 대한 메모리 맵 파일을 만든다
mmap_img_path = 'img.dat' # 이미지에 대한 메모리맵 경로
mmap_lb_path = 'lb.dat' # 레이블에 대한 메모리맵 경로

num_data = 100 # 데이터 수; 임의로 100으로 설정함

img_size = (256,256,3) # 이미지 사이즈
img_mmap_shape = (num_data,*img_size) # 메모리맵으로 저장된 이미지 array 사이즈

lb_size = 10 # 레이블 사이즈; 임의로 이미지 당 10개의 레이블이 있다고 가정
lb_mmap_shape = (num_data,lb_size) # 메모리맵으로 저장된 레이블 array 사이즈

# 메모리맵 사이즈들을 저장
np.save('img_mmap_shape.npy',np.array(img_mmap_shape))
print('img_mmap_shape:',img_mmap_shape)
np.save('lb_mmap_shape.npy',np.array(lb_mmap_shape))
print('lb_mmap_shape:',lb_mmap_shape)

# 메모리맵을 만든다. 주의점:
# 1. 데이터 타입을 알맞게 설정해주기
# 2. mode: w+는 파일을 만들거나 파일에 write할 때 사용하는 모드
# 3. 메모리맵에 할당되는 총 사이즈를 미리 정해야함.
mmap_img = np.memmap(mmap_img_path,dtype=np.uint8,mode='w+',shape=img_mmap_shape)
mmap_lb = np.memmap(mmap_lb_path,dtype=np.float32,mode='w+',shape=lb_mmap_shape)

# 편의 상 data가 npz 형태로 저장되어있다고 가정하자: https://numpy.org/doc/stable/reference/generated/numpy.savez.html
npz_paths = ['0.npz', '1.npz',...,'99.npz'] # 100개의 npz 파일들 경로

for idx,npz_path in enumerate(tqdm(npz_paths)):
    npz = np.load(npz_path) # npz 파일을 불러오고
    img = npz['img']
    mmap_img[idx][:]=img # memory map에 이미지 데이터 저장
    lb = npz['label']
    mmap_lb[idx][:]=lb # memory map에 레이블 데이터 저장

memory map 사용 코드 (pytorch Dataset)

import numpy as np
from torch.utils.data import Dataset

class MMDataset(Dataset):
    def __init__(self):
        """ 
        use memory map to load data
        """
        super().__init__()
        # 저장해둔 이미지와 레이블에 대한 memory map 사이즈를 불러온다
        self.img_mmap_shape = np.load('img_mmap_shape.npy'))
        self.lb_mmap_shape = np.load('lb_mmap_shape.npy'))
        # 메모리맵을 불러온다. 메모리 맵을 만들 때와 비슷하게
        # 데이터 타입, 모드, shape 설정 주의
        self.img_mmap = np.memmap('img.dat',dtype=np.uint8,mode='r',shape=tuple(self.img_mmap_shape))
        self.lb_mmap = np.memmap('lb.dat',dtype=np.float32,mode='r',shape=tuple(self.lb_mmap_shape))
    
    def __getitem__(self,idx):
    	# 메모리맵에서 필요한 데이터 access하고
        img = self.img_mmap[idx]
        lb = self.lb_mmap[idx]
        return img,lb # 데이터를 반환하면 끝

    def __len__(self,):
        return self.img_mmap_shape[0]

 

728x90
728x90