Source code for morpheus.core.base_model

# MIT License
# Copyright 2018 Ryan Hausen
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ==============================================================================
"""A base class for building neural network models in TensorFlow."""

from types import FunctionType

import tensorflow as tf

from morpheus.core.helpers import OptionalFunc


[docs]class Model: """ Base class for models. Attributes: dataset (tf.data.Dataset): Dataset Object for training is_training (bool): indicates if the model is training data_format (str): 'channels_first' or 'channels_last' Required methods to override: model_fn: the graph function Optional methods to override: train_metrics: to add metrics during training test_metrics: to add metrics during testing, can be same as train_metrics optimizer: updates params based on a loss tensor loss_func: defines a loss value given and x and y tensor inference: default applies softmax to tensor from model_fn """ train_metrics = OptionalFunc("No training metrics set") test_metrics = OptionalFunc("No test metrics set") optimizer = OptionalFunc("No optimizer set") loss_func = OptionalFunc("No loss function set") inference = OptionalFunc("Nor inference fucntion set") def __init__(self, dataset: tf.data.Dataset, data_format: str = "channels_last"): """Inits Model with dataset, and data_format""" self.dataset = dataset self.data_format = data_format self._graph = None
[docs] def model_fn(self, inputs: tf.Tensor, is_training: bool) -> FunctionType: """Function that defines model. Needs to be Overridden! Args: inputs (tf.Tensor): the input tensor is_training (bool): boolean to indicate if in training phase Returns: Should return a function that takes two inputs tf.Tensor and bool Raises: NotImplementedError if not overridden """ raise NotImplementedError()
[docs] def build_graph(self, inputs: tf.Tensor, is_training: bool) -> tf.Tensor: """Base function that returns model_fn evaluated on x. Don't Override! Args: inputs (tf.Tensor): The tensor to be processed, ie a placeholder is_training (bool): whether or not the model is training useful for things like batch normalization or dropout Returns: returns the tensor that represents the result of model_fn evaluated on the input tensor Raises NotImplementedError if Model.model_fn() is not overwritten """ if self._graph: return self._graph(inputs, is_training) self._graph = self.model_fn return self._graph(inputs, is_training)
[docs] def train(self) -> (tf.Tensor, tf.Tensor): """Builds the training routine tensors. Don't Override! Returns: (optimize, metrics): the result of self.optimizer and self.train_metrics respectively Raises NotImplementedError if Model.model_fn() is not overwritten """ data, labels = self.dataset.train logits = self.build_graph(data, True) optimize = self.optimizer(self.loss_func(logits, labels)) metrics = self.train_metrics(logits, labels) return optimize, metrics
[docs] def test(self) -> (tf.Tensor, tf.Tensor): """Builds the testing routing tensors. Don't Override! Returns: (logits, metrics): the result of the self.build_graph and self.test_metrics respectively Raises NotImplementedError if Model.model_fn() is not overwritten """ inputs, labels = self.dataset.test logits = self.build_graph(inputs, False) metrics = self.test_metrics(logits, labels) return logits, metrics