Source code for deepobs.pytorch.testproblems.fmnist_mlp

# -*- coding: utf-8 -*-
"""A vanilla MLP architecture for Fashion-MNIST."""

import warnings

from torch import nn

from ..datasets.fmnist import fmnist
from .testproblem import UnregularizedTestproblem
from .testproblems_modules import net_mlp


[docs]class fmnist_mlp(UnregularizedTestproblem): """DeepOBS test problem class for a multi-layer perceptron neural network\ on Fashion-MNIST. The network is build as follows: - Four fully-connected layers with ``1000``, ``500``, ``100`` and ``10`` units per layer. - The first three layers use ReLU activation, and the last one a softmax activation. - The biases are initialized to ``0.0`` and the weight matrices with truncated normal (standard deviation of ``3e-2``) - The model uses a cross entropy loss. - No regularization is used. Args: batch_size (int): Batch size to use. l2_reg (float): No L2-Regularization (weight decay) is used in this test problem. Defaults to ``None`` and any input here is ignored. Attributes: data: The DeepOBS data set class for Fashion-MNIST. loss_function: The loss function for this testproblem is torch.nn.CrossEntropyLoss() net: The DeepOBS subclass of torch.nn.Module that is trained for this tesproblem (net_mlp). """ def __init__(self, batch_size, l2_reg=None): """Create a new multi-layer perceptron test problem instance on \ Fashion-MNIST. Args: batch_size (int): Batch size to use. l2_reg (float): No L2-Regularization (weight decay) is used in this test problem. Defaults to ``None`` and any input here is ignored. """ super(fmnist_mlp, self).__init__(batch_size, l2_reg) if l2_reg is not None: warnings.warn( "L2-Regularization is non-zero but no L2-regularization is used for this model.", RuntimeWarning, )
[docs] def set_up(self): """Sets up the vanilla MLP test problem on Fashion-MNIST.""" self.data = fmnist(self._batch_size) self.loss_function = nn.CrossEntropyLoss self.net = net_mlp(num_outputs=10) self.net.to(self._device) self.regularization_groups = self.get_regularization_groups()