Source code for deepobs.tensorflow.datasets.dataset

# -*- coding: utf-8 -*-
"""Base class for DeepOBS datasets."""

import tensorflow as tf


# pylint: disable=too-many-instance-attributes, too-few-public-methods
[docs]class DataSet(object): """Base class for DeepOBS data sets. Args: batch_size (int): The mini-batch size to use. Attributes: batch: A tuple of tensors, yielding batches of data from the dataset. Executing these tensors raises a ``tf.errors.OutOfRangeError`` after one epoch. train_init_op: A tensorflow operation initializing the dataset for the training phase. train_eval_init_op: A tensorflow operation initializing the testproblem for evaluating on training data. valid_init_op: A tensorflow operation initializing the dataset for the validation phase. test_init_op: A tensorflow operation initializing the testproblem for evaluating on test data. phase: A string-value tf.Variable that is set to ``train``, ``train_eval``, ``valid``, or ``test``, depending on the current phase. This can be used by testproblems to adapt their behavior to this phase. """ def __init__(self, batch_size): """Creates a new DataSet instance. Args: batch_size (int): The mini-batch size to use. """ self._batch_size = batch_size ( self._train_dataset, self._train_eval_dataset, self._valid_dataset, ) = self._make_train_datasets() self._test_dataset = self._make_test_dataset() # Reinitializable iterator given types and shapes of the outputs # (needs to be the same for train and test of course). self._iterator = tf.data.Iterator.from_structure( self._train_dataset.output_types, self._train_dataset.output_shapes ) self.batch = self._iterator.get_next() # Operations to switch phases (reinitialize iterator and assign value to # phase variable) self.phase = tf.Variable("train", name="phase", trainable=False) self.train_init_op = tf.group( [ self._iterator.make_initializer(self._train_dataset), tf.assign(self.phase, "train"), ], name="train_init_op", ) self.train_eval_init_op = tf.group( [ self._iterator.make_initializer(self._train_eval_dataset), tf.assign(self.phase, "train_eval"), ], name="train_eval_init_op", ) self.valid_init_op = tf.group( [ self._iterator.make_initializer(self._valid_dataset), tf.assign(self.phase, "valid"), ], name="valid_init_op", ) self.test_init_op = tf.group( [ self._iterator.make_initializer(self._test_dataset), tf.assign(self.phase, "test"), ], name="test_init_op", ) def _make_train_datasets(self): """Creates the training datasets (train, train eval and validation set). Returns: A tf.data.Dataset instance with batches of training data. A tf.data.Dataset instance with batches of training eval data. A tf.data.Dataset instance with batches of validation data. """ raise NotImplementedError( """'DataSet' is an abstract base class, please use one of the sub-classes.""" ) def _make_test_dataset(self): """Creates the test dataset. Returns: A tf.data.Dataset instance with batches of test data. """ raise NotImplementedError( """'DataSet' is an abstract base class, please use one of the sub-classes.""" )