Source code for morpheus.core.model

# MIT License
# Copyright 2018 Ryan Hausen
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.
# ==============================================================================
"""Contains model code for Morpheus."""

import collections
import json
import os
from typing import List

import tensorflow.compat.v1 as tf

import morpheus.core.unet
from morpheus.core.hparams import HParams


[docs]class Morpheus(morpheus.core.unet.Model): """The main class for the Morpheus model. This class takes a HParams object as an argument and it should contain the following properties: Note if you are using pretrained weights for inference only you need to mock the dataset object and use the default hparams. You can mock the dataset object calling Morpheus.mock_dataset(). You can get the default HParams by calling Morpheus.inference_hparams(). An example call for inference only >>> dataset = Morpheus.mock_dataset() >>> hparams = Morpheus.inference_hparams() >>> data_format = 'channels_last' >>> morph = Morpheus(hparams, dataset, data_format) Required HParams: * inference (bool): true if using pretrained model * down_filters (list): number of filters for each down conv section * num_down_convs (int): number of conv ops per down conv section * up_filters (list): number of filters for each up conv section * num_up_convs (int): number of conv ops per up conv section * batch_norm (bool): use batch normalization * dropout (bool): use dropout Optional HParams: * learning_rate (float): learning rate for training, required if inference is set to false * dropout_rate (float): the percentage of neurons to drop [0.0, 1.0] Args: hparams (morpheus.core.hparams.HParams): Model Hyperparameters dataset (tf.data.Dataset): dataset to use for training data_format: channels_first or channels_last TODO: * Make optimizer a parameter """ def __init__( self, hparams: HParams, dataset: tf.data.Dataset, data_format: str, ): """Inits Morpheus with hparams, dataset, data_format.""" super().__init__(hparams, dataset, data_format) if not hparams.inference: self.opt = tf.train.AdamOptimizer(hparams.learning_rate)
[docs] def loss_func(self, logits: tf.Tensor, labels: tf.Tensor) -> tf.Tensor: """Defines the loss function used in training. The loss function is defined by combining cross entropy loss calculated against all 5 classes and dice loss calculated against just the background class. Args: logits (tf.Tensor): output tensor from graph should be [batch_size, width, height, 5] labels (tf.Tensor): labels used in training should be [batch_size, width, height, 5] Returns: tf.Tensor: Tensor representing loss function. """ flat_logits = tf.reshape(logits, [-1, 5]) flat_y = tf.reshape(labels, [-1, 5]) # Calculate weighted crossentropy ====================================== # This is normally calculated by taking a count of the pixels assigned # to each class, but because we have continous values for each class # we sum the probabilities for each class in the pixels instead. xentropy_loss = tf.nn.softmax_cross_entropy_with_logits_v2( logits=flat_logits, labels=flat_y ) dominant_class = tf.argmax(flat_y, axis=1, output_type=tf.int32) p_dominant_class = tf.reduce_max(flat_y, axis=1) class_coefficient = tf.zeros_like(xentropy_loss) for output_class_idx in range(5): class_pixels = tf.cast( tf.equal(output_class_idx, dominant_class), tf.float32 ) coef = tf.reduce_mean(class_pixels * p_dominant_class) class_coefficient = tf.add(class_coefficient, coef * class_pixels) class_coefficient = 1 / class_coefficient weighted_xentropy_loss = tf.reduce_mean(xentropy_loss * class_coefficient) # Calculate weighted crossentropy ====================================== # Calculate dice loss ================================================== if self.data_format == "channels_first": yh_background = tf.nn.sigmoid(logits[:, -1, :, :]) y_background = labels[:, -1, :, :] else: yh_background = tf.nn.sigmoid(logits[:, :, :, -1]) y_background = labels[:, :, :, -1] dice_numerator = tf.reduce_sum(y_background * yh_background, axis=[1, 2]) dice_denominator = tf.reduce_sum(y_background + yh_background, axis=[1, 2]) dice_loss = tf.reduce_mean(2 * dice_numerator / dice_denominator) # Calculate dice loss ================================================== total_loss = weighted_xentropy_loss total_loss = total_loss + (1 - dice_loss) return total_loss
[docs] def optimizer(self, loss: tf.Tensor) -> tf.Tensor: """Overrides the optimizer func in morpheus.core.unet Args: loss (tf.Tensor): The loss function tensor to pass to the optimizer Returns: tf.Tensor: the Tensor result of optimizer.minimize() """ update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS) with tf.control_dependencies(update_ops): optimize = self.opt.minimize(loss) return optimize
[docs] def train_metrics( self, logits: tf.Tensor, labels: tf.Tensor ) -> ( (List[str], List[tf.Tensor]), List[tf.Tensor], ): # overrides base method pylint: disable=no-self-use """Overrides the train_metrics func in morpheus.core.unet Args: logits (tf.Tensor): the output logits from the model labels (tf.Tensor): the labels used during training Returns: Tuple(Tuple( List(str): names of metrics, List(tf.Tensor): tensors for metrics ), List(tf.Tensor): Tensors for updating running metrics """ with tf.name_scope("train_metrics"): metrics_dict = Morpheus.eval_metrics(logits, labels) names, finalize, running = [], [], [] for key in sorted(metrics_dict): names.append(key) finalize.append(metrics_dict[key][0]) running.append(metrics_dict[key][1]) return ([names, finalize], running)
[docs] def test_metrics( self, logits: tf.Tensor, labels: tf.Tensor ) -> ( (List[str], List[tf.Tensor]), List[tf.Tensor], ): # overrides base method pylint: disable=no-self-use """Overrides the test_metrics func in morpheus.core.unet Args: logits (tf.Tensor): the output logits from the model labels (tf.Tensor): the labels used during training Returns: Tuple(Tuple( List(str): names of metrics, List(tf.Tensor): tensors for metrics ), List(tf.Tensor): Tensors for updating running metrics """ with tf.name_scope("test_metrics"): metrics_dict = Morpheus.eval_metrics(logits, labels) names, finalize, running = [], [], [] for key in sorted(metrics_dict): names.append(key) finalize.append(metrics_dict[key][0]) running.append(metrics_dict[key][1]) return ([names, finalize], running)
[docs] def inference(self, inputs: tf.Tensor) -> tf.Tensor: """Performs inference on input. Args: inputs (tf.Tensor): input tensor with shape [batch_size, width, height, 5] Returns: A tf.Tensor of [batch_size, width, height, 5] representing the output the model, includes applying the softmax function. """ return tf.nn.softmax(self.build_graph(inputs, False))
[docs] @staticmethod def eval_metrics(yh: tf.Tensor, y: tf.Tensor) -> dict: """Function to generate metrics for evaluation during training. Args: yh (tf.Tensor): network output [n,h,w,c] y (tf.Tensor): labels [n,h,w,c] Returns: A dictionary collection of (tf.Tensor, tf.Tensor), where the keys are the names of the metrics and the values are running metric pairs. More infor on running accuracy metrics here: https://www.tensorflow.org/api_docs/python/tf/metrics/accuracy """ metrics_dict = {} thresholds = [0.5, 0.6, 0.7, 0.8, 0.9] classes = ["spheroid", "disk", "irregular", "point_source", "background"] yh_bkg = tf.reshape(tf.nn.sigmoid(yh[:, :, :, -1]), [-1]) y_bkg = tf.reshape(y[:, :, :, -1], [-1]) for threshold in thresholds: name = "iou-{}".format(threshold) with tf.name_scope(name): preds = tf.cast(tf.greater_equal(yh_bkg, threshold), tf.int32) metric, update_op = tf.metrics.mean_iou(y_bkg, preds, 2, name=name) metrics_dict[name] = (metric, update_op) # Calculate the accuracy per class per pixel y = tf.reshape(y, [-1, 5]) yh = tf.reshape(yh, [-1, 5]) lbls = tf.argmax(y, 1) preds = tf.argmax(yh, 1) name = "overall" metric, update_op = tf.metrics.accuracy(lbls, preds, name=name) metrics_dict[name] = (metric, update_op) for i, _ in enumerate(classes): in_c = tf.equal(lbls, i) name = classes[i] metric, update_op = tf.metrics.accuracy( lbls, preds, weights=in_c, name=name ) metrics_dict[name] = (metric, update_op) return metrics_dict
[docs] @staticmethod def mock_dataset() -> collections.namedtuple: """Generates a mockdataset for inference. Returns: A collections.namedtuple object that can be passed in place of a tf.data.Dataset for 'dataset' argument in the constructor """ MockDataset = collections.namedtuple("Dataset", ["num_labels"]) return MockDataset(5)
[docs] @staticmethod def inference_hparams() -> HParams: """Generates a mockdataset for inference. Returns: a morpheus.core.hparams.HParams object with the settings for inference """ config_path = os.path.join(os.path.dirname(__file__), "model_config.json") with open(config_path, "r") as f: return HParams(**json.load(f))
[docs] @staticmethod def get_weights_dir() -> str: """Returns the location of the weights for tf.Saver.""" return os.path.join(os.path.dirname(__file__), "model_weights")