# -*- coding: utf-8 -*-
"""Base class for DeepOBS datasets."""
import tensorflow as tf
# pylint: disable=too-many-instance-attributes, too-few-public-methods
[docs]class DataSet(object):
"""Base class for DeepOBS data sets.
Args:
batch_size (int): The mini-batch size to use.
Attributes:
batch: A tuple of tensors, yielding batches of data from the dataset.
Executing these tensors raises a ``tf.errors.OutOfRangeError`` after one
epoch.
train_init_op: A tensorflow operation initializing the dataset for the
training phase.
train_eval_init_op: A tensorflow operation initializing the testproblem for
evaluating on training data.
test_init_op: A tensorflow operation initializing the testproblem for
evaluating on test data.
phase: A string-value tf.Variable that is set to ``train``, ``train_eval``
or ``test``, depending on the current phase. This can be used by testproblems
to adapt their behavior to this phase.
"""
def __init__(self, batch_size):
"""Creates a new DataSet instance.
Args:
batch_size (int): The mini-batch size to use.
"""
self._batch_size = batch_size
self._train_dataset = self._make_train_dataset()
self._train_eval_dataset = self._make_train_eval_dataset()
self._test_dataset = self._make_test_dataset()
# Reinitializable iterator given types and shapes of the outputs
# (needs to be the same for train and test of course).
self._iterator = tf.data.Iterator.from_structure(
self._train_dataset.output_types,
self._train_dataset.output_shapes)
self.batch = self._iterator.get_next()
# Operations to switch phases (reinitialize iterator and assign value to
# phase variable)
self.phase = tf.Variable("train", name="phase", trainable=False)
self.train_init_op = tf.group([
self._iterator.make_initializer(self._train_dataset),
tf.assign(self.phase, "train")
], name="train_init_op")
self.train_eval_init_op = tf.group([
self._iterator.make_initializer(self._train_eval_dataset),
tf.assign(self.phase, "train_eval")
], name="train_eval_init_op")
self.test_init_op = tf.group([
self._iterator.make_initializer(self._test_dataset),
tf.assign(self.phase, "test")
], name="test_init_op")
def _make_train_dataset(self):
"""Creates the training dataset.
Returns:
A tf.data.Dataset instance with batches of training data.
"""
raise NotImplementedError(
"""'DataSet' is an abstract base class, please use
one of the sub-classes.""")
def _make_train_eval_dataset(self):
"""Creates the train eval dataset.
Returns:
A tf.data.Dataset instance with batches of training eval data.
"""
raise NotImplementedError(
"""'DataSet' is an abstract base class, please use
one of the sub-classes.""")
def _make_test_dataset(self):
"""Creates the test dataset.
Returns:
A tf.data.Dataset instance with batches of test data.
"""
raise NotImplementedError(
"""'DataSet' is an abstract base class, please use
one of the sub-classes.""")