Source code for airlab.regulariser.demons

# Copyright 2018 University of Basel, Center for medical Image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch as th
import torch.nn.functional as F
import numpy as np
from ..utils import graph as G
from ..utils import matrix as mat
from ..utils import kernelFunction as utils
from ..utils import image as iu

class _DemonsRegulariser():
    def __init__(self, pixel_spacing, dtype=th.float32, device='cpu'):
        super(_DemonsRegulariser, self).__init__()

        self._dtype = dtype
        self._device = device
        self._weight = 1
        self._dim = len(pixel_spacing)
        self._pixel_spacing = pixel_spacing
        self.name = "parent"



[docs]class GaussianRegulariser(_DemonsRegulariser): def __init__(self, pixel_spacing, sigma, dtype=th.float32, device='cpu'): super(GaussianRegulariser, self).__init__(pixel_spacing, dtype=dtype, device=device) sigma = np.array(sigma) if sigma.size != self._dim: sigma_app = sigma[-1] while sigma.size != self._dim: sigma = np.append(sigma, sigma_app) self._kernel = utils.gaussian_kernel(sigma, self._dim, asTensor=True, dtype=dtype, device=device) self._padding = (np.array(self._kernel.size()) - 1) / 2 self._padding = self._padding.astype(dtype=int).tolist() self._kernel.unsqueeze_(0).unsqueeze_(0) self._kernel = self._kernel.expand(self._dim, *((np.ones(self._dim + 1, dtype=int) * -1).tolist())) self._kernel = self._kernel.to(dtype=dtype, device=self._device) if self._dim == 2: self._regulariser = self._regularise_2d elif self._dim == 3: self._regulariser = self._regularise_3d def _regularise_2d(self, data): data.data = data.data.unsqueeze(0) data.data = F.conv2d(data.data, self._kernel.contiguous(), padding=self._padding, groups=2) data.data = data.data.squeeze() def _regularise_3d(self, data): data.data = data.data.unsqueeze(0) data.data = F.conv3d(data.data, self._kernel, padding=self._padding, groups=3) data.data = data.data.squeeze()
[docs] def regularise(self, data): for parameter in data: # no gradient calculation for the demons regularisation with th.no_grad(): self._regulariser(parameter)
class _GraphEdgeWeightUpdater(): def __init__(self, pixel_spacing, edge_window=0.9, edge_mean=False): self._edge_window = edge_window self._edge_mean = edge_mean self._laplace_matrix = None self._dim = len(pixel_spacing) self._pixel_spacing = pixel_spacing self._collapse_threshold = 0 self._detect_node_collapse = True def detect_node_collapse(self, detect): self._detect_node_collapse = detect def set_laplace_matrix(self, laplace_matrix): self._laplace_matrix = laplace_matrix def remove_node_collapse(self): for i, diag in enumerate(self._laplace_matrix.diag_elements): node_value = self._laplace_matrix.main_diag[diag.edge_index[-1]] index = th.abs(node_value) < self._collapse_threshold diag.edge_values[index] = 1
[docs]class EdgeUpdaterIntensities(_GraphEdgeWeightUpdater): def __init__(self, pixel_spacing, image, scale=1, edge_window=0.9, edge_mean=False): super(EdgeUpdaterIntensities, self).__init__(pixel_spacing, edge_window, edge_mean) self._image = image self._scale = scale
[docs] def set_scale(self, sale): self._scale = scale
[docs] def update(self, data): if self._dim == 2: for i, diag in enumerate(self._laplace_matrix.diag_elements): one = th.zeros(self._dim, dtype=th.int64, device=self._image.device) one[i] = 1 intensyties_A = self._image[0, 0, diag.edge_index[0], diag.edge_index[1]] intensyties_B = self._image[0, 0, diag.edge_index[0] - one[0], diag.edge_index[1] - one[1]] diag.edge_values = (th.exp(-self._scale*th.abs(intensyties_A - intensyties_B))) elif self._dim == 3: for i, diag in enumerate(self._laplace_matrix.diag_elements): one = th.zeros(self._dim, dtype=th.int64, device=self._image.device) one[i] = 1 intensyties_A = self._image[0, 0, diag.edge_index[0], diag.edge_index[1], diag.edge_index[2]] intensyties_B = self._image[0, 0, diag.edge_index[0] - one[0], diag.edge_index[1] - one[1], diag.edge_index[2] - one[2]] diag.edge_values = (th.exp(-self._scale*th.abs(intensyties_A - intensyties_B))) # update the laplace matrix self._laplace_matrix.update()
[docs]class EdgeUpdaterDisplacementIntensities(_GraphEdgeWeightUpdater): def __init__(self, pixel_spacing, image, edge_window=0.9, edge_mean=False): super(EdgeUpdaterDisplacementIntensities, self).__init__(pixel_spacing, edge_window, edge_mean) self._image = image[0, 0, ...] self._image_gradient = None self._scale_int_diff = 1 self._scale_disp_diff = 1 self._scale_disp = 1 if self._dim == 2: data_pad = th.nn.functional.pad(self._image, pad=(1, 0, 1, 0)) # , mode='replicate' dx = data_pad[1:, 1:] - data_pad[:-1, 1:] dy = data_pad[1:, 1:] - data_pad[:1, -1:] self._image_gradient = th.stack((dx, dy), 2)
[docs] def update(self, data): if self._dim == 2: for i, diag in enumerate(self._laplace_matrix.diag_elements): one = th.zeros(self._dim, dtype=th.int64, device=self._image.device) one[i] = 1 intensyties_A = self._image[diag.edge_index[0], diag.edge_index[1]] intensyties_B = self._image[diag.edge_index[0] - one[0], diag.edge_index[1] - one[1]] intensity_diff = th.exp(-th.abs(intensyties_A - intensyties_B)*self._scale_int_diff) del intensyties_A, intensyties_B displacement_A = data[:, diag.edge_index[0], diag.edge_index[1]] displacement_B = data[:, diag.edge_index[0] - one[0], diag.edge_index[1] - one[1]] displacement_diff = displacement_A - displacement_B displacement_diff = th.sqrt(displacement_diff[0, :]**2 + displacement_diff[1, :]**2) displacement_diff = th.exp(-self._scale_disp_diff*displacement_diff) norm_disp_A = th.sqrt(displacement_A[0, ...]**2 + displacement_A[1, ...]**2) norm_disp_B = th.sqrt(displacement_B[0, ...]**2 + displacement_B[1, ...]**2) image_gradient_A = self._image_gradient[diag.edge_index[0], diag.edge_index[1], :] image_gradient_B = self._image_gradient[diag.edge_index[0] - one[0], diag.edge_index[1] - one[1], :] norm_A = th.sqrt(image_gradient_A[..., 0]**2 + image_gradient_A[..., 1]**2) norm_B = th.sqrt(image_gradient_B[..., 0]**2 + image_gradient_B[..., 1]**2) max_norm = th.max(norm_A, norm_B) max_grad = th.zeros_like(image_gradient_A) index = (norm_A - max_norm) == 0 max_grad[index] = image_gradient_A[index] max_grad[1 - index] = image_gradient_B[1 - index] del index, image_gradient_A, image_gradient_B phi_A = th.div(th.sum(th.mul(max_grad, displacement_A.t()), dim=1), th.mul(norm_disp_A, max_norm) + 1e-10) phi_B = th.div(th.sum(th.mul(max_grad, displacement_B.t()), dim=1), th.mul(norm_disp_B, max_norm) + 1e-10) weight = th.mul(intensity_diff,displacement_diff) + (1-intensity_diff)*((phi_A + phi_B)*0.5) weight = weight*(1 - self._scale_disp) + displacement_diff*self._scale_disp if self._edge_mean: diag.edge_values = diag.edge_values*self._edge_window + th.round(weight)*(1. - self._edge_window) else: diag.edge_values = weight elif self._dim == 3: for i, diag in enumerate(self._laplace_matrix.diag_elements): one = th.zeros(self._dim, dtype=th.int64, device=self._image.device) one[i] = 1 intensyties_A = self._image[diag.edge_index[0], diag.edge_index[1]] intensyties_B = self._image[diag.edge_index[0] - one[0], diag.edge_index[1] - one[1], diag.edge_index[2] - one[2]] intensity_diff = th.exp(-th.abs(intensyties_A - intensyties_B)*self._scale_int_diff) del intensyties_A, intensyties_B displacement_A = data[:, diag.edge_index[0], diag.edge_index[1]] displacement_B = data[:, diag.edge_index[0] - one[0], diag.edge_index[1] - one[1], diag.edge_index[2] - one[2]] displacement_diff = displacement_A - displacement_B displacement_diff = th.sqrt(displacement_diff[0, :]**2 + displacement_diff[1, :]**2 + displacement_diff[2, :]**2) displacement_diff = th.exp(-self._scale_disp_diff*displacement_diff) norm_disp_A = th.sqrt(displacement_A[0, ...]**2 + displacement_A[1, ...]**2 + displacement_A[2, ...]**2) norm_disp_B = th.sqrt(displacement_B[0, ...]**2 + displacement_B[1, ...]**2 + displacement_B[2, ...]**2) image_gradient_A = self._image_gradient[diag.edge_index[0], diag.edge_index[1], :] image_gradient_B = self._image_gradient[diag.edge_index[0] - one[0], diag.edge_index[1] - one[1], diag.edge_index[2] - one[2], :] norm_A = th.sqrt(image_gradient_A[..., 0]**2 + image_gradient_A[..., 1]**2 + image_gradient_A[..., 2]**2) norm_B = th.sqrt(image_gradient_B[..., 0]**2 + image_gradient_B[..., 1]**2 + image_gradient_A[..., 2]**2) max_norm = th.max(norm_A, norm_B) del norm_A, norm_B max_grad = th.zeros_like(image_gradient_A) index = (norm_A - max_norm) == 0 max_grad[index] = image_gradient_A[index] max_grad[1 - index] = image_gradient_B[1 - index] del index, image_gradient_A, image_gradient_B phi_A = th.div(th.sum(th.mul(max_grad, displacement_A.t()), dim=1), th.mul(norm_disp_A, max_norm) + 1e-10) phi_B = th.div(th.sum(th.mul(max_grad, displacement_B.t()), dim=1), th.mul(norm_disp_B, max_norm) + 1e-10) weight = th.mul(intensity_diff,displacement_diff) + (1-intensity_diff)*((phi_A + phi_B)*0.5) weight = weight*(1 - self._scale_disp) + displacement_diff*self._scale_disp if self._edge_mean: diag.edge_values = diag.edge_values*self._edge_window + th.round(weight)*(1. - self._edge_window) else: diag.edge_values = weight # remove collapsed nodes if self._detect_node_collapse: self.remove_node_collapse()
[docs]class GraphDiffusionRegulariser(_DemonsRegulariser): def __init__(self, image_size, pixel_spacing, edge_updater, phi=1, dtype=th.float32, device='cpu'): super(GraphDiffusionRegulariser, self).__init__(pixel_spacing, dtype=dtype, device=device) self._graph = G.Graph(image_size, dtype=dtype, device=device) self._edge_updater = edge_updater self._edge_updater.set_laplace_matrix(self._graph.laplace_matrix) self._phi = phi self._krylov_dim = 30 self._image_size = image_size
[docs] def set_krylov_dim(self, krylov_dim): self._krylov_dim = krylov_dim
[docs] def get_edge_image(self): main_diag_laplace = th.reshape(self._graph.laplace_matrix.main_diag, self._image_size) return iu.Image(main_diag_laplace.unsqueeze_(0).unsqueeze(0), self._image_size, self._pixel_spacing, th.zeros(len(self._image_size))) # only zero origin supported yet
[docs] def regularise(self, data): for parameter in data: # no gradient calculation for the demons regularisation with th.no_grad(): dim = parameter.size()[0] # compute the graph diffusion regularisation for each dimension for i in range(dim): mat.expm_krylov(self._graph.laplace_matrix, parameter.data[i, ...].view(-1), phi=self._phi, krylov_dim=self._krylov_dim) # update the edge weights on the curren data self._edge_updater.update(parameter.data)