Source code for airlab.transformation.pairwise
# Copyright 2018 University of Basel, Center for medical Image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch as th
from torch.nn.parameter import Parameter
import torch.nn.functional as F
import numpy as np
from ..utils import kernelFunction as utils
from . import utils as tu
"""
Base class for a transformation
"""
class _Transformation(th.nn.Module):
def __init__(self, image_size, diffeomorphic=False, dtype=th.float32, device='cpu'):
super(_Transformation, self).__init__()
self._dtype = dtype
self._device = device
self._dim = len(image_size)
self._image_size = np.array(image_size)
self._constant_displacement = None
self._diffeomorphic = diffeomorphic
self._constant_flow = None
self._compute_flow = None
if self._diffeomorphic:
self._diffeomorphic_calculater = tu.Diffeomorphic(image_size, dtype=dtype, device=device)
else:
self._diffeomorphic_calculater = None
def get_flow(self):
if self._constant_flow is None:
return self._compute_flow().detach()
else:
return self._compute_flow().detach() + self._constant_flow
def set_constant_flow(self, flow):
self._constant_flow = flow
def get_displacement_numpy(self):
if self._dim == 2:
return th.unsqueeze(self().detach(), 0).cpu().numpy()
elif self._dim == 3:
return self().detach().cpu().numpy()
def get_displacement(self):
return self().detach()
# def get_current_displacement(self):
#
# if self._dim == 2:
# return th.unsqueeze(self().detach(), 0).cpu().numpy()
# elif self._dim == 3:
# return self().detach().cpu().numpy()
# def set_constant_displacement(self, displacement):
#
# self._constant_displacement = displacement
# def get_inverse_transformation(self, displacement):
# if self._diffeomorphic:
# if self._dim == 2:
# inv_displacement = self._diffeomorphic_calculater.calculate(displacement * -1)
# else:
# inv_displacement = self._diffeomorphic_calculater.calculate(displacement * -1)
# else:
# print("error displacement ")
# inv_displacement = None
#
# return inv_displacement
def get_inverse_displacement(self):
flow = self._concatenate_flows(self._compute_flow()).detach()
if self._diffeomorphic:
inv_displacement = self._diffeomorphic_calculater.calculate(flow * -1)
else:
print("error displacement ")
inv_displacement = None
return inv_displacement
def _compute_diffeomorphic_displacement(self, flow):
return self._diffeomorphic_calculater.calculate(flow)
def _concatenate_flows(self, flow):
if self._constant_flow is None:
return flow
else:
return flow + self._constant_flow
[docs]class RigidTransformation(_Transformation):
r"""
Rigid centred transformation for 2D and 3D.
Args:
moving_image (Image): moving image for the registration
opt_cm (bool): using center of as parameter for the optimisation
"""
def __init__(self, moving_image, opt_cm=False):
super(RigidTransformation, self).__init__(image_size=moving_image.size,
dtype=moving_image.dtype,
device=moving_image.device)
self._opt_cm = opt_cm
grid = th.squeeze(tu.compute_grid(moving_image.size, dtype=self._dtype))
grid = th.cat((grid, th.ones(*[list(moving_image.size) + [1]], dtype=self._dtype)), self._dim)\
.to(device=self._device)
self.register_buffer("_grid", grid)
# compute the initial center of mass of the moving image
intensity_sum = th.sum(moving_image.image)
self._center_mass_x = th.sum(moving_image.image.squeeze() * self._grid[..., 0]) / intensity_sum
self._center_mass_y = th.sum(moving_image.image.squeeze() * self._grid[..., 1]) / intensity_sum
self._phi_z = Parameter(th.tensor(0.0))
self._t_x = Parameter(th.tensor(0.0))
self._t_y = Parameter(th.tensor(0.0))
self._trans_matrix_pos = None
self._trans_matrix_cm = None
self._trans_matrix_cm_rw = None
self._rotation_matrix = None
if self._opt_cm:
self._center_mass_x = Parameter(self._center_mass_x)
self._center_mass_y = Parameter(self._center_mass_y)
if self._dim == 2:
self._compute_transformation = self._compute_transformation_2d
else:
self._compute_transformation = self._compute_transformation_3d
self._center_mass_z = th.sum(moving_image.image.squeeze() * self._grid[..., 2]) / intensity_sum
self._t_z = Parameter(th.tensor(0.0))
self._phi_x = Parameter(th.tensor(0.0))
self._phi_y = Parameter(th.tensor(0.0))
if self._opt_cm:
self._center_mass_z = Parameter(self._center_mass_z)
[docs] def init_translation(self, fixed_image):
r"""
Initialize the translation parameters with the difference between the center of mass of the
fixed and the moving image
Args:
fixed_image (Image): Fixed image for the registration
"""
intensity_sum = th.sum(fixed_image.image)
fixed_image_center_mass_x = th.sum(fixed_image.image.squeeze() * self._grid[..., 0]) / intensity_sum
fixed_image_center_mass_y = th.sum(fixed_image.image.squeeze() * self._grid[..., 1]) / intensity_sum
self._t_x = Parameter(self._center_mass_x - fixed_image_center_mass_x)
self._t_y = Parameter(self._center_mass_y - fixed_image_center_mass_y)
if self._dim == 3:
fixed_image_center_mass_z = th.sum(fixed_image.image.squeeze() * self._grid[..., 2]) / intensity_sum
self._t_z = Parameter(self._center_mass_z - fixed_image_center_mass_z)
@property
def transformation_matrix(self):
return self._compute_transformation_matrix()
[docs] def set_parameters(self, t, phi, rotation_center=None):
"""
Set parameters manually
t (array): 2 or 3 dimensional array specifying the spatial translation
phi (array): 1 or 3 dimensional array specifying the rotation angles
rotation_center (array): 2 or 3 dimensional array specifying the rotation center (default is zeros)
"""
self._t_x = Parameter(th.tensor(t[0]).to(dtype=self._dtype, device=self._device))
self._t_y = Parameter(th.tensor(t[1]).to(dtype=self._dtype, device=self._device))
self._phi_z = Parameter(th.tensor(phi[0]).to(dtype=self._dtype, device=self._device))
if rotation_center is not None:
self._center_mass_x = rotation_center[0]
self._center_mass_y = rotation_center[1]
if len(t) == 2:
self._compute_transformation_2d()
else:
self._t_z = Parameter(th.tensor(t[2]).to(dtype=self._dtype, device=self._device))
self._phi_x = Parameter(th.tensor(phi[1]).to(dtype=self._dtype, device=self._device))
self._phi_y = Parameter(th.tensor(phi[2]).to(dtype=self._dtype, device=self._device))
if rotation_center is not None:
self._center_mass_z = rotation_center[1]
self._compute_transformation_3d()
def _compute_transformation_2d(self):
self._trans_matrix_pos = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._trans_matrix_cm = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._trans_matrix_cm_rw = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._rotation_matrix = th.zeros(self._dim + 1, self._dim + 1, dtype=self._dtype, device=self._device)
self._rotation_matrix[-1, -1] = 1
self._trans_matrix_pos[0, 2] = self._t_x
self._trans_matrix_pos[1, 2] = self._t_y
self._trans_matrix_cm[0, 2] = -self._center_mass_x
self._trans_matrix_cm[1, 2] = -self._center_mass_y
self._trans_matrix_cm_rw[0, 2] = self._center_mass_x
self._trans_matrix_cm_rw[1, 2] = self._center_mass_y
self._rotation_matrix[0, 0] = th.cos(self._phi_z)
self._rotation_matrix[0, 1] = -th.sin(self._phi_z)
self._rotation_matrix[1, 0] = th.sin(self._phi_z)
self._rotation_matrix[1, 1] = th.cos(self._phi_z)
def _compute_transformation_3d(self):
self._trans_matrix_pos = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._trans_matrix_cm = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._trans_matrix_cm_rw = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._trans_matrix_pos[0, 3] = self._t_x
self._trans_matrix_pos[1, 3] = self._t_y
self._trans_matrix_pos[2, 3] = self._t_z
self._trans_matrix_cm[0, 3] = -self._center_mass_x
self._trans_matrix_cm[1, 3] = -self._center_mass_y
self._trans_matrix_cm[2, 3] = -self._center_mass_z
self._trans_matrix_cm_rw[0, 3] = self._center_mass_x
self._trans_matrix_cm_rw[1, 3] = self._center_mass_y
self._trans_matrix_cm_rw[2, 3] = self._center_mass_z
R_x = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
R_x[1, 1] = th.cos(self._phi_x)
R_x[1, 2] = -th.sin(self._phi_x)
R_x[2, 1] = th.sin(self._phi_x)
R_x[2, 2] = th.cos(self._phi_x)
R_y = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
R_y[0, 0] = th.cos(self._phi_y)
R_y[0, 2] = th.sin(self._phi_y)
R_y[2, 0] = -th.sin(self._phi_y)
R_y[2, 2] = th.cos(self._phi_y)
R_z = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
R_z[0, 0] = th.cos(self._phi_z)
R_z[0, 1] = -th.sin(self._phi_z)
R_z[1, 0] = th.sin(self._phi_z)
R_z[1, 1] = th.cos(self._phi_z)
self._rotation_matrix = th.mm(th.mm(R_z, R_y), R_x)
def _compute_transformation_matrix(self):
transformation_matrix = th.mm(th.mm(th.mm(self._trans_matrix_pos, self._trans_matrix_cm),
self._rotation_matrix), self._trans_matrix_cm_rw)[0:self._dim, :]
return transformation_matrix
def _compute_dense_flow(self, transformation_matrix):
displacement = th.mm(self._grid.view(np.prod(self._image_size).tolist(), self._dim + 1),
transformation_matrix.t()).view(*(self._image_size.tolist()), self._dim) \
- self._grid[..., :self._dim]
return displacement
[docs] def compute_displacement(self, transformation_matrix):
return self._compute_dense_flow(transformation_matrix)
[docs] def forward(self):
self._compute_transformation()
transformation_matrix = self._compute_transformation_matrix()
flow = self._compute_dense_flow(transformation_matrix)
return self._concatenate_flows(flow)
[docs]class SimilarityTransformation(RigidTransformation):
r"""
Similarity centred transformation for 2D and 3D.
Args:
moving_image (Image): moving image for the registration
opt_cm (bool): using center of as parameter for the optimisation
"""
def __init__(self, moving_image, opt_cm=False):
super(SimilarityTransformation, self).__init__(moving_image, opt_cm)
self._scale_x = Parameter(th.tensor(1.0))
self._scale_y = Parameter(th.tensor(1.0))
self._scale_matrix = None
if self._dim == 2:
self._compute_transformation = self._compute_transformation_2d
else:
self._compute_transformation = self._compute_transformation_3d
self._scale_z = Parameter(th.tensor(1.0))
[docs] def set_parameters(self, t, phi, scale, rotation_center=None):
"""
Set parameters manually
t (array): 2 or 3 dimensional array specifying the spatial translation
phi (array): 1 or 3 dimensional array specifying the rotation angles
scale (array): 2 or 3 dimensional array specifying the scale in each dimension
rotation_center (array): 2 or 3 dimensional array specifying the rotation center (default is zeros)
"""
super(SimilarityTransformation, self).set_parameters(t, phi, rotation_center)
self._scale_x = Parameter(th.tensor(scale[0]).to(dtype=self._dtype, device=self._device))
self._scale_y = Parameter(th.tensor(scale[1]).to(dtype=self._dtype, device=self._device))
if len(t) == 2:
self._compute_transformation_2d()
else:
self._scale_z = Parameter(th.tensor(scale[2]).to(dtype=self._dtype, device=self._device))
self._compute_transformation_3d()
def _compute_transformation_2d(self):
super(SimilarityTransformation, self)._compute_transformation_2d()
self._scale_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._scale_matrix[0, 0] = self._scale_x
self._scale_matrix[1, 1] = self._scale_y
def _compute_transformation_3d(self):
super(SimilarityTransformation, self)._compute_transformation_3d()
self._scale_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._scale_matrix[0, 0] = self._scale_x
self._scale_matrix[1, 1] = self._scale_y
self._scale_matrix[2, 2] = self._scale_z
def _compute_transformation_matrix(self):
transformation_matrix = th.mm(th.mm(th.mm(th.mm(self._trans_matrix_pos, self._trans_matrix_cm),
self._rotation_matrix), self._scale_matrix),
self._trans_matrix_cm_rw)[0:self._dim, :]
return transformation_matrix
[docs] def forward(self):
self._compute_transformation()
transformation_matrix = self._compute_transformation_matrix()
flow = self._compute_dense_flow(transformation_matrix)
return self._concatenate_flows(flow)
[docs]class AffineTransformation(SimilarityTransformation):
"""
Affine centred transformation for 2D and 3D.
Args:
moving_image (Image): moving image for the registration
opt_cm (bool): using center of as parameter for the optimisation
"""
def __init__(self, moving_image, opt_cm=False):
super(AffineTransformation, self).__init__(moving_image, opt_cm)
self._shear_y_x = Parameter(th.tensor(0.0))
self._shear_x_y = Parameter(th.tensor(0.0))
self._shear_matrix = None
if self._dim == 2:
self._compute_displacement = self._compute_transformation_2d
else:
self._compute_displacement = self._compute_transformation_3d
self._shear_z_x = Parameter(th.tensor(0.0))
self._shear_z_y = Parameter(th.tensor(0.0))
self._shear_x_z = Parameter(th.tensor(0.0))
self._shear_y_z = Parameter(th.tensor(0.0))
[docs] def set_parameters(self, t, phi, scale, shear, rotation_center=None):
"""
Set parameters manually
t (array): 2 or 3 dimensional array specifying the spatial translation
phi (array): 1 or 3 dimensional array specifying the rotation angles
scale (array): 2 or 3 dimensional array specifying the scale in each dimension
shear (array): 2 or 6 dimensional array specifying the shear in each dimension: yx, xy, zx, zy, xz, yz
rotation_center (array): 2 or 3 dimensional array specifying the rotation center (default is zeros)
"""
super(AffineTransformation, self).set_parameters(t, phi, scale, rotation_center)
self._shear_y_x = Parameter(th.tensor(shear[0]).to(dtype=self._dtype, device=self._device))
self._shear_x_y = Parameter(th.tensor(shear[1]).to(dtype=self._dtype, device=self._device))
if len(t) == 2:
self._compute_transformation_2d()
else:
self._shear_z_x = Parameter(th.tensor(shear[2]).to(dtype=self._dtype, device=self._device))
self._shear_z_y = Parameter(th.tensor(shear[3]).to(dtype=self._dtype, device=self._device))
self._shear_x_z = Parameter(th.tensor(shear[4]).to(dtype=self._dtype, device=self._device))
self._shear_y_z = Parameter(th.tensor(shear[5]).to(dtype=self._dtype, device=self._device))
self._compute_transformation_3d()
def _compute_transformation_2d(self):
super(AffineTransformation, self)._compute_transformation_2d()
self._shear_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._shear_matrix[0, 1] = self._shear_y_x
self._shear_matrix[1, 0] = self._shear_x_y
def _compute_transformation_3d(self):
super(AffineTransformation, self)._compute_transformation_3d()
self._shear_matrix = th.diag(th.ones(self._dim + 1, dtype=self._dtype, device=self._device))
self._shear_matrix[0, 1] = self._shear_y_x
self._shear_matrix[0, 2] = self._shear_z_x
self._shear_matrix[1, 0] = self._shear_x_y
self._shear_matrix[1, 2] = self._shear_z_y
self._shear_matrix[2, 0] = self._shear_x_z
self._shear_matrix[2, 1] = self._shear_y_z
def _compute_transformation_matrix(self):
transformation_matrix = th.mm(th.mm(th.mm(th.mm(th.mm(self._trans_matrix_pos, self._trans_matrix_cm),
self._rotation_matrix),self._scale_matrix), self._shear_matrix),
self._trans_matrix_cm_rw)[0:self._dim, :]
return transformation_matrix
[docs] def forward(self):
self._compute_transformation()
transformation_matrix = self._compute_transformation_matrix()
flow = self._compute_dense_flow(transformation_matrix)
return self._concatenate_flows(flow)
[docs]class NonParametricTransformation(_Transformation):
r"""
None parametric transformation
"""
def __init__(self, image_size, diffeomorphic=False, dtype=th.float32, device='cpu'):
super(NonParametricTransformation, self).__init__(image_size, diffeomorphic, dtype, device)
self._tensor_size = [self._dim] + self._image_size.tolist()
self.trans_parameters = Parameter(th.Tensor(*self._tensor_size))
self.trans_parameters.data.fill_(0)
self.to(dtype=self._dtype, device=self._device)
if self._dim == 2:
self._compute_flow = self._compute_flow_2d
else:
self._compute_flow = self._compute_flow_3d
[docs] def set_start_parameter(self, parameters):
if self._dim == 2:
self.trans_parameters = Parameter(th.tensor(parameters.transpose(0, 2)))
elif self._dim == 3:
self.trans_parameters = Parameter(th.tensor(parameters.transpose(0, 1)
.transpose(0, 2).transpose(0, 3)))
def _compute_flow_2d(self):
return self.trans_parameters.transpose(0, 2).transpose(0, 1)
def _compute_flow_3d(self):
return self.trans_parameters.transpose(0, 3).transpose(0, 2).transpose(0, 1)
[docs] def forward(self):
flow = self._concatenate_flows(self._compute_flow())
if self._diffeomorphic:
displacement = self._compute_diffeomorphic_displacement(flow)
else:
displacement = flow
return displacement
"""
Base class for kernel transformations
"""
class _KernelTransformation(_Transformation):
def __init__(self, image_size, diffeomorphic=False, dtype=th.float32, device='cpu'):
super(_KernelTransformation, self).__init__(image_size, diffeomorphic, dtype, device)
self._kernel = None
self._stride = 1
self._padding = 0
self._displacement_tmp = None
self._displacement = None
assert self._dim == 2 or self._dim == 3
if self._dim == 2:
self._compute_flow = self._compute_flow_2d
else:
self._compute_flow = self._compute_flow_3d
def get_current_displacement(self):
if self._dim == 2:
return th.unsqueeze(self._compute_displacement().detach(), 0).cpu().numpy()
elif self._dim == 3:
return self._compute_displacement().detach().cpu().numpy()
def _initialize(self):
cp_grid = np.ceil(np.divide(self._image_size, self._stride)).astype(dtype=int)
# new image size after convolution
inner_image_size = np.multiply(self._stride, cp_grid) - (self._stride - 1)
# add one control point at each side
cp_grid = cp_grid + 2
# image size with additional control points
new_image_size = np.multiply(self._stride, cp_grid) - (self._stride - 1)
# center image between control points
image_size_diff = inner_image_size - self._image_size
image_size_diff_floor = np.floor((np.abs(image_size_diff)/2))*np.sign(image_size_diff)
self._crop_start = image_size_diff_floor + np.remainder(image_size_diff, 2)*np.sign(image_size_diff)
self._crop_end = image_size_diff_floor
cp_grid = [1, self._dim] + cp_grid.tolist()
# create transformation parameters
self.trans_parameters = Parameter(th.Tensor(*cp_grid))
self.trans_parameters.data.fill_(0)
# copy to gpu if needed
self.to(dtype=self._dtype, device=self._device)
# convert to integer
self._padding = self._padding.astype(dtype=int).tolist()
self._stride = self._stride.astype(dtype=int).tolist()
self._crop_start = self._crop_start.astype(dtype=int)
self._crop_end = self._crop_end.astype(dtype=int)
size = [1, 1] + new_image_size.astype(dtype=int).tolist()
self._displacement_tmp = th.empty(*size, dtype=self._dtype, device=self._device)
size = [1, 1] + self._image_size.astype(dtype=int).tolist()
self._displacement = th.empty(*size, dtype=self._dtype, device=self._device)
def _compute_flow_2d(self):
displacement_tmp = F.conv_transpose2d(self.trans_parameters, self._kernel,
padding=self._padding, stride=self._stride, groups=2)
# crop displacement
return th.squeeze(displacement_tmp[:, :,
self._stride[0] + self._crop_start[0]:-self._stride[0] - self._crop_end[0],
self._stride[1] + self._crop_start[1]:-self._stride[1] - self._crop_end[1]].transpose_(1, 3).transpose(1, 2))
def _compute_flow_3d(self):
# compute dense displacement
displacement = F.conv_transpose3d(self.trans_parameters, self._kernel,
padding=self._padding, stride=self._stride, groups=3)
# crop displacement
return th.squeeze(displacement[:, :, self._stride[0] + self._crop_start[0]:-self._stride[0] - self._crop_end[0],
self._stride[1] + self._crop_start[1]:-self._stride[1] - self._crop_end[1],
self._stride[2] + self._crop_start[2]:-self._stride[2] - self._crop_end[2]
].transpose_(1,4).transpose_(1,3).transpose_(1,2))
def forward(self):
flow = self._concatenate_flows(self._compute_flow())
if self._diffeomorphic:
displacement = self._compute_diffeomorphic_displacement(flow)
else:
displacement = flow
return displacement
"""
bspline kernel transformation
"""
[docs]class BsplineTransformation(_KernelTransformation):
def __init__(self, image_size, sigma, diffeomorphic=False, order=2, dtype=th.float32, device='cpu'):
super(BsplineTransformation, self).__init__(image_size, diffeomorphic, dtype, device)
self._stride = np.array(sigma)
# compute bspline kernel
self._kernel = utils.bspline_kernel(sigma, dim=self._dim, order=order, asTensor=True, dtype=dtype)
self._padding = (np.array(self._kernel.size()) - 1) / 2
self._kernel.unsqueeze_(0).unsqueeze_(0)
self._kernel = self._kernel.expand(self._dim, *((np.ones(self._dim + 1, dtype=int)*-1).tolist()))
self._kernel = self._kernel.to(dtype=dtype, device=self._device)
self._initialize()
"""
Wendland kernel transformation
"""
[docs]class WendlandKernelTransformation(_KernelTransformation):
"""
Wendland Kernel Transform:
Implements the kernel transform with the Wendland basis
Parameters:
sigma: specifies how many control points are used (each sigma pixels)
cp_scale: specifies the extent of the kernel. how many control points are in the support of the kernel
"""
def __init__(self, image_size, sigma, cp_scale=2, diffeomorphic=False, ktype="C4", dtype=th.float32, device='cpu'):
super(WendlandKernelTransformation, self).__init__(image_size, diffeomorphic, dtype, device)
self._stride = np.array(sigma)
# compute bspline kernel
self._kernel = utils.wendland_kernel(np.array(sigma)*cp_scale, dim=self._dim, type=ktype, asTensor=True, dtype=dtype)
self._padding = (np.array(self._kernel.size()) - 1) / 2
self._kernel.unsqueeze_(0).unsqueeze_(0)
self._kernel = self._kernel.expand(self._dim, *((np.ones(self._dim + 1,dtype=int) * -1).tolist()))
self._kernel = self._kernel.to(dtype=dtype, device=self._device)
self._initialize()