Dataset
사용자가 정의한 데이터셋을 다루기 위한 추상 클래스입니다. 이 클래스를 통해 데이터셋을 다루는데 있어 일관성을 유지하면서 효율적으로 처리할 수 있습니다.
Dataset은 일반적으로 다음과 같은 세 가지 요소를 포함합니다.
- __init__ : 데이터셋을 초기화하고 읽어들입니다.
- __getitem__ : 데이터셋에서 특정 인덱스의 샘플을 읽어들입니다.
- __len__ : 데이터셋의 샘플 개수를 반환합니다.
또한, 데이터셋을 사용하는 이유는 모델에 학습데이터를 제공하는 것 뿐만 아니라 데이터 전철, 증강, 로딩 등 다양한 작업을 수행합니다.
from torch.utils.data import Dataset
class CustomDataset(Dataset):
def __init__(self,):
pass
def __len__(self):
pass
def __getitem__(self, idx):
pass
DataLoader
데이터셋을 사용하여 모델 학습에 필요한 미니배치(mini-batch) 데이터를 제공하는 유틸리티입니다.
DataLoader는 데이터셋으로부터 데이터를 불러올 때 병목현상을 줄이기 위해 멀티스레딩 및 프로세스를 사용합니다.
DataLoader 객체는 데이터셋 객체와 함께 생성되며, 사용자가 지정한 미니배치 크기, 데이터 셔플 여부, 데이터 로드에 사용되는 스레드 및 프로세스 개수 등을 설정할 수 있습니다.
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
# 데이터셋 다운로드
train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transforms.ToTensor())
# DataLoader 생성
batch_size = 64
train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True, num_workers=4)
- dataset
DataLoader가 불러올 데이터셋 객체입니다.
- batch_size
전체 트레이닝 데이터 셋을 여러 작은 그룹으로 나누었을 때 batch size는 하나의 미니배치의 데이터 수를 의미합니다.
- shuffle
샘플들을 섞어서 순서를 바꿀지 여부입니다.
- sampler
데이터를 배치 단위로 자르고 셔플링하는 등 다양한 옵션 제공 (shuffle과 함께 사용 가능)
- SequentialSampler : 항상 같은 순서로
- RandomSampler : 랜덤 순서, 개수 선택 가능
- SubsetRandomSampler : 인덱스 집합을 입력받아 해당 인덱스의 데이터를 랜덤하게 샘플링합니다.
- WeightedRandomSampler : 각 데이터마다 가중치를 부여하여 가중치에 따라 샘플링합니다.
- num_workers
데이터를 불러올 떄 사용할 프로세스의 개수입니다 .데이터 전처리가 CPU에서 이루어진다면 이 값을 높이면 속도를 높일 수 있습니다.
(윈도우에서는 멀티프로세서의 제한 때문에 num_worker의 수가 1 이상인 경우 에러가 발생합니다.)
또한 무작정 높인다고 성능이 좋아지진 않습니다. 오히려 CPU와 GPU의 병목현상때문에 성능이 저하할 수 있습니다.
- collate_fn
미니배치를 만들 때 각 샘플들을 어떻게 결합할지 지정하는 함수입니다.
- pin_memory
CUDA로 모델을 학습할 때 CPU 메모리에 데이터를 로딩한 후 복사하는 것보다 GPU 메모리로 바로 복사하는 것이 더 효율적입니다. 이 값을 True로 설정하면 로딩한 데이터를 곧바로 복사합니다.
- drop_last
마지막 미니배치의 크기가 batch_size 보다 작을 때 이를 무시할지 여부를 판단합니다.
- time_out
DataLoader가 data를 불러오는 데 시간을 지정합니다.
'ML & DL > PyTorch' 카테고리의 다른 글
[Pytroch] Multi GPU Training (0) | 2023.05.31 |
---|---|
[PyTorch] torchvision & transform (0) | 2023.03.15 |
[PyTorch] nn.Module (0) | 2023.03.15 |
[PyTorch] torch.nn (0) | 2023.03.15 |
[PyTorch] Optimization, 최적화 (0) | 2023.03.13 |