Source code for airlab.transformation.pairwise

# Copyright 2018 University of Basel, Center for medical Image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch as th
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import numpy as np

from ..utils import kernelFunction as utils

from . import utils as tu

"""
    Base class for a transformation
"""
class _Transformation(th.nn.Module):
    def __init__(self, image_size, diffeomorphic=False, dtype=th.float32, device='cpu'):
        super(_Transformation, self).__init__()

        self._dtype = dtype
        self._device = device
        self._dim = len(image_size)
        self._image_size = np.array(image_size)
        self._constant_displacement = None
        self._diffeomorphic = diffeomorphic
        self._constant_flow = None

        self._compute_flow = None

        if self._diffeomorphic:
            self._diffeomorphic_calculater = tu.Diffeomorphic(image_size, dtype=dtype, device=device)
        else:
            self._diffeomorphic_calculater = None

    def get_flow(self):

        if self._constant_flow is None:
            return self._compute_flow().detach()
        else:
            return self._compute_flow().detach() + self._constant_flow

    def set_constant_flow(self, flow):
        self._constant_flow = flow

    def get_displacement_numpy(self):

        if self._dim == 2:
            return th.unsqueeze(self().detach(), 0).cpu().numpy()
        elif self._dim == 3:
            return self().detach().cpu().numpy()

    def get_displacement(self):
            return self().detach()

    # def get_current_displacement(self):
    #
    #     if self._dim == 2:
    #         return th.unsqueeze(self().detach(), 0).cpu().numpy()
    #     elif self._dim == 3:
    #         return self().detach().cpu().numpy()

    # def set_constant_displacement(self, displacement):
    #
    #     self._constant_displacement = displacement

    # def get_inverse_transformation(self, displacement):
    #     if self._diffeomorphic:
    #         if self._dim == 2:
    #             inv_displacement = self._diffeomorphic_calculater.calculate(displacement * -1)
    #         else:
    #             inv_displacement = self._diffeomorphic_calculater.calculate(displacement * -1)
    #     else:
    #         print("error displacement ")
    #         inv_displacement = None
    #
    #     return inv_displacement

    def get_inverse_displacement(self):

        flow = self._concatenate_flows(self._compute_flow()).detach()

        if self._diffeomorphic:
                inv_displacement = self._diffeomorphic_calculater.calculate(flow * -1)
        else:
            print("error displacement ")
            inv_displacement = None

        return inv_displacement

    def _compute_diffeomorphic_displacement(self, flow):

        return self._diffeomorphic_calculater.calculate(flow)

    def _concatenate_flows(self, flow):

        if self._constant_flow is None:
            return flow
        else:
            return flow + self._constant_flow


