Source code for deepobs.pytorch.datasets.dataset

# -*- coding: utf-8 -*-
import abc
from random import shuffle

from torch.utils import data as dat
from torch.utils.data.sampler import SubsetRandomSampler

from .. import config
from .datasets_utils import train_eval_sampler

"""Base class for DeepOBS datasets."""

# pylint: disable=too-many-instance-attributes, too-few-public-methods


[docs]class DataSet(abc.ABC): """Base class for DeepOBS data sets. Args: batch_size (int): The mini-batch size to use. Methods: _make_train_and_valid_dataloader: Creates a torch data loader for the training and validation data with batches of size batch_size. _make_train_eval_dataloader: Creates a torch data loader for the training evaluation data with batches of size batch_size. _make_test_dataloader: Creates a torch data loader for the test data with batches of size batch_size. Attributes: _pin_memory: Whether to pin memory for the dataloaders. Defaults to 'False' if 'cuda' is not the current device. _num_workers: The number of workers used for the dataloaders. It's value is set to the global variable NUM_WORKERS. _train_dataloader: A torch.utils.data.DataLoader instance that holds the training data. _valid_dataloader: A torch.utils.data.DataLoader instance that holds the validation data. _train_eval_dataloader: A torch.utils.data.DataLoader instance that holds the training evaluation data. _test_dataloader: A torch.utils.data.DataLoader instance that holds the test data. """ def __init__(self, batch_size): """Creates a new DataSet instance. Args: batch_size (int): The mini-batch size to use. """ self._batch_size = batch_size if "cuda" in config.get_default_device(): self._pin_memory = True else: self._pin_memory = False self._num_workers = config.get_num_workers() ( self._train_dataloader, self._valid_dataloader, ) = self._make_train_and_valid_dataloader() self._train_eval_dataloader = self._make_train_eval_dataloader() self._test_dataloader = self._make_test_dataloader() def _make_dataloader(self, dataset, sampler=None, shuffle=False): loader = dat.DataLoader( dataset, batch_size=self._batch_size, drop_last=True, pin_memory=self._pin_memory, num_workers=self._num_workers, sampler=sampler, shuffle=shuffle, ) return loader def _make_train_eval_split_sampler(self, train_dataset): """Generates SubSetRandomSamplers that can be used for splitting the training set.""" indices = list(range(len(train_dataset))) shuffle(indices) train_indices, valid_indices = ( indices[self._train_eval_size :], indices[: self._train_eval_size], ) train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(valid_indices) return train_sampler, valid_sampler def _make_train_and_valid_dataloader_helper(self, train_dataset, valid_dataset): train_sampler, valid_sampler = self._make_train_eval_split_sampler( train_dataset ) # since random sampling, shuffle is useless train_loader = self._make_dataloader(train_dataset, sampler=train_sampler) valid_loader = self._make_dataloader(valid_dataset, sampler=valid_sampler) return train_loader, valid_loader
[docs] def _make_train_eval_dataloader(self): """Creates the training evaluation data loader. Returns: A torch.utils.data.DataLoader instance with batches of training evaluatoion data. """ size = len(self._train_dataloader.dataset) sampler = train_eval_sampler(size, self._train_eval_size) return self._make_dataloader(self._train_dataloader.dataset, sampler=sampler)
[docs] @abc.abstractmethod def _make_train_and_valid_dataloader(self): """Creates the training and validation data loader. Returns: A torch.utils.data.DataLoader instance with batches of training data. A torch.utils.data.DataLoader instance with batches of validation data. """ pass
[docs] @abc.abstractmethod def _make_test_dataloader(self): """Creates the test data loader. Returns: A torch.utils.data.DataLoader instance with batches of test data. """ pass