Source code for deepobs.pytorch.testproblems.mnist_vae

# -*- coding: utf-8 -*-
"""A VAE architecture for MNIST."""

import torch

from ..datasets.mnist import mnist
from .testproblem import UnregularizedTestproblem
from .testproblems_modules import net_vae
from .testproblems_utils import vae_loss_function_factory


[docs]class mnist_vae(UnregularizedTestproblem): """DeepOBS test problem class for a variational autoencoder (VAE) on \ MNIST. The network has been adapted from the `here\ <https://towardsdatascience.com/teaching-a-variational-autoencoder-vae-to-draw-mnist-characters-978675c95776>`_ and consists of an encoder: - With three convolutional layers with each ``64`` filters. - Using a leaky ReLU activation function with :math:`\\alpha = 0.3` - Dropout layers after each convolutional layer with a rate of ``0.2``. and an decoder: - With two dense layers with ``24`` and ``49`` units and leaky ReLU activation. - With three deconvolutional layers with each ``64`` filters. - Dropout layers after the first two deconvolutional layer with a rate of ``0.2``. - A final dense layer with ``28 x 28`` units and sigmoid activation. No regularization is used. Args: batch_size (int): Batch size to use. l2_reg (float): No L2-Regularization (weight decay) is used in this test problem. Defaults to ``None`` and any input here is ignored. Attributes: data: The DeepOBS data set class for MNIST. loss_function: The loss function for this testproblem (vae_loss_function as defined in testproblem_utils) net: The DeepOBS subclass of torch.nn.Module that is trained for this tesproblem (net_vae). """ def __init__(self, batch_size, l2_reg=None): """Create a new VAE test problem instance on MNIST. Args: batch_size (int): Batch size to use. l2_reg (float): No L2-Regularization (weight decay) is used in this test problem. Defaults to ``None`` and any input here is ignored. """ super(mnist_vae, self).__init__(batch_size, l2_reg) if l2_reg is not None: print( "WARNING: L2-Regularization is non-zero but no L2-regularization is used", "for this model.", ) self.loss_function = vae_loss_function_factory
[docs] def set_up(self): """Sets up the vanilla CNN test problem on MNIST.""" self.data = mnist(self._batch_size) self.net = net_vae(n_latent=8) self.net.to(self._device) self.regularization_groups = self.get_regularization_groups()
[docs] def get_batch_loss_and_accuracy_func( self, reduction="mean", add_regularization_if_available=True ): """Get new batch and create forward function that calculates loss and accuracy (if available) on that batch. Args: reduction (str): The reduction that is used for returning the loss. Can be 'mean', 'sum' or 'none' in which \ case each indivual loss in the mini-batch is returned as a tensor. add_regularization_if_available (bool): If true, regularization is added to the loss. Returns: callable: The function that calculates the loss/accuracy on the current batch. """ inputs, _ = self._get_next_batch() inputs = inputs.to(self._device) def forward_func(): # in evaluation phase is no gradient needed if self.phase in ["train_eval", "test", "valid"]: with torch.no_grad(): outputs, means, std_devs = self.net(inputs) loss = self.loss_function(reduction=reduction)( outputs, inputs, means, std_devs ) else: outputs, means, std_devs = self.net(inputs) loss = self.loss_function(reduction=reduction)( outputs, inputs, means, std_devs ) accuracy = 0 if add_regularization_if_available: regularizer_loss = self.get_regularization_loss() else: regularizer_loss = torch.tensor(0.0, device=torch.device(self._device)) return loss + regularizer_loss, accuracy return forward_func