[docs]class RigidTransformation(_Transformation): r""" Rigid centred transformation for 2D and 3D. Args: moving_image (Image): moving image for the registration opt_cm (bool): using center of as parameter for the optimisation """ def __init__(self, moving_image, opt_cm=False): super(RigidTransformation, self).__init__(image_size=moving_image.size, dtype=moving_image.dtype, device=moving_image.device) self._opt_cm = opt_cm grid = th.squeeze(tu.compute_grid(moving_image.size, dtype=self._dtype)) grid = th.cat((grid, th.ones(*[list(moving_image.size) + [1]], dtype=self._dtype)), self._dim)\ .to(device=self._device) self.register_buffer("_grid", grid) # compute the initial center of mass of the moving image intensity_sum = th.sum(moving_image.image) self._center_mass_x = th.sum(moving_image.image.squeeze() * self._grid[..., 0]) / intensity_sum self._center_mass_y = th.sum(moving_image.image.squeeze() * self._grid[..., 1]) / intensity_sum self._phi_z = Parameter(th.tensor(0.0)) self._t_x = Parameter(th.tensor(0.0)) self._t_y = Parameter(th.tensor(0.0)) self._trans_matrix_pos = None self._trans_matrix_cm = None self._trans_matrix_cm_rw = None self._rotation_matrix = None if self._opt_cm: self._center_mass_x = Parameter(self._center_mass_x) self._center_mass_y = Parameter(self._center_mass_y) if self._dim == 2: self._compute_transformation = self._compute_transformation_2d else: self._compute_transformation = self._compute_transformation_3d self._center_mass_z = th.sum(moving_image.image.squeeze() * self._grid[..., 2]) / intensity_sum self._t_z = Parameter(th.tensor(0.0)) self._phi_x = Parameter(th.tensor(0.0)) self._phi_y = Parameter(th.tensor(0.0)) if self._opt_cm: self._center_mass_z = Parameter(self._center_mass_z)
[docs] def init_translation(self, fixed_image): r""" Initialize the translation parameters with the difference between the center of mass of the fixed and the moving image Args: fixed_image (Image): Fixed image for the registration """ intensity_sum = th.sum(fixed_image.image) fixed_image_center_mass_x = th.sum(fixed_image.image.squeeze() * self._grid[..., 0]) / intensity_sum fixed_image_center_mass_y = th.sum(fixed_image.image.squeeze() * self._grid[..., 1]) / intensity_sum self._t_x = Parameter(self._center_mass_x - fixed_image_center_mass_x) self._t_y = Parameter(self._center_mass_y - fixed_image_center_mass_y) if self._dim == 3: fixed_image_center_mass_z = th.sum(fixed_image.image.squeeze() * self._grid[..., 2]) / intensity_sum self._t_z = Parameter(self._center_mass_z - fixed_image_center_mass_z)
@property def transformation_matrix(self): return self._compute_transformation_matrix()
[docs] def set_parameters(self, t, phi, rotation_center=None): """ Set parameters manually t (array): 2 or 3 dimensional array specifying the spatial translation phi (array): 1 or 3 dimensional array specifying the rotation angles rotation_center (array): 2 or 3 dimensional array specifying the rotation center (default is zeros) """ self._t_x = Parameter(th.tensor(t[0]).to(dtype=self._dtype, device=self._device)) self._t_y = Parameter(th.tensor(t[1]).to(dtype=self._dtype, device=self._device)) self._phi_z = Parameter(th.tensor(phi[0]).to(dtype=self._dtype, device=self._device)) if rotation_center is not None: self._center_mass_x = rotation_center[0] self._center_mass_y = rotation_center[1] if len(t) == 2: self._compute_transformation_2d() else: self._t_z = Parameter(th.tensor(t[2]).to(dtype=self._dtype, device=self._device)) self._phi_x = Parameter(th.tensor(phi[1]).to(dtype=self._dtype, device=self._device)) self._phi_y = Parameter(th.tensor(phi[2]).to(dtype=self._dtype, device=self._device)) if rotation_center is not None: self._center_mass_z = rotation_center[1] self._compute_transformation_3d()
def _compute_transformation_2d(self): self._trans_matrix_pos = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._trans_matrix_cm = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._trans_matrix_cm_rw = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._rotation_matrix = th.zeros(self._dim + 1, self._dim + 1, dtype=self._dtype, device=self._device) self._rotation_matrix[-1, -1] = 1 self._trans_matrix_pos[0, 2] = self._t_x self._trans_matrix_pos[1, 2] = self._t_y self._trans_matrix_cm[0, 2] = -self._center_mass_x self._trans_matrix_cm[1, 2] = -self._center_mass_y self._trans_matrix_cm_rw[0, 2] = self._center_mass_x self._trans_matrix_cm_rw[1, 2] = self._center_mass_y self._rotation_matrix[0, 0] = th.cos(self._phi_z) self._rotation_matrix[0, 1] = -th.sin(self._phi_z) self._rotation_matrix[1, 0] = th.sin(self._phi_z) self._rotation_matrix[1, 1] = th.cos(self._phi_z) def _compute_transformation_3d(self): self._trans_matrix_pos = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._trans_matrix_cm = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._trans_matrix_cm_rw = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._trans_matrix_pos[0, 3] = self._t_x self._trans_matrix_pos[1, 3] = self._t_y self._trans_matrix_pos[2, 3] = self._t_z self._trans_matrix_cm[0, 3] = -self._center_mass_x self._trans_matrix_cm[1, 3] = -self._center_mass_y self._trans_matrix_cm[2, 3] = -self._center_mass_z self._trans_matrix_cm_rw[0, 3] = self._center_mass_x self._trans_matrix_cm_rw[1, 3] = self._center_mass_y self._trans_matrix_cm_rw[2, 3] = self._center_mass_z R_x = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) R_x[1, 1] = th.cos(self._phi_x) R_x[1, 2] = -th.sin(self._phi_x) R_x[2, 1] = th.sin(self._phi_x) R_x[2, 2] = th.cos(self._phi_x) R_y = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) R_y[0, 0] = th.cos(self._phi_y) R_y[0, 2] = th.sin(self._phi_y) R_y[2, 0] = -th.sin(self._phi_y) R_y[2, 2] = th.cos(self._phi_y) R_z = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) R_z[0, 0] = th.cos(self._phi_z) R_z[0, 1] = -th.sin(self._phi_z) R_z[1, 0] = th.sin(self._phi_z) R_z[1, 1] = th.cos(self._phi_z) self._rotation_matrix = th.mm(th.mm(R_z, R_y), R_x) def _compute_transformation_matrix(self): transformation_matrix = th.mm(th.mm(th.mm(self._trans_matrix_pos, self._trans_matrix_cm), self._rotation_matrix), self._trans_matrix_cm_rw)[0:self._dim, :] return transformation_matrix def _compute_dense_flow(self, transformation_matrix): displacement = th.mm(self._grid.view(np.prod(self._image_size).tolist(), self._dim + 1), transformation_matrix.t()).view(*(self._image_size.tolist()), self._dim) \ - self._grid[..., :self._dim] return displacement
[docs] def print(self): for name, param in self.named_parameters(): print(name, param.item())
[docs] def compute_displacement(self, transformation_matrix): return self._compute_dense_flow(transformation_matrix)
[docs] def forward(self): self._compute_transformation() transformation_matrix = self._compute_transformation_matrix() flow = self._compute_dense_flow(transformation_matrix) return self._concatenate_flows(flow)
[docs]class SimilarityTransformation(RigidTransformation): r""" Similarity centred transformation for 2D and 3D. Args: moving_image (Image): moving image for the registration opt_cm (bool): using center of as parameter for the optimisation """ def __init__(self, moving_image, opt_cm=False): super(SimilarityTransformation, self).__init__(moving_image, opt_cm) self._scale_x = Parameter(th.tensor(1.0)) self._scale_y = Parameter(th.tensor(1.0)) self._scale_matrix = None if self._dim == 2: self._compute_transformation = self._compute_transformation_2d else: self._compute_transformation = self._compute_transformation_3d self._scale_z = Parameter(th.tensor(1.0))
[docs] def set_parameters(self, t, phi, scale, rotation_center=None): """ Set parameters manually t (array): 2 or 3 dimensional array specifying the spatial translation phi (array): 1 or 3 dimensional array specifying the rotation angles scale (array): 2 or 3 dimensional array specifying the scale in each dimension rotation_center (array): 2 or 3 dimensional array specifying the rotation center (default is zeros) """ super(SimilarityTransformation, self).set_parameters(t, phi, rotation_center) self._scale_x = Parameter(th.tensor(scale[0]).to(dtype=self._dtype, device=self._device)) self._scale_y = Parameter(th.tensor(scale[1]).to(dtype=self._dtype, device=self._device)) if len(t) == 2: self._compute_transformation_2d() else: self._scale_z = Parameter(th.tensor(scale[2]).to(dtype=self._dtype, device=self._device)) self._compute_transformation_3d()
def _compute_transformation_2d(self): super(SimilarityTransformation, self)._compute_transformation_2d() self._scale_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._scale_matrix[0, 0] = self._scale_x self._scale_matrix[1, 1] = self._scale_y def _compute_transformation_3d(self): super(SimilarityTransformation, self)._compute_transformation_3d() self._scale_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._scale_matrix[0, 0] = self._scale_x self._scale_matrix[1, 1] = self._scale_y self._scale_matrix[2, 2] = self._scale_z def _compute_transformation_matrix(self): transformation_matrix = th.mm(th.mm(th.mm(th.mm(self._trans_matrix_pos, self._trans_matrix_cm), self._rotation_matrix), self._scale_matrix), self._trans_matrix_cm_rw)[0:self._dim, :] return transformation_matrix
[docs] def forward(self): self._compute_transformation() transformation_matrix = self._compute_transformation_matrix() flow = self._compute_dense_flow(transformation_matrix) return self._concatenate_flows(flow)
[docs]class AffineTransformation(SimilarityTransformation): """ Affine centred transformation for 2D and 3D. Args: moving_image (Image): moving image for the registration opt_cm (bool): using center of as parameter for the optimisation """ def __init__(self, moving_image, opt_cm=False): super(AffineTransformation, self).__init__(moving_image, opt_cm) self._shear_y_x = Parameter(th.tensor(0.0)) self._shear_x_y = Parameter(th.tensor(0.0)) self._shear_matrix = None if self._dim == 2: self._compute_displacement = self._compute_transformation_2d else: self._compute_displacement = self._compute_transformation_3d self._shear_z_x = Parameter(th.tensor(0.0)) self._shear_z_y = Parameter(th.tensor(0.0)) self._shear_x_z = Parameter(th.tensor(0.0)) self._shear_y_z = Parameter(th.tensor(0.0))
[docs] def set_parameters(self, t, phi, scale, shear, rotation_center=None): """ Set parameters manually t (array): 2 or 3 dimensional array specifying the spatial translation phi (array): 1 or 3 dimensional array specifying the rotation angles scale (array): 2 or 3 dimensional array specifying the scale in each dimension shear (array): 2 or 6 dimensional array specifying the shear in each dimension: yx, xy, zx, zy, xz, yz rotation_center (array): 2 or 3 dimensional array specifying the rotation center (default is zeros) """ super(AffineTransformation, self).set_parameters(t, phi, scale, rotation_center) self._shear_y_x = Parameter(th.tensor(shear[0]).to(dtype=self._dtype, device=self._device)) self._shear_x_y = Parameter(th.tensor(shear[1]).to(dtype=self._dtype, device=self._device)) if len(t) == 2: self._compute_transformation_2d() else: self._shear_z_x = Parameter(th.tensor(shear[2]).to(dtype=self._dtype, device=self._device)) self._shear_z_y = Parameter(th.tensor(shear[3]).to(dtype=self._dtype, device=self._device)) self._shear_x_z = Parameter(th.tensor(shear[4]).to(dtype=self._dtype, device=self._device)) self._shear_y_z = Parameter(th.tensor(shear[5]).to(dtype=self._dtype, device=self._device)) self._compute_transformation_3d()
def _compute_transformation_2d(self): super(AffineTransformation, self)._compute_transformation_2d() self._shear_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._shear_matrix[0, 1] = self._shear_y_x self._shear_matrix[1, 0] = self._shear_x_y def _compute_transformation_3d(self): super(AffineTransformation, self)._compute_transformation_3d() self._shear_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device)) self._shear_matrix[0, 1] = self._shear_y_x self._shear_matrix[0, 2] = self._shear_z_x self._shear_matrix[1, 0] = self._shear_x_y self._shear_matrix[1, 2] = self._shear_z_y self._shear_matrix[2, 0] = self._shear_x_z self._shear_matrix[2, 1] = self._shear_y_z def _compute_transformation_matrix(self): transformation_matrix = th.mm(th.mm(th.mm(th.mm(th.mm(self._trans_matrix_pos, self._trans_matrix_cm), self._rotation_matrix),self._scale_matrix), self._shear_matrix), self._trans_matrix_cm_rw)[0:self._dim, :] return transformation_matrix
[docs] def forward(self): self._compute_transformation() transformation_matrix = self._compute_transformation_matrix() flow = self._compute_dense_flow(transformation_matrix) return self._concatenate_flows(flow)
[docs]class NonParametricTransformation(_Transformation): r""" None parametric transformation """ def __init__(self, image_size, diffeomorphic=False, dtype=th.float32, device='cpu'): super(NonParametricTransformation, self).__init__(image_size, diffeomorphic, dtype, device) self._tensor_size = [self._dim] + self._image_size.tolist() self.trans_parameters = Parameter(th.Tensor(*self._tensor_size)) self.trans_parameters.data.fill_(0) self.to(dtype=self._dtype, device=self._device) if self._dim == 2: self._compute_flow = self._compute_flow_2d else: self._compute_flow = self._compute_flow_3d
[docs] def set_start_parameter(self, parameters): if self._dim == 2: self.trans_parameters = Parameter(th.tensor(parameters.transpose(0, 2))) elif self._dim == 3: self.trans_parameters = Parameter(th.tensor(parameters.transpose(0, 1) .transpose(0, 2).transpose(0, 3)))
def _compute_flow_2d(self): return self.trans_parameters.transpose(0, 2).transpose(0, 1) def _compute_flow_3d(self): return self.trans_parameters.transpose(0, 3).transpose(0, 2).transpose(0, 1)
[docs] def forward(self): flow = self._concatenate_flows(self._compute_flow()) if self._diffeomorphic: displacement = self._compute_diffeomorphic_displacement(flow) else: displacement = flow return displacement
""" Base class for kernel transformations """ class _KernelTransformation(_Transformation): def __init__(self, image_size, diffeomorphic=False, dtype=th.float32, device='cpu'): super(_KernelTransformation, self).__init__(image_size, diffeomorphic, dtype, device) self._kernel = None self._stride = 1 self._padding = 0 self._displacement_tmp = None self._displacement = None assert self._dim == 2 or self._dim == 3 if self._dim == 2: self._compute_flow = self._compute_flow_2d else: self._compute_flow = self._compute_flow_3d def get_current_displacement(self): if self._dim == 2: return th.unsqueeze(self._compute_displacement().detach(), 0).cpu().numpy() elif self._dim == 3: return self._compute_displacement().detach().cpu().numpy() def _initialize(self): cp_grid = np.ceil(np.divide(self._image_size, self._stride)).astype(dtype=int) # new image size after convolution inner_image_size = np.multiply(self._stride, cp_grid) - (self._stride - 1) # add one control point at each side cp_grid = cp_grid + 2 # image size with additional control points new_image_size = np.multiply(self._stride, cp_grid) - (self._stride - 1) # center image between control points image_size_diff = inner_image_size - self._image_size image_size_diff_floor = np.floor((np.abs(image_size_diff)/2))*np.sign(image_size_diff) self._crop_start = image_size_diff_floor + np.remainder(image_size_diff, 2)*np.sign(image_size_diff) self._crop_end = image_size_diff_floor cp_grid = [1, self._dim] + cp_grid.tolist() # create transformation parameters self.trans_parameters = Parameter(th.Tensor(*cp_grid)) self.trans_parameters.data.fill_(0) # copy to gpu if needed self.to(dtype=self._dtype, device=self._device) # convert to integer self._padding = self._padding.astype(dtype=int).tolist() self._stride = self._stride.astype(dtype=int).tolist() self._crop_start = self._crop_start.astype(dtype=int) self._crop_end = self._crop_end.astype(dtype=int) size = [1, 1] + new_image_size.astype(dtype=int).tolist() self._displacement_tmp = th.empty(*size, dtype=self._dtype, device=self._device) size = [1, 1] + self._image_size.astype(dtype=int).tolist() self._displacement = th.empty(*size, dtype=self._dtype, device=self._device) def _compute_flow_2d(self): displacement_tmp = F.conv_transpose2d(self.trans_parameters, self._kernel, padding=self._padding, stride=self._stride, groups=2) # crop displacement return th.squeeze(displacement_tmp[:, :, self._stride[0] + self._crop_start[0]:-self._stride[0] - self._crop_end[0], self._stride[1] + self._crop_start[1]:-self._stride[1] - self._crop_end[1]].transpose_(1, 3).transpose(1, 2)) def _compute_flow_3d(self): # compute dense displacement displacement = F.conv_transpose3d(self.trans_parameters, self._kernel, padding=self._padding, stride=self._stride, groups=3) # crop displacement return th.squeeze(displacement[:, :, self._stride[0] + self._crop_start[0]:-self._stride[0] - self._crop_end[0], self._stride[1] + self._crop_start[1]:-self._stride[1] - self._crop_end[1], self._stride[2] + self._crop_start[2]:-self._stride[2] - self._crop_end[2] ].transpose_(1,4).transpose_(1,3).transpose_(1,2)) def forward(self): flow = self._concatenate_flows(self._compute_flow()) if self._diffeomorphic: displacement = self._compute_diffeomorphic_displacement(flow) else: displacement = flow return displacement """ bspline kernel transformation """
[docs]class BsplineTransformation(_KernelTransformation): def __init__(self, image_size, sigma, diffeomorphic=False, order=2, dtype=th.float32, device='cpu'): super(BsplineTransformation, self).__init__(image_size, diffeomorphic, dtype, device) self._stride = np.array(sigma) # compute bspline kernel self._kernel = utils.bspline_kernel(sigma, dim=self._dim, order=order, asTensor=True, dtype=dtype) self._padding = (np.array(self._kernel.size()) - 1) / 2 self._kernel.unsqueeze_(0).unsqueeze_(0) self._kernel = self._kernel.expand(self._dim, *((np.ones(self._dim + 1, dtype=int)*-1).tolist())) self._kernel = self._kernel.to(dtype=dtype, device=self._device) self._initialize()
""" Wendland kernel transformation """
[docs]class WendlandKernelTransformation(_KernelTransformation): """ Wendland Kernel Transform: Implements the kernel transform with the Wendland basis Parameters: sigma: specifies how many control points are used (each sigma pixels) cp_scale: specifies the extent of the kernel. how many control points are in the support of the kernel """ def __init__(self, image_size, sigma, cp_scale=2, diffeomorphic=False, ktype="C4", dtype=th.float32, device='cpu'): super(WendlandKernelTransformation, self).__init__(image_size, diffeomorphic, dtype, device) self._stride = np.array(sigma) # compute bspline kernel self._kernel = utils.wendland_kernel(np.array(sigma)*cp_scale, dim=self._dim, type=ktype, asTensor=True, dtype=dtype) self._padding = (np.array(self._kernel.size()) - 1) / 2 self._kernel.unsqueeze_(0).unsqueeze_(0) self._kernel = self._kernel.expand(self._dim, *((np.ones(self._dim + 1,dtype=int) * -1).tolist())) self._kernel = self._kernel.to(dtype=dtype, device=self._device) self._initialize()