Source code for deepobs.pytorch.datasets.cifar10

# -*- coding: utf-8 -*-
"""CIFAR-10 DeepOBS dataset."""

from torchvision import datasets, transforms

from deepobs import config

from . import dataset

training_transform_not_augmented = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784),
        ),
    ]
)

training_transform_augmented = transforms.Compose(
    [
        transforms.Pad(padding=2),
        transforms.RandomCrop(size=(32, 32)),
        transforms.RandomHorizontalFlip(),
        transforms.ColorJitter(
            brightness=63.0 / 255.0, saturation=[0.5, 1.5], contrast=[0.2, 1.8]
        ),
        transforms.ToTensor(),
        transforms.Normalize(
            (0.49139968, 0.48215841, 0.44653091), (0.24703223, 0.24348513, 0.26158784),
        ),
    ]
)


[docs]class cifar10(dataset.DataSet): """DeepOBS data set class for the `CIFAR-10\ <https://www.cs.toronto.edu/~kriz/cifar.html>`_ data set. Args: batch_size (int): The mini-batch size to use. Note that, if ``batch_size`` is not a divider of the dataset size (``50 000`` for train, ``10 000`` for test) the remainder is dropped in each epoch (after shuffling). data_augmentation (bool): If ``True`` some data augmentation operations (random crop window, horizontal flipping, lighting augmentation) are applied to the training data (but not the test data). train_eval_size (int): Size of the train eval data set. Defaults to ``10 000`` the size of the test set. Methods: _make_dataloader: A helper that is shared by all three data loader methods. """ def __init__(self, batch_size, data_augmentation=True, train_eval_size=10000): """Creates a new CIFAR-10 instance. Args: batch_size (int): The mini-batch size to use. Note that, if ``batch_size`` is not a divider of the dataset size (``50 000`` for train, ``10 000`` for test) the remainder is dropped in each epoch (after shuffling). data_augmentation (bool): If ``True`` some data augmentation operations (random crop window, horizontal flipping, lighting augmentation) are applied to the training data (but not the test data). train_eval_size (int): Size of the train eval data set. Defaults to ``10 000`` the size of the test set. """ self._name = "cifar10" self._data_augmentation = data_augmentation self._train_eval_size = train_eval_size super(cifar10, self).__init__(batch_size) def _make_train_and_valid_dataloader(self): if self._data_augmentation: transform = training_transform_augmented else: transform = training_transform_not_augmented train_dataset = datasets.CIFAR10( root=config.get_data_dir(), train=True, download=True, transform=transform, ) valid_dataset = datasets.CIFAR10( root=config.get_data_dir(), train=True, download=True, transform=training_transform_not_augmented, ) train_loader, valid_loader = self._make_train_and_valid_dataloader_helper( train_dataset, valid_dataset ) return train_loader, valid_loader def _make_test_dataloader(self): transform = training_transform_not_augmented test_dataset = datasets.CIFAR10( root=config.get_data_dir(), train=False, download=True, transform=transform, ) return self._make_dataloader(test_dataset, sampler=None)