Source code for airlab.utils.image

# Copyright 2018 University of Basel, Center for medical Image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import SimpleITK as sitk
import torch as th
import torch.nn.functional as F
import numpy as np
import sys

from . import kernelFunction


[docs]class Image: """ Class representing an image in airlab """ def __init__(self, *args, **kwargs): """ Constructor for an image object where two cases are distinguished: - Construct airlab image from an array or tensor (4 arguments) - Construct airlab image from an SimpleITK image (less than 4 arguments """ if len(args) == 4: self.initializeForTensors(*args) elif len(args) < 4: self.initializeForImages(*args)
[docs] def initializeForTensors(self, tensor_image, image_size, image_spacing, image_origin): """ Constructor for torch tensors and numpy ndarrays Args: tensor_image (np.ndarray | th.Tensor): n-dimensional tensor, where the last dimensions are the image dimensions while the preceeding dimensions need to empty image_size (array | list | tuple): number of pixels in each space dimension image_spacing (array | list | tuple): pixel size for each space dimension image_origin (array | list | tuple): physical coordinate of the first pixel :return (Image): an airlab image object """ # distinguish between numpy array and torch tensors if type(tensor_image) == np.ndarray: self.image = th.from_numpy(tensor_image).squeeze().unsqueeze(0).unsqueeze(0) elif type(tensor_image) == th.Tensor: self.image = tensor_image.squeeze().unsqueeze(0).unsqueeze(0) else: raise Exception("A numpy ndarray or a torch tensor was expected as argument. Got " + str(type(tensor_image))) self.size = image_size self.spacing = image_spacing self.origin = image_origin self.dtype = self.image.dtype self.device = self.image.device self.ndim = len(self.image.squeeze().shape) # take only non-empty dimensions to count space dimensions
[docs] def initializeForImages(self, sitk_image, dtype=None, device='cpu'): """ Constructor for SimpleITK image Note: the order of axis are flipped in order to follow the convention of numpy and torch sitk_image (sitk.SimpleITK.Image): SimpleITK image dtype: pixel type device ('cpu'|'cuda'): on which device the image should be allocated return (Image): an airlab image object """ if type(sitk_image)==sitk.SimpleITK.Image: self.image = th.from_numpy(sitk.GetArrayFromImage(sitk_image)).unsqueeze(0).unsqueeze(0) self.size = sitk_image.GetSize() self.spacing = sitk_image.GetSpacing() self.origin = sitk_image.GetOrigin() if not dtype is None: self.to(dtype, device) else: self.to(self.image.dtype, device) self.ndim = len(self.image.squeeze().shape) self._reverse_axis() else: raise Exception("A SimpleITK image was expected as argument. Got " + str(type(sitk_image)))
[docs] def to(self, dtype=None, device='cpu'): """ Converts the image tensor to a specified dtype and moves it to the specified device """ if not dtype is None: self.image = self.image.to(dtype=dtype, device=device) else: self.image = self.image.to(device=device) self.dtype = self.image.dtype self.device = self.image.device return self
[docs] def itk(self): """ Returns a SimpleITK image Note: the order of axis is flipped back to the convention of SimpleITK """ image = Image(self.image.cpu().clone(), self.size, self.spacing, self.origin) image._reverse_axis() image.image.squeeze_() itk_image = sitk.GetImageFromArray(image.image.numpy()) itk_image.SetSpacing(spacing=self.spacing) itk_image.SetOrigin(origin=self.origin) return itk_image
[docs] def numpy(self): """ Returns a numpy array """ return self.image.cpu().squeeze().numpy()
[docs] @staticmethod def read(filename, dtype=th.float32, device='cpu'): """ Static method to directly read an image through the Image class filename (str): filename of the image dtype: specific dtype for representing the tensor device: on which device the image has to be allocated return (Image): an airlab image """ return Image(sitk.ReadImage(filename, sitk.sitkFloat32), dtype, device)
[docs] def write(self, filename): """ Write an image to hard drive Note: order of axis are flipped to have the representation of SimpleITK again filename (str): filename where the image is written """ sitk.WriteImage(self.itk(), filename)
def _reverse_axis(self): """ Flips the order of the axis representing the space dimensions (preceeding dimensions are ignored) Note: the method is inplace """ # reverse order of axis to follow the convention of SimpleITK self.image = self.image.squeeze().permute(tuple(reversed(range(self.ndim)))) self.image = self.image.unsqueeze(0).unsqueeze(0)
""" Object representing a displacement image """
[docs]class Displacement(Image): def __init__(self, *args, **kwargs): """ Constructor for a displacement field object where two cases are distinguished: - Construct airlab displacement field from an array or tensor (4 arguments) - Construct airlab displacement field from an SimpleITK image (less than 4 arguments) """ if len(args) == 4: self.initializeForTensors(*args) elif len(args) < 4: self.initializeForImages(*args)
[docs] def itk(self): # flip axis to df = Displacement(self.image.clone(), self.size, self.spacing, self.origin) df._reverse_axis() df.image = df.image.squeeze() df.image = df.image.cpu() if len(self.size) == 2: itk_displacement = sitk.GetImageFromArray(df.image.numpy(), isVector=True) elif len(self.size) == 3: itk_displacement = sitk.GetImageFromArray(df.image.numpy()) itk_displacement.SetSpacing(spacing=self.spacing) itk_displacement.SetOrigin(origin=self.origin) return itk_displacement
[docs] def magnitude(self): return Image(th.sqrt(th.sum(self.image.pow(2), -1)).squeeze(), self.size, self.spacing, self.origin)
[docs] def numpy(self): return self.image.cpu().numpy()
def _reverse_axis(self): """ Flips the order of the axis representing the space dimensions (preceeding dimensions are ignored). Respectively, the axis holding the vectors is flipped as well Note: the method is inplace """ # reverse order of axis to follow the convention of SimpleITK order = list(reversed(range(self.ndim-1))) order.append(len(order)) self.image = self.image.squeeze_().permute(tuple(order)) self.image = flip(self.image, self.ndim-1) self.image = self.image.unsqueeze(0).unsqueeze(0)
[docs] @staticmethod def read(filename, dtype=th.float32, device='cpu'): """ Static method to directly read a displacement field through the Image class filename (str): filename of the displacement field dtype: specific dtype for representing the tensor device: on which device the displacement field has to be allocated return (Displacement): an airlab displacement field """ return Displacement(sitk.ReadImage(filename, sitk.sitkVectorFloat32), dtype, device)
[docs]def flip(x, dim): """ Flip order of a specific dimension dim x (Tensor): input tensor dim (int): axis which should be flipped return (Tensor): returns the tensor with the specified axis flipped """ indices = [slice(None)] * x.dim() indices[dim] = th.arange(x.size(dim) - 1, -1, -1, dtype=th.long, device=x.device) return x[tuple(indices)]
""" Convert an image to tensor representation """
[docs]def read_image_as_tensor(filename, dtype=th.float32, device='cpu'): itk_image = sitk.ReadImage(filename, sitk.sitkFloat32) return create_tensor_image_from_itk_image(itk_image, dtype=dtype, device=device)
""" Convert an image to tensor representation """
[docs]def create_image_from_image(tensor_image, image): return Image(tensor_image, image.size, image.spacing, image.origin)
""" Convert numpy image to AirlLab image format """
[docs]def image_from_numpy(image, pixel_spacing, image_origin, dtype=th.float32, device='cpu'): tensor_image = th.from_numpy(image).unsqueeze(0).unsqueeze(0) tensor_image = tensor_image.to(dtype=dtype, device=device) return Image(tensor_image, image.shape, pixel_spacing, image_origin)
""" Convert an image to tensor representation """
[docs]def create_displacement_image_from_image(tensor_displacement, image): return Displacement(tensor_displacement, image.size, image.spacing, image.origin)
""" Create tensor image representation """
[docs]def create_tensor_image_from_itk_image(itk_image, dtype=th.float32, device='cpu'): # transform image in a unit direction image_dim = itk_image.GetDimension() if image_dim == 2: itk_image.SetDirection(sitk.VectorDouble([1, 0, 0, 1])) else: itk_image.SetDirection(sitk.VectorDouble([1, 0, 0, 0, 1, 0, 0, 0, 1])) image_spacing = itk_image.GetSpacing() image_origin = itk_image.GetOrigin() np_image = np.squeeze(sitk.GetArrayFromImage(itk_image)) image_size = np_image.shape # adjust image spacing vector size if image contains empty dimension if len(image_size) != image_dim: image_spacing = image_spacing[0:len(image_size)] tensor_image = th.tensor(np_image, dtype=dtype, device=device).unsqueeze(0).unsqueeze(0) return Image(tensor_image, image_size, image_spacing, image_origin)
""" Create an image pyramide """
[docs]def create_image_pyramid(image, down_sample_factor): image_dim = len(image.size) image_pyramide = [] if image_dim == 2: for level in down_sample_factor: sigma = (th.tensor(level)/2).to(dtype=th.float32) kernel = kernelFunction.gaussian_kernel_2d(sigma.numpy(), asTensor=True) padding = np.array([(x - 1)/2 for x in kernel.size()], dtype=int).tolist() kernel = kernel.unsqueeze(0).unsqueeze(0) kernel = kernel.to(dtype=image.dtype, device=image.device) image_sample = F.conv2d(image.image, kernel, stride=level, padding=padding) image_size = image_sample.size()[-image_dim:] image_spacing = [x*y for x, y in zip(image.spacing, level)] image_origin = image.origin image_pyramide.append(Image(image_sample, image_size, image_spacing, image_origin)) image_pyramide.append(image) elif image_dim == 3: for level in down_sample_factor: sigma = (th.tensor(level)/2).to(dtype=th.float32) kernel = kernelFunction.gaussian_kernel_3d(sigma.numpy(), asTensor=True) padding = np.array([(x - 1) / 2 for x in kernel.size()], dtype=int).tolist() kernel = kernel.unsqueeze(0).unsqueeze(0) kernel = kernel.to(dtype=image.dtype, device=image.device) image_sample = F.conv3d(image.image, kernel, stride=level, padding=padding) image_size = image_sample.size()[-image_dim:] image_spacing = [x*y for x, y in zip(image.spacing, level)] image_origin = image.origin image_pyramide.append(Image(image_sample, image_size, image_spacing, image_origin)) image_pyramide.append(image) else: print("Error: ", image_dim, " is not supported with create_image_pyramide()") sys.exit(-1) return image_pyramide