# -*- coding: utf-8 -*-
"""Base class for DeepOBS test problems."""
import abc
import torch
from .. import config
[docs]class TestProblem(abc.ABC):
"""Base class for DeepOBS test problems.
Args:
batch_size (int): Batch size to use.
l2_reg (float): L2-Regularization (weight decay) factor to use. If
not specified, the test problems revert to their respective defaults.
Note: Some test problems do not use regularization and this value will
be ignored in such a case.
Attributes:
_batch_size: Batch_size for the data of this test problem.
_l2_reg: The regularization factor for this test problem
data: The dataset used by the test problem (datasets.DataSet instance).
loss_function: The loss function for this test problem.
net: The torch module (the neural network) that is trained.
Methods:
train_init_op: Initializes the test problem for the
training phase.
train_eval_init_op: Initializes the test problem for
evaluating on training data.
test_init_op: Initializes the test problem for
evaluating on test data.
_get_next_batch: Returns the next batch of data of the current phase.
get_batch_loss_and_accuracy: Calculates the loss and accuracy of net on the next batch of the current phase.
set_up: Sets all public attributes.
"""
def __init__(self, batch_size, l2_reg=None):
"""Creates a new test problem instance.
Args:
batch_size (int): Batch size to use.
l2_reg (float): L2-Regularization (weight decay) factor to use. If
not specified, the test problems revert to their respective defaults.
Note: Some test problems do not use regularization and this value will
be ignored in such a case.
"""
self._batch_size = batch_size
self._l2_reg = l2_reg
self._device = torch.device(config.get_default_device())
self._batch_count = 0
# Public attributes by which to interact with test problems. These have to
# be created by the set_up function of sub-classes.
self.data = None
self.loss_function = None
self.net = None
self.regularization_groups = None
self._batch_count = 0
[docs] def train_init_op(self):
"""Initializes the testproblem instance to train mode. I.e.
sets the iterator to the training set and sets the model to train mode.
"""
self._iterator = iter(self.data._train_dataloader)
self.phase = "train"
self.net.train()
[docs] def train_eval_init_op(self):
"""Initializes the testproblem instance to train eval mode. I.e.
sets the iterator to the train evaluation set and sets the model to eval mode.
"""
self._iterator = iter(self.data._train_eval_dataloader)
self.phase = "train_eval"
self.net.eval()
[docs] def valid_init_op(self):
"""Initializes the testproblem instance to validation mode. I.e.
sets the iterator to the validation set and sets the model to eval mode.
"""
self._iterator = iter(self.data._valid_dataloader)
self.phase = "valid"
self.net.eval()
[docs] def test_init_op(self):
"""Initializes the testproblem instance to test mode. I.e.
sets the iterator to the test set and sets the model to eval mode.
"""
self._iterator = iter(self.data._test_dataloader)
self.phase = "test"
self.net.eval()
[docs] def _get_next_batch(self):
"""Returns the next batch from the iterator."""
batch = next(self._iterator)
self._batch_count += 1
return batch
[docs] def get_batch_loss_and_accuracy_func(
self, reduction="mean", add_regularization_if_available=True
):
"""Get new batch and create forward function that calculates loss and accuracy (if available)
on that batch. This is a default implementation for image classification.
Testproblems with different calculation routines (e.g. RNNs) overwrite this method accordingly.
Args:
reduction (str): The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which \
case each indivual loss in the mini-batch is returned as a tensor.
add_regularization_if_available (bool): If true, regularization is added to the loss.
Returns:
callable: The function that calculates the loss/accuracy on the current batch.
"""
inputs, labels = self._get_next_batch()
inputs = inputs.to(self._device)
labels = labels.to(self._device)
def forward_func():
correct = 0.0
total = 0.0
# in evaluation phase is no gradient needed
if self.phase in ["train_eval", "test", "valid"]:
with torch.no_grad():
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs,
labels)
else:
outputs = self.net(inputs)
loss = self.loss_function(reduction=reduction)(outputs, labels)
_, predicted = torch.max(outputs.data, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = correct / total
if add_regularization_if_available:
regularizer_loss = self.get_regularization_loss()
else:
regularizer_loss = torch.tensor(0.0,
device=torch.device(
self._device))
return loss + regularizer_loss, accuracy
return forward_func
[docs] def get_batch_loss_and_accuracy(
self, reduction="mean", add_regularization_if_available=True
):
"""Gets a new batch and calculates the loss and accuracy (if available)
on that batch.
Args:
reduction (str): The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which \
case each indivual loss in the mini-batch is returned as a tensor.
add_regularization_if_available (bool): If true, regularization is added to the loss.
Returns:
float/torch.tensor, float: loss and accuracy of the model on the current batch.
"""
forward_func = self.get_batch_loss_and_accuracy_func(
reduction=reduction,
add_regularization_if_available=add_regularization_if_available,
)
return forward_func()
[docs] def get_regularization_loss(self):
"""Returns the current regularization loss of the network based on the parameter groups.
Returns:
int or torch.tensor: If no regularzations is applied, it returns the integer 0. Else a torch.tensor \
that holds the regularization loss.
"""
# iterate through all layers
layer_norms = []
for regularization, parameter_group in self.regularization_groups.items(
):
if regularization > 0.0:
# L2 regularization
for parameters in parameter_group:
layer_norms.append(regularization *
parameters.pow(2).sum())
regularization_loss = 0.5 * sum(layer_norms)
return regularization_loss
[docs] @abc.abstractmethod
def get_regularization_groups(self):
"""Creates regularization groups for the parameters.
Returns:
dict: A dictionary where the key is the regularization factor and the value is a list of parameters.
"""
return
[docs] @abc.abstractmethod
# TODO get rid of setup structure by parsing individual loss func, network and dataset
def set_up(self):
"""Sets up the test problem.
"""
pass
class UnregularizedTestproblem(TestProblem):
def __init__(self, batch_size, l2_reg=None):
super(UnregularizedTestproblem, self).__init__(batch_size,
l2_reg)
def get_regularization_groups(self):
"""Creates regularization groups for the parameters.
Returns:
dict: A dictionary where the key is the regularization factor and the value is a list of parameters.
"""
no = 0.0
group_dict = {no: []}
for parameters_name, parameters in self.net.named_parameters():
# penalize no parameters
group_dict[no].append(parameters)
return group_dict
@abc.abstractmethod
def set_up(self):
"""Sets up the test problem.
"""
pass