728x90
728x90
학습할 때 dataloading 이 생각보다 시간을 많이 잡아먹습니다. 물론 dataloading을 가장 빠르게 하는 방법은 모든 data를 RAM에 올려서 작업하는 방식이지만 현실적으로 큰 데이터 셋을 다루다 보면 이것은 불가능하죠.
Dataloading을 빠르게 할 수 있는 방법 중 하나가 이번 포스트에서 소개할 memory map을 사용하는 것입니다.
class numpy.memmap(filename, dtype=<class 'numpy.ubyte'>, mode='r+', offset=0, shape=None, order='C')
- "Create a memory-map to an array stored in a binary file on disk. Memory-mapped files are used for accessing small segments of large files on disk, without reading the entire file into memory"
- 참고: https://numpy.org/doc/stable/reference/generated/numpy.memmap.html, https://en.wikipedia.org/wiki/Memory-mapped_file
memory map은 큰 데이터를 다 메모리로 읽어오지 않고, 디스크에 저장해둔 다음 빠르게 데이터를 읽어올 때 사용합니다.
memory map 파일을 만드는 코드:
from tqdm import tqdm
import numpy as np
# 학습을 위해서 이미지와 이미지에 대응되는 레이블에 대한 메모리 맵 파일을 만든다
mmap_img_path = 'img.dat' # 이미지에 대한 메모리맵 경로
mmap_lb_path = 'lb.dat' # 레이블에 대한 메모리맵 경로
num_data = 100 # 데이터 수; 임의로 100으로 설정함
img_size = (256,256,3) # 이미지 사이즈
img_mmap_shape = (num_data,*img_size) # 메모리맵으로 저장된 이미지 array 사이즈
lb_size = 10 # 레이블 사이즈; 임의로 이미지 당 10개의 레이블이 있다고 가정
lb_mmap_shape = (num_data,lb_size) # 메모리맵으로 저장된 레이블 array 사이즈
# 메모리맵 사이즈들을 저장
np.save('img_mmap_shape.npy',np.array(img_mmap_shape))
print('img_mmap_shape:',img_mmap_shape)
np.save('lb_mmap_shape.npy',np.array(lb_mmap_shape))
print('lb_mmap_shape:',lb_mmap_shape)
# 메모리맵을 만든다. 주의점:
# 1. 데이터 타입을 알맞게 설정해주기
# 2. mode: w+는 파일을 만들거나 파일에 write할 때 사용하는 모드
# 3. 메모리맵에 할당되는 총 사이즈를 미리 정해야함.
mmap_img = np.memmap(mmap_img_path,dtype=np.uint8,mode='w+',shape=img_mmap_shape)
mmap_lb = np.memmap(mmap_lb_path,dtype=np.float32,mode='w+',shape=lb_mmap_shape)
# 편의 상 data가 npz 형태로 저장되어있다고 가정하자: https://numpy.org/doc/stable/reference/generated/numpy.savez.html
npz_paths = ['0.npz', '1.npz',...,'99.npz'] # 100개의 npz 파일들 경로
for idx,npz_path in enumerate(tqdm(npz_paths)):
npz = np.load(npz_path) # npz 파일을 불러오고
img = npz['img']
mmap_img[idx][:]=img # memory map에 이미지 데이터 저장
lb = npz['label']
mmap_lb[idx][:]=lb # memory map에 레이블 데이터 저장
memory map 사용 코드 (pytorch Dataset)
import numpy as np
from torch.utils.data import Dataset
class MMDataset(Dataset):
def __init__(self):
"""
use memory map to load data
"""
super().__init__()
# 저장해둔 이미지와 레이블에 대한 memory map 사이즈를 불러온다
self.img_mmap_shape = np.load('img_mmap_shape.npy'))
self.lb_mmap_shape = np.load('lb_mmap_shape.npy'))
# 메모리맵을 불러온다. 메모리 맵을 만들 때와 비슷하게
# 데이터 타입, 모드, shape 설정 주의
self.img_mmap = np.memmap('img.dat',dtype=np.uint8,mode='r',shape=tuple(self.img_mmap_shape))
self.lb_mmap = np.memmap('lb.dat',dtype=np.float32,mode='r',shape=tuple(self.lb_mmap_shape))
def __getitem__(self,idx):
# 메모리맵에서 필요한 데이터 access하고
img = self.img_mmap[idx]
lb = self.lb_mmap[idx]
return img,lb # 데이터를 반환하면 끝
def __len__(self,):
return self.img_mmap_shape[0]
728x90
728x90
'코딩 > pytorch' 카테고리의 다른 글
PyTorch 와 PyTorch Lightning을 위한 간단 도커파일 (Dockerfile) 작성 & 사용 (0) | 2023.08.03 |
---|---|
[PyTorch] 간단한 pytorch lightning 설명과 템플릿 코드 (0) | 2023.07.30 |
[PyTorch] learning rate scheduler 직접 짜기 (0) | 2023.02.04 |
ddp 학습 중단 오류--한번에 프로세스 죽이기 (0) | 2023.01.23 |