Source code for airlab.regulariser.displacement

# Copyright 2018 University of Basel, Center for medical Image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch as th
import torch.nn.functional as F

import numpy as np

# Regulariser base class (standard from PyTorch)
class _Regulariser(th.nn.modules.Module):
    def __init__(self, pixel_spacing, size_average=True, reduce=True):
        super(_Regulariser, self).__init__()
        self._size_average = size_average
        self._reduce = reduce
        self._weight = 1
        self._dim = len(pixel_spacing)
        self._pixel_spacing = pixel_spacing
        self.name = "parent"
        self._mask = None

    def SetWeight(self, weight):
        print("SetWeight is deprecated. Use set_weight instead.")
        self.set_weight(weight)

    def set_weight(self, weight):
        self._weight = weight

    def set_mask(self, mask):
        self._mask = mask

    def _mask_2d(self, df):
        if not self._mask is None:
            nx, ny, d = df.shape
            return df * self._mask.image.squeeze()[:nx,:ny].unsqueeze(-1).repeat(1,1,d)
        else:
            return df

    def _mask_3d(self, df):
        if not self._mask is None:
            nx, ny, nz, d = df.shape
            return df * self._mask.image.squeeze()[:nx,:ny,:nz].unsqueeze(-1).repeat(1,1,1,d)
        else:
            return df

    # conditional return
    def return_loss(self, tensor):
        if self._size_average and self._reduce:
            return self._weight*tensor.mean()
        if not self._size_average and self._reduce:
            return self._weight*tensor.sum()
        if not self._reduce:
            return self._weight*tensor

"""
    Isotropic TV regularisation
"""
[docs]class IsotropicTVRegulariser(_Regulariser): def __init__(self, pixel_spacing, size_average=True, reduce=True): super(IsotropicTVRegulariser, self).__init__(pixel_spacing, size_average, reduce) self.name = "isoTV" if self._dim == 2: self._regulariser = self._isotropic_TV_regulariser_2d # 2d regularisation elif self._dim == 3: self._regulariser = self._isotropic_TV_regulariser_3d # 3d regularisation def _isotropic_TV_regulariser_2d(self, displacement): dx = (displacement[1:, 1:, :] - displacement[:-1, 1:, :]).pow(2)*self._pixel_spacing[0] dy = (displacement[1:, 1:, :] - displacement[1:, :-1, :]).pow(2)*self._pixel_spacing[1] return self._mask_2d(F.pad(dx + dy, (0,1,0,1))) def _isotropic_TV_regulariser_3d(self, displacement): dx = (displacement[1:, 1:, 1:, :] - displacement[:-1, 1:, 1:, :]).pow(2)*self._pixel_spacing[0] dy = (displacement[1:, 1:, 1:, :] - displacement[1:, :-1, 1:, :]).pow(2)*self._pixel_spacing[1] dz = (displacement[1:, 1:, 1:, :] - displacement[1:, 1:, :-1, :]).pow(2)*self._pixel_spacing[2] return self._mask_3d(F.pad(dx + dy + dz, (0,1,0,1,0,1)))
[docs] def forward(self, displacement): # set the supgradient to zeros value = self._regulariser(displacement) mask = value > 0 value[mask] = th.sqrt(value[mask]) return self.return_loss(value)
""" TV regularisation """
[docs]class TVRegulariser(_Regulariser): def __init__(self, pixel_spacing, size_average=True, reduce=True): super(TVRegulariser, self).__init__(pixel_spacing, size_average, reduce) self.name = "TV" if self._dim == 2: self._regulariser = self._TV_regulariser_2d # 2d regularisation elif self._dim == 3: self._regulariser = self._TV_regulariser_3d # 3d regularisation def _TV_regulariser_2d(self, displacement): dx = th.abs(displacement[1:, 1:, :] - displacement[:-1, 1:, :])*self._pixel_spacing[0] dy = th.abs(displacement[1:, 1:, :] - displacement[1:, :-1, :])*self._pixel_spacing[1] return self._mask_2d(F.pad(dx + dy, (0, 1, 0, 1))) def _TV_regulariser_3d(self, displacement): dx = th.abs(displacement[1:, 1:, 1:, :] - displacement[:-1, 1:, 1:, :])*self._pixel_spacing[0] dy = th.abs(displacement[1:, 1:, 1:, :] - displacement[1:, :-1, 1:, :])*self._pixel_spacing[1] dz = th.abs(displacement[1:, 1:, 1:, :] - displacement[1:, 1:, :-1, :])*self._pixel_spacing[2] return self._mask_3d(F.pad(dx + dy + dz, (0, 1, 0, 1, 0, 1)))
[docs] def forward(self, displacement): return self.return_loss(self._regulariser(displacement))
""" Diffusion regularisation """
[docs]class DiffusionRegulariser(_Regulariser): def __init__(self, pixel_spacing, size_average=True, reduce=True): super(DiffusionRegulariser, self).__init__(pixel_spacing, size_average, reduce) self.name = "L2" if self._dim == 2: self._regulariser = self._l2_regulariser_2d # 2d regularisation elif self._dim == 3: self._regulariser = self._l2_regulariser_3d # 3d regularisation def _l2_regulariser_2d(self, displacement): dx = (displacement[1:, 1:, :] - displacement[:-1, 1:, :]).pow(2) * self._pixel_spacing[0] dy = (displacement[1:, 1:, :] - displacement[1:, :-1, :]).pow(2) * self._pixel_spacing[1] return self._mask_2d(F.pad(dx + dy, (0, 1, 0, 1))) def _l2_regulariser_3d(self, displacement): dx = (displacement[1:, 1:, 1:, :] - displacement[:-1, 1:, 1:, :]).pow(2) * self._pixel_spacing[0] dy = (displacement[1:, 1:, 1:, :] - displacement[1:, :-1, 1:, :]).pow(2) * self._pixel_spacing[1] dz = (displacement[1:, 1:, 1:, :] - displacement[1:, 1:, :-1, :]).pow(2) * self._pixel_spacing[2] return self._mask_3d(F.pad(dx + dy + dz, (0, 1, 0, 1, 0, 1)))
[docs] def forward(self, displacement): return self.return_loss(self._regulariser(displacement))
""" Sparsity regularisation """
[docs]class SparsityRegulariser(_Regulariser): def __init__(self, size_average=True, reduce=True): super(SparsityRegulariser, self).__init__([0], size_average, reduce)
[docs] def forward(self, displacement): return self.return_loss(th.abs(displacement))