# Copyright 2018 University of Basel, Center for medical Image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch as th
import torch.nn.functional as F
import numpy as np
from .. import transformation as T
from ..transformation import utils as tu
from ..utils import kernelFunction as utils
# Loss base class (standard from PyTorch)
class _PairwiseImageLoss(th.nn.modules.Module):
def __init__(self, fixed_image, moving_image, fixed_mask=None, moving_mask=None, size_average=True, reduce=True):
super(_PairwiseImageLoss, self).__init__()
self._size_average = size_average
self._reduce = reduce
self._name = "parent"
self._warped_moving_image = None
self._warped_moving_mask = None
self._weight = 1
self._moving_image = moving_image
self._moving_mask = moving_mask
self._fixed_image = fixed_image
self._fixed_mask = fixed_mask
self._grid = None
assert self._moving_image != None and self._fixed_image != None
# TODO allow different image size for each image in the future
assert self._moving_image.size == self._fixed_image.size
assert self._moving_image.device == self._fixed_image.device
assert len(self._moving_image.size) == 2 or len(self._moving_image.size) == 3
self._grid = T.utils.compute_grid(self._moving_image.size, dtype=self._moving_image.dtype,
device=self._moving_image.device)
self._dtype = self._moving_image.dtype
self._device = self._moving_image.device
@property
def name(self):
return self._name
def GetWarpedImage(self):
return self._warped_moving_image[0, 0, ...].detach().cpu()
def GetCurrentMask(self, displacement):
"""
Computes a mask defining if pixels are warped outside the image domain, or if they fall into
a fixed image mask or a warped moving image mask.
return (Tensor): maks array
"""
# exclude points which are transformed outside the image domain
mask = th.zeros_like(self._fixed_image.image, dtype=th.uint8, device=self._device)
for dim in range(displacement.size()[-1]):
mask += displacement[..., dim].gt(1) + displacement[..., dim].lt(-1)
mask = mask == 0
# and exclude points which are masked by the warped moving and the fixed mask
if not self._moving_mask is None:
self._warped_moving_mask = F.grid_sample(self._moving_mask.image, displacement)
self._warped_moving_mask = self._warped_moving_mask >= 0.5
# if either the warped moving mask or the fixed mask is zero take zero,
# otherwise take the value of mask
if not self._fixed_mask is None:
mask = th.where(((self._warped_moving_mask == 0) | (self._fixed_mask == 0)), th.zeros_like(mask), mask)
else:
mask = th.where((self._warped_moving_mask == 0), th.zeros_like(mask), mask)
return mask
def set_loss_weight(self, weight):
self._weight = weight
# conditional return
def return_loss(self, tensor):
if self._size_average and self._reduce:
return tensor.mean()*self._weight
if not self._size_average and self._reduce:
return tensor.sum()*self._weight
if not self.reduce:
return tensor*self._weight
[docs]class MSE(_PairwiseImageLoss):
r""" The mean square error loss is a simple and fast to compute point-wise measure
which is well suited for monomodal image registration.
.. math::
\mathcal{S}_{\text{MSE}} := \frac{1}{\vert \mathcal{X} \vert}\sum_{x\in\mathcal{X}}
\Big(I_M\big(x+f(x)\big) - I_F\big(x\big)\Big)^2
Args:
fixed_image (Image): Fixed image for the registration
moving_image (Image): Moving image for the registration
size_average (bool): Average loss function
reduce (bool): Reduce loss function to a single value
"""
def __init__(self, fixed_image, moving_image, fixed_mask=None, moving_mask=None, size_average=True, reduce=True):
super(MSE, self).__init__(fixed_image, moving_image, fixed_mask, moving_mask, size_average, reduce)
self._name = "mse"
self.warped_moving_image = None
[docs] def forward(self, displacement):
# compute displacement field
displacement = self._grid + displacement
# compute current mask
mask = super(MSE, self).GetCurrentMask(displacement)
# warp moving image with dispalcement field
self.warped_moving_image = F.grid_sample(self._moving_image.image, displacement)
# compute squared differences
value = (self.warped_moving_image - self._fixed_image.image).pow(2)
# mask values
value = th.masked_select(value, mask)
return self.return_loss(value)
[docs]class NCC(_PairwiseImageLoss):
r""" The normalized cross correlation loss is a measure for image pairs with a linear
intensity relation.
.. math::
\mathcal{S}_{\text{NCC}} := \frac{\sum I_F\cdot (I_M\circ f)
- \sum\text{E}(I_F)\text{E}(I_M\circ f)}
{\vert\mathcal{X}\vert\cdot\sum\text{Var}(I_F)\text{Var}(I_M\circ f)}
Args:
fixed_image (Image): Fixed image for the registration
moving_image (Image): Moving image for the registration
"""
def __init__(self, fixed_image, moving_image, fixed_mask=None, moving_mask=None):
super(NCC, self).__init__(fixed_image, moving_image, fixed_mask, moving_mask, False, False)
self._name = "ncc"
self.warped_moving_image = th.empty_like(self._moving_image.image, dtype=self._dtype, device=self._device)
[docs] def forward(self, displacement):
# compute displacement field
displacement = self._grid + displacement
# compute current mask
mask = super(NCC, self).GetCurrentMask(displacement)
self._warped_moving_image = F.grid_sample(self._moving_image.image, displacement)
moving_image_valid = th.masked_select(self._warped_moving_image, mask)
fixed_image_valid = th.masked_select(self._fixed_image.image, mask)
value = -1.*th.sum((fixed_image_valid - th.mean(fixed_image_valid))*(moving_image_valid - th.mean(moving_image_valid)))\
/th.sqrt(th.sum((fixed_image_valid - th.mean(fixed_image_valid))**2)*th.sum((moving_image_valid - th.mean(moving_image_valid))**2) + 1e-10)
return value
"""
Local Normaliced Cross Corelation Image Loss
"""
[docs]class LCC(_PairwiseImageLoss):
def __init__(self, fixed_image, moving_image,fixed_mask=None, moving_mask=None, sigma=[3], kernel_type="box", size_average=True, reduce=True):
super(LCC, self).__init__(fixed_image, moving_image, fixed_mask, moving_mask, size_average, reduce)
self._name = "lcc"
self.warped_moving_image = th.empty_like(self._moving_image.image, dtype=self._dtype, device=self._device)
self._kernel = None
dim = len(self._moving_image.size)
sigma = np.array(sigma)
if sigma.size != dim:
sigma_app = sigma[-1]
while sigma.size != dim:
sigma = np.append(sigma, sigma_app)
if kernel_type == "box":
kernel_size = sigma*2 + 1
self._kernel = th.ones(*kernel_size.tolist(), dtype=self._dtype, device=self._device) \
/ float(np.product(kernel_size)**2)
elif kernel_type == "gaussian":
self._kernel = utils.gaussian_kernel(sigma, dim, asTensor=True, dtype=self._dtype, device=self._device)
self._kernel.unsqueeze_(0).unsqueeze_(0)
if dim == 2:
self._lcc_loss = self._lcc_loss_2d # 2d lcc
self._mean_fixed_image = F.conv2d(self._fixed_image.image, self._kernel)
self._variance_fixed_image = F.conv2d(self._fixed_image.image.pow(2), self._kernel) \
- (self._mean_fixed_image.pow(2))
elif dim == 3:
self._lcc_loss = self._lcc_loss_3d # 3d lcc
self._mean_fixed_image = F.conv3d(self._fixed_image.image, self._kernel)
self._variance_fixed_image = F.conv3d(self._fixed_image.image.pow(2), self._kernel) \
- (self._mean_fixed_image.pow(2))
def _lcc_loss_2d(self, warped_image, mask):
mean_moving_image = F.conv2d(warped_image, self._kernel)
variance_moving_image = F.conv2d(warped_image.pow(2), self._kernel) - (mean_moving_image.pow(2))
mean_fixed_moving_image = F.conv2d(self._fixed_image.image * warped_image, self._kernel)
cc = (mean_fixed_moving_image - mean_moving_image*self._mean_fixed_image)**2 \
/ (variance_moving_image*self._variance_fixed_image + 1e-10)
mask = F.conv2d(mask, self._kernel)
mask = mask == 0
return -1.0*th.masked_select(cc, mask)
def _lcc_loss_3d(self, warped_image, mask):
mean_moving_image = F.conv3d(warped_image, self._kernel)
variance_moving_image = F.conv3d(warped_image.pow(2), self._kernel) - (mean_moving_image.pow(2))
mean_fixed_moving_image = F.conv3d(self._fixed_image.image * warped_image, self._kernel)
cc = (mean_fixed_moving_image - mean_moving_image*self._mean_fixed_image)**2\
/(variance_moving_image*self._variance_fixed_image + 1e-10)
mask = F.conv3d(mask, self._kernel)
mask = mask == 0
return -1.0 * th.masked_select(cc, mask)
[docs] def forward(self, displacement):
# compute displacement field
displacement = self._grid + displacement
# compute current mask
mask = super(LCC, self).GetCurrentMask(displacement)
mask = 1-mask
mask = mask.to(dtype=self._dtype, device=self._device)
self._warped_moving_image = F.grid_sample(self._moving_image.image, displacement)
return self.return_loss(self._lcc_loss(self._warped_moving_image, mask))
[docs]class MI(_PairwiseImageLoss):
r""" Implementation of the Mutual Information image loss.
.. math::
\mathcal{S}_{\text{MI}} := H(F, M) - H(F|M) - H(M|F)
Args:
fixed_image (Image): Fixed image for the registration
moving_image (Image): Moving image for the registration
bins (int): Number of bins for the intensity distribution
sigma (float): Kernel sigma for the intensity distribution approximation
spatial_samples (float): Percentage of pixels used for the intensity distribution approximation
background: Method to handle background pixels. None: Set background to the min value of image
"mean": Set the background to the mean value of the image
float: Set the background value to the input value
size_average (bool): Average loss function
reduce (bool): Reduce loss function to a single value
"""
def __init__(self, fixed_image, moving_image, fixed_mask=None, moving_mask=None, bins=64, sigma=3,
spatial_samples=0.1, background=None, size_average=True, reduce=True):
super(MI, self).__init__(fixed_image, moving_image, fixed_mask, moving_mask, size_average, reduce)
self._name = "mi"
self._dim = fixed_image.ndim
self._bins = bins
self._sigma = 2*sigma**2
self._normalizer_1d = np.sqrt(2.0 * np.pi) * sigma
self._normalizer_2d = 2.0 * np.pi*sigma**2
if background is None:
self._background_fixed = th.min(fixed_image.image)
self._background_moving = th.min(moving_image.image)
elif background == "mean":
self._background_fixed = th.mean(fixed_image.image)
self._background_moving = th.mean(moving_image.image)
else:
self._background_fixed = background
self._background_moving = background
self._max_f = th.max(fixed_image.image)
self._max_m = th.max(moving_image.image)
self._spatial_samples = spatial_samples
self._bins_fixed_image = th.linspace(self._background_fixed, self._max_f, self.bins,
device=fixed_image.device, dtype=fixed_image.dtype).unsqueeze(1)
self._bins_moving_image = th.linspace(self._background_moving, self._max_m, self.bins,
device=fixed_image.device, dtype=fixed_image.dtype).unsqueeze(1)
@property
def sigma(self):
return self._sigma
@property
def bins(self):
return self._bins
@property
def bins_fixed_image(self):
return self._bins_fixed_image
def _compute_marginal_entropy(self, values, bins):
p = th.exp(-((values - bins).pow(2).div(self._sigma))).div(self._normalizer_1d)
p_n = p.mean(dim=1)
p_n = p_n/(th.sum(p_n) + 1e-10)
return -(p_n * th.log2(p_n + 1e-10)).sum(), p
[docs] def forward(self, displacement):
# compute displacement field
displacement = self._grid + displacement
# compute current mask
mask = super(MI, self).GetCurrentMask(displacement)
self._warped_moving_image = F.grid_sample(self._moving_image.image, displacement)
moving_image_valid = th.masked_select(self._warped_moving_image, mask)
fixed_image_valid = th.masked_select(self._fixed_image.image, mask)
mask = (fixed_image_valid > self._background_fixed) & (moving_image_valid > self._background_moving)
fixed_image_valid = th.masked_select(fixed_image_valid, mask)
moving_image_valid = th.masked_select(moving_image_valid, mask)
number_of_pixel = moving_image_valid.shape[0]
sample = th.zeros(number_of_pixel, device=self._fixed_image.device,
dtype=self._fixed_image.dtype).uniform_() < self._spatial_samples
# compute marginal entropy fixed image
image_samples_fixed = th.masked_select(fixed_image_valid.view(-1), sample)
ent_fixed_image, p_f = self._compute_marginal_entropy(image_samples_fixed, self._bins_fixed_image)
# compute marginal entropy moving image
image_samples_moving = th.masked_select(moving_image_valid.view(-1), sample)
ent_moving_image, p_m = self._compute_marginal_entropy(image_samples_moving, self._bins_moving_image)
# compute joint entropy
p_joint = th.mm(p_f, p_m.transpose(0, 1)).div(self._normalizer_2d)
p_joint = p_joint / (th.sum(p_joint) + 1e-10)
ent_joint = -(p_joint * th.log2(p_joint + 1e-10)).sum()
return -(ent_fixed_image + ent_moving_image - ent_joint)
[docs]class NGF(_PairwiseImageLoss):
r""" Implementation of the Normalized Gradient Fields image loss.
Args:
fixed_image (Image): Fixed image for the registration
moving_image (Image): Moving image for the registration
fixed_mask (Tensor): Mask for the fixed image
moving_mask (Tensor): Mask for the moving image
epsilon (float): Regulariser for the gradient amplitude
size_average (bool): Average loss function
reduce (bool): Reduce loss function to a single value
"""
def __init__(self, fixed_image, moving_image, fixed_mask=None, moving_mask=None, epsilon=1e-5,
size_average=True,
reduce=True):
super(NGF, self).__init__(fixed_image, moving_image, fixed_mask, moving_mask, size_average, reduce)
self._name = "ngf"
self._dim = fixed_image.ndim
self._epsilon = epsilon
if self._dim == 2:
dx = (fixed_image.image[..., 1:, 1:] - fixed_image.image[..., :-1, 1:]) * fixed_image.spacing[0]
dy = (fixed_image.image[..., 1:, 1:] - fixed_image.image[..., 1:, :-1]) * fixed_image.spacing[1]
if self._epsilon is None:
with th.no_grad():
self._epsilon = th.mean(th.abs(dx) + th.abs(dy))
norm = th.sqrt(dx.pow(2) + dy.pow(2) + self._epsilon ** 2)
self._ng_fixed_image = F.pad(th.cat((dx, dy), dim=1) / norm, (0, 1, 0, 1))
self._ngf_loss = self._ngf_loss_2d
else:
dx = (fixed_image.image[..., 1:, 1:, 1:] - fixed_image.image[..., :-1, 1:, 1:]) * fixed_image.spacing[0]
dy = (fixed_image.image[..., 1:, 1:, 1:] - fixed_image.image[..., 1:, :-1, 1:]) * fixed_image.spacing[1]
dz = (fixed_image.image[..., 1:, 1:, 1:] - fixed_image.image[..., 1:, 1:, :-1]) * fixed_image.spacing[2]
if self._epsilon is None:
with th.no_grad():
self._epsilon = th.mean(th.abs(dx) + th.abs(dy) + th.abs(dz))
norm = th.sqrt(dx.pow(2) + dy.pow(2) + dz.pow(2) + self._epsilon ** 2)
self._ng_fixed_image = F.pad(th.cat((dx, dy, dz), dim=1) / norm, (0, 1, 0, 1, 0, 1))
self._ngf_loss = self._ngf_loss_3d
def _ngf_loss_2d(self, warped_image):
dx = (warped_image[..., 1:, 1:] - warped_image[..., :-1, 1:]) * self._moving_image.spacing[0]
dy = (warped_image[..., 1:, 1:] - warped_image[..., 1:, :-1]) * self._moving_image.spacing[1]
norm = th.sqrt(dx.pow(2) + dy.pow(2) + self._epsilon ** 2)
return F.pad(th.cat((dx, dy), dim=1) / norm, (0, 1, 0, 1))
def _ngf_loss_3d(self, warped_image):
dx = (warped_image[..., 1:, 1:, 1:] - warped_image[..., :-1, 1:, 1:]) * self._moving_image.spacing[0]
dy = (warped_image[..., 1:, 1:, 1:] - warped_image[..., 1:, :-1, 1:]) * self._moving_image.spacing[1]
dz = (warped_image[..., 1:, 1:, 1:] - warped_image[..., 1:, 1:, :-1]) * self._moving_image.spacing[2]
norm = th.sqrt(dx.pow(2) + dy.pow(2) + dz.pow(2) + self._epsilon ** 2)
return F.pad(th.cat((dx, dy, dz), dim=1) / norm, (0, 1, 0, 1, 0, 1))
[docs] def forward(self, displacement):
# compute displacement field
displacement = self._grid + displacement
# compute current mask
mask = super(NGF, self).GetCurrentMask(displacement)
self._warped_moving_image = F.grid_sample(self._moving_image.image, displacement)
# compute the gradient of the warped image
ng_warped_image = self._ngf_loss(self._warped_moving_image)
value = 0
for dim in range(self._dim):
value = value + ng_warped_image[:, dim, ...] * self._ng_fixed_image[:, dim, ...]
value = 0.5 * th.masked_select(-value.pow(2), mask)
return self.return_loss(value)
[docs]class SSIM(_PairwiseImageLoss):
r""" Implementation of the Structual Similarity Image Measure loss.
Args:
fixed_image (Image): Fixed image for the registration
moving_image (Image): Moving image for the registration
fixed_mask (Tensor): Mask for the fixed image
moving_mask (Tensor): Mask for the moving image
sigma (float): Sigma for the kernel
kernel_type (string): Type of kernel i.e. gaussian, box
alpha (float): Controls the influence of the luminance value
beta (float): Controls the influence of the contrast value
gamma (float): Controls the influence of the structure value
c1 (float): Numerical constant for the luminance value
c2 (float): Numerical constant for the contrast value
c3 (float): Numerical constant for the structure value
size_average (bool): Average loss function
reduce (bool): Reduce loss function to a single value
"""
def __init__(self, fixed_image, moving_image, fixed_mask=None, moving_mask=None,
sigma=[3], dim=2, kernel_type="box", alpha=1, beta=1, gamma=1, c1=0.00001, c2=0.00001,
c3=0.00001, size_average=True, reduce=True, ):
super(SSIM, self).__init__(fixed_image, moving_image, fixed_mask, moving_mask, size_average, reduce)
self._alpha = alpha
self._beta = beta
self._gamma = gamma
self._c1 = c1
self._c2 = c2
self._c3 = c3
self._name = "sim"
self._kernel = None
dim = dim
sigma = np.array(sigma)
if sigma.size != dim:
sigma_app = sigma[-1]
while sigma.size != dim:
sigma = np.append(sigma, sigma_app)
if kernel_type == "box":
kernel_size = sigma * 2 + 1
self._kernel = th.ones(*kernel_size.tolist()) \
/ float(np.product(kernel_size) ** 2)
elif kernel_type == "gaussian":
self._kernel = utils.gaussian_kernel(sigma, dim, asTensor=True)
self._kernel.unsqueeze_(0).unsqueeze_(0)
self._kernel = self._kernel.to(dtype=self._dtype, device=self._device)
# calculate mean and variance of the fixed image
self._mean_fixed_image = F.conv2d(self._fixed_image.image, self._kernel)
self._variance_fixed_image = F.conv2d(self._fixed_image.image.pow(2), self._kernel) \
- (self._mean_fixed_image.pow(2))
[docs] def forward(self, displacement):
# compute displacement field
displacement = self._grid + displacement
# compute current mask
mask = super(SSIM, self).GetCurrentMask(displacement)
mask = 1 - mask
mask = mask.to(dtype=self._dtype, device=self._device)
self._warped_moving_image = F.grid_sample(self._moving_image.image, displacement)
mask = F.conv2d(mask, self._kernel)
mask = mask == 0
mean_moving_image = F.conv2d(self._warped_moving_image, self._kernel)
variance_moving_image = F.conv2d(self._warped_moving_image.pow(2), self._kernel) - (
mean_moving_image.pow(2))
mean_fixed_moving_image = F.conv2d(self._fixed_image.image * self._warped_moving_image, self._kernel)
covariance_fixed_moving = (mean_fixed_moving_image - mean_moving_image * self._mean_fixed_image)
luminance = (2 * self._mean_fixed_image * mean_moving_image + self._c1) / \
(self._mean_fixed_image.pow(2) + mean_moving_image.pow(2) + self._c1)
contrast = (2 * th.sqrt(self._variance_fixed_image + 1e-10) * th.sqrt(
variance_moving_image + 1e-10) + self._c2) / \
(self._variance_fixed_image + variance_moving_image + self._c2)
structure = (covariance_fixed_moving + self._c3) / \
(th.sqrt(self._variance_fixed_image + 1e-10) * th.sqrt(
variance_moving_image + 1e-10) + self._c3)
sim = luminance.pow(self._alpha) * contrast.pow(self._beta) * structure.pow(self._gamma)
value = -1.0 * th.masked_select(sim, mask)
return self.return_loss(value)