Test Problems

Currently DeepOBS includes twenty-six different test problems. A test problem is given by a combination of a data set and a model and is characterized by its loss function.

Each test problem inherits from the same base class with the following signature.

class deepobs.pytorch.testproblems.TestProblem(batch_size, weight_decay=None)[source]

Base class for DeepOBS test problems.

Parameters:
  • batch_size (int) -- Batch size to use.
  • weight_decay (float) -- Weight decay (L2-regularization) factor to use. If not specified, the test problems revert to their respective defaults. Note: Some test problems do not use regularization and this value will be ignored in such a case.
_batch_size

Batch_size for the data of this test problem.

_weight_decay

The regularization factor for this test problem

data

The dataset used by the test problem (datasets.DataSet instance).

loss_function

The loss function for this test problem.

net

The torch module (the neural network) that is trained.

train_init_op()[source]

Initializes the test problem for the training phase.

train_eval_init_op()[source]

Initializes the test problem for evaluating on training data.

test_init_op()[source]

Initializes the test problem for evaluating on test data.

_get_next_batch()[source]

Returns the next batch of data of the current phase.

get_batch_loss_and_accuracy()[source]

Calculates the loss and accuracy of net on the next batch of the current phase.

set_up()[source]

Sets all public attributes.

get_batch_loss_and_accuracy(reduction='mean', add_regularization_if_available=True)[source]

Gets a new batch and calculates the loss and accuracy (if available) on that batch.

Parameters:
  • reduction (str) -- The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which case each indivual loss in the mini-batch is returned as a tensor.
  • add_regularization_if_available (bool) -- If true, regularization is added to the loss.
Returns:

loss and accuracy of the model on the current batch.

Return type:

float/torch.tensor, float

get_batch_loss_and_accuracy_func(reduction='mean', add_regularization_if_available=True)[source]

Get new batch and create forward function that calculates loss and accuracy (if available) on that batch. This is a default implementation for image classification. Testproblems with different calculation routines (e.g. RNNs) overwrite this method accordingly.

Parameters:
  • reduction (str) -- The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which case each indivual loss in the mini-batch is returned as a tensor.
  • add_regularization_if_available (bool) -- If true, regularization is added to the loss.
Returns:

The function that calculates the loss/accuracy on the current batch.

Return type:

callable

get_regularization_groups()[source]

Creates regularization groups for the parameters.

Returns:A dictionary where the key is the regularization factor and the value is a list of parameters.
Return type:dict
get_regularization_loss()[source]

Returns the current regularization loss of the network based on the parameter groups.

Returns:If no regularzations is applied, it returns the integer 0. Else a torch.tensor that holds the regularization loss.
Return type:int or torch.tensor
set_up()[source]

Sets up the test problem.

test_init_op()[source]

Initializes the testproblem instance to test mode. I.e. sets the iterator to the test set and sets the model to eval mode.

train_eval_init_op()[source]

Initializes the testproblem instance to train eval mode. I.e. sets the iterator to the train evaluation set and sets the model to eval mode.

train_init_op()[source]

Initializes the testproblem instance to train mode. I.e. sets the iterator to the training set and sets the model to train mode.

valid_init_op()[source]

Initializes the testproblem instance to validation mode. I.e. sets the iterator to the validation set and sets the model to eval mode.

Note

Some of the test problems described here are based on more general implementations. For example the Wide ResNet 40-4 network on Cifar-100 is based on the general Wide ResNet architecture which is also implemented. Therefore, it is very easy to include new Wide ResNets if necessary.