Source code for airlab.utils.graph

# Copyright 2018 University of Basel, Center for medical image Analysis and Navigation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import torch as th
import numpy as np
from .matrix import MatrixDiagonalElement, LaplaceMatrix


[docs]class Graph(): def __init__(self, graph_size, dtype=th.float32, device='cpu'): self.dtype=dtype self.device=device self._dim = len(graph_size) self._graph_size = np.array(graph_size) # compute the number of nodes self._number_of_nodes = np.prod(self._graph_size) self.laplace_matrix = None if self._dim == 2: self._create_graph_2d() elif self._dim == 3: self._create_graph_3d() def _create_graph_2d(self): # compute the number of edges self._number_of_edges = self._number_of_nodes*2.0 for length in self._graph_size: self._number_of_edges -= length self._edge_index_D1 = [[], [], []] self._edge_index_D2 = [[], [], []] for y in range(self._graph_size[1]): for x in range(self._graph_size[0]): if x + 1 < self._graph_size[0]: self._edge_index_D1[0].append(x + 1) self._edge_index_D1[1].append(y) self._edge_index_D1[2].append(x + y*self._graph_size[0]) if y + 1 < self._graph_size[1]: self._edge_index_D2[0].append(x) self._edge_index_D2[1].append(y + 1) self._edge_index_D2[2].append(x + y*self._graph_size[0]) element_1 = MatrixDiagonalElement(np.asarray(self._edge_index_D1), np.ones(len(self._edge_index_D1[0])), th.tensor(1, dtype=th.int64), dtype=self.dtype, device=self.device) element_2 = MatrixDiagonalElement(np.asarray(self._edge_index_D2), np.ones(len(self._edge_index_D2[0])), th.tensor(self._graph_size[0], dtype=th.int64), dtype=self.dtype, device=self.device) self.laplace_matrix = LaplaceMatrix(self._number_of_nodes, [element_1, element_2], dtype=self.dtype, device=self.device) def _create_graph_3d(self): # compute the number of edges self._number_of_edges = self._number_of_nodes * 2.0
# for length in self._graph_size: # self._number_of_edges -= length # # self._number_of_edges *= self._graph_size[2]; # self._number_of_edges += self._graph_size[0]*self._graph_size[1]*(self._graph_size[2] - 1); # # self._edge_index_D1 = [[], [], [], []] # self._edge_index_D2 = [[], [], [], []] # self._edge_index_D3 = [[], [], [], []] # # for z in range(self._graph_size[2]): # for y in range(self._graph_size[1]): # for x in range(self._graph_size[0]): # # if x + 1 < self._graph_size[0]: # self._edge_index_D1[0].append(x-1) # self._edge_index_D1[1].append(y) # self._edge_index_D1[2].append(z) # self._edge_index_D1[3].append(x + y * self._graph_size[0] + z*self._graph_size[0]* self._graph_size[1]) # # if y + 1 < self._graph_size[1]: # self._edge_index_D2[0].append(x) # self._edge_index_D2[1].append(y + 1) # self._edge_index_D1[2].append(z) # self._edge_index_D2[3].append(x + y * self._graph_size[0] + z*self._graph_size[0]* self._graph_size[1]) # # if z + 1 < self._graph_size[2]: # self._edge_index_D3[0].append(x) # self._edge_index_D3[1].append(y) # self._edge_index_D3[2].append(z + 1) # self._edge_index_D3[3].append(x + y * self._graph_size[0] + z * self._graph_size[0] * self._graph_size[1]) # # element_1 = MatrixDiagonalElement(np.asarray(self._edge_index_D1), np.ones(len(self._edge_index_D1[0])), # th.tensor(1, dtype=th.int64), dtype=self.dtype, device=self.device) # # element_2 = MatrixDiagonalElement(np.asarray(self._edge_index_D2), np.ones(len(self._edge_index_D2[0])), # th.tensor(self._graph_size[0], dtype=th.int64), dtype=self.dtype, # device=self.device) # # element_3 = MatrixDiagonalElement(np.asarray(self._edge_index_D3), np.ones(len(self._edge_index_D3[0])), # th.tensor(self._graph_size[0]*self._graph_size[1], dtype=th.int64), dtype=self.dtype, # device=self.device) # # self.laplace_matrix = LaplaceMatrix(self._number_of_nodes, [element_1, element_2, element_3])