Skip to content

Commit 806d4d6

Browse files
authored
Merge branch 'multi_model' into multi_model
2 parents 63c101c + 1f5db37 commit 806d4d6

File tree

13 files changed

+14912
-22
lines changed

13 files changed

+14912
-22
lines changed

examples/BraTS2020.ipynb

Lines changed: 7032 additions & 0 deletions
Large diffs are not rendered by default.

examples/BraTS2020.multimodal.ipynb

Lines changed: 6818 additions & 0 deletions
Large diffs are not rendered by default.

miscnn/model/cross_validation_group.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,7 @@ def evaluate(self, samples, evaluation_path="evaluation", epochs=20, iterations=
8989
cb_list = callbacks
9090

9191
model.reset()
92+
9293
if (isinstance(model, Model_Group)):
9394
model.evaluate(training, validation, evaluation_path=out_dir, epochs=epochs, iterations=iterations, callbacks=cb_list)
9495
else:

miscnn/model/model_group.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def __init__ (self, models, preprocessor, verify_preprocessor=True):
5151
raise RuntimeError("Model Groups can only be comprised of objects inheriting from Model")
5252
if verify_preprocessor and not model.preprocessor == self.preprocessor:
5353
raise RuntimeError("not all models use the same preprocessor. This can have have unintended effects and instabilities. To disable this warning pass \"verify_preprocessor=False\"")
54-
54+
5555
#---------------------------------------------#
5656
# Training #
5757
#---------------------------------------------#

miscnn/neural_network/architecture/unet/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,5 @@
2222
from miscnn.neural_network.architecture.unet.dense import Architecture as UNet_dense
2323
from miscnn.neural_network.architecture.unet.multiRes import Architecture as UNet_multiRes
2424
from miscnn.neural_network.architecture.unet.compact import Architecture as UNet_compact
25+
from miscnn.neural_network.architecture.unet.attention import Architecture as UNet_attention
26+
from miscnn.neural_network.architecture.unet.attention_residual import Architecture as UNet_attention_residual
Lines changed: 303 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,303 @@
1+
# ==============================================================================#
2+
# Author: Dennis Hartmann #
3+
# Copyright: 2021 IT-Infrastructure for Translational Medical Research, #
4+
# University of Augsburg #
5+
# #
6+
# This program is free software: you can redistribute it and/or modify #
7+
# it under the terms of the GNU General Public License as published by #
8+
# the Free Software Foundation, either version 3 of the License, or #
9+
# (at your option) any later version. #
10+
# #
11+
# This program is distributed in the hope that it will be useful, #
12+
# but WITHOUT ANY WARRANTY; without even the implied warranty of #
13+
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the #
14+
# GNU General Public License for more details. #
15+
# #
16+
# You should have received a copy of the GNU General Public License #
17+
# along with this program. If not, see <http://www.gnu.org/licenses/>. #
18+
# ==============================================================================#
19+
# -----------------------------------------------------#
20+
# Reference: #
21+
# Ozan Oktay et al. #
22+
# 11 April 2018. #
23+
# Attention U-Net: Learning Where #
24+
# to Look for the Pancreas #
25+
# MIDL'18. #
26+
# -----------------------------------------------------#
27+
# Library imports #
28+
# -----------------------------------------------------#
29+
# External libraries
30+
from tensorflow.keras.models import Model
31+
from tensorflow.keras.layers import Input, concatenate, Activation, add, Lambda, multiply
32+
from tensorflow.keras.layers import Conv3D, MaxPooling3D, Conv3DTranspose, UpSampling3D
33+
from tensorflow.keras.layers import Conv2D, MaxPooling2D, Conv2DTranspose, UpSampling2D
34+
from tensorflow.keras.layers import BatchNormalization
35+
from tensorflow.keras import backend as k
36+
# Internal libraries/scripts
37+
from miscnn.neural_network.architecture.abstract_architecture import Abstract_Architecture
38+
39+
# -----------------------------------------------------#
40+
# Architecture class: Attention U-Net #
41+
# -----------------------------------------------------#
42+
""" The Standard variant of the popular U-Net architecture.
43+
44+
Methods:
45+
__init__ Object creation function
46+
create_model_2D: Creating the 2D Attention U-Net standard model using Keras
47+
create_model_3D: Creating the 3D Attention U-Net standard model using Keras
48+
"""
49+
50+
51+
class Architecture(Abstract_Architecture):
52+
# ---------------------------------------------#
53+
# Initialization #
54+
# ---------------------------------------------#
55+
def __init__(self, n_filters=32, depth=4, activation='softmax',
56+
batch_normalization=True):
57+
# Parse parameter
58+
self.n_filters = n_filters
59+
self.depth = depth
60+
self.activation = activation
61+
# Batch normalization settings
62+
self.ba_norm = batch_normalization
63+
self.ba_norm_momentum = 0.99
64+
65+
# ---------------------------------------------#
66+
# Create 2D Model #
67+
# ---------------------------------------------#
68+
def create_model_2D(self, input_shape, n_labels=2):
69+
# Input layer
70+
inputs = Input(input_shape)
71+
# Start the CNN Model chain with adding the inputs as first tensor
72+
cnn_chain = inputs
73+
# Cache contracting normalized conv layers
74+
# for later copy & concatenate links
75+
contracting_convs = []
76+
77+
# Contracting Layers
78+
for i in range(0, self.depth):
79+
neurons = self.n_filters * 2 ** i
80+
cnn_chain, last_conv = contracting_layer_2D(cnn_chain, neurons,
81+
self.ba_norm,
82+
self.ba_norm_momentum)
83+
contracting_convs.append(last_conv)
84+
85+
# Middle Layer
86+
neurons = self.n_filters * 2 ** self.depth
87+
cnn_chain = middle_layer_2D(cnn_chain, neurons, self.ba_norm,
88+
self.ba_norm_momentum)
89+
90+
# Expanding Layers
91+
for i in reversed(range(0, self.depth)):
92+
neurons = self.n_filters * 2 ** i
93+
cnn_chain = expanding_layer_2D(cnn_chain, neurons,
94+
contracting_convs[i], self.ba_norm,
95+
self.ba_norm_momentum)
96+
97+
# Output Layer
98+
conv_out = Conv2D(n_labels, (1, 1),
99+
activation=self.activation)(cnn_chain)
100+
# Create Model with associated input and output layers
101+
model = Model(inputs=[inputs], outputs=[conv_out])
102+
# Return model
103+
return model
104+
105+
# ---------------------------------------------#
106+
# Create 3D Model #
107+
# ---------------------------------------------#
108+
def create_model_3D(self, input_shape, n_labels=2):
109+
# Input layer
110+
inputs = Input(input_shape)
111+
# Start the CNN Model chain with adding the inputs as first tensor
112+
cnn_chain = inputs
113+
# Cache contracting normalized conv layers
114+
# for later copy & concatenate links
115+
contracting_convs = []
116+
117+
# Contracting Layers
118+
for i in range(0, self.depth):
119+
neurons = self.n_filters * 2 ** i
120+
cnn_chain, last_conv = contracting_layer_3D(cnn_chain, neurons,
121+
self.ba_norm,
122+
self.ba_norm_momentum)
123+
contracting_convs.append(last_conv)
124+
125+
# Middle Layer
126+
neurons = self.n_filters * 2 ** self.depth
127+
cnn_chain = middle_layer_3D(cnn_chain, neurons, self.ba_norm,
128+
self.ba_norm_momentum)
129+
130+
# Expanding Layers
131+
for i in reversed(range(0, self.depth)):
132+
neurons = self.n_filters * 2 ** i
133+
cnn_chain = expanding_layer_3D(cnn_chain, neurons,
134+
contracting_convs[i], self.ba_norm,
135+
self.ba_norm_momentum)
136+
137+
# Output Layer
138+
conv_out = Conv3D(n_labels, (1, 1, 1),
139+
activation=self.activation)(cnn_chain)
140+
# Create Model with associated input and output layers
141+
model = Model(inputs=[inputs], outputs=[conv_out])
142+
# Return model
143+
return model
144+
145+
146+
# -----------------------------------------------------#
147+
# Subroutines all #
148+
# -----------------------------------------------------#
149+
def repeat_elem(tensor, rep, axs=3):
150+
# lambda function to repeat Repeats the elements of a tensor along an axis
151+
# by a factor of rep.
152+
# If tensor has shape (None, 256,256,3), lambda will return a tensor of shape
153+
# (None, 256,256,6), if specified axis=3 and rep=2.
154+
155+
return Lambda(lambda x, repnum: k.repeat_elements(x, repnum, axis=axs),
156+
arguments={'repnum': rep})(tensor)
157+
158+
159+
# -----------------------------------------------------#
160+
# Subroutines 2D #
161+
# -----------------------------------------------------#
162+
def gating_signal2D(input, out_size, batch_norm=False):
163+
"""
164+
resize the down layer feature map into the same dimension as the up layer feature map
165+
using 1x1 conv
166+
:return: the gating feature map with the same dimension of the up layer feature map
167+
"""
168+
x = Conv2D(out_size, (1, 1), padding='same')(input)
169+
if batch_norm:
170+
x = BatchNormalization()(x)
171+
x = Activation('relu')(x)
172+
return x
173+
174+
175+
def attention_block2D(x, gating, inter_shape):
176+
shape_x = k.int_shape(x)
177+
178+
# Getting the x signal to the same shape as the gating signal
179+
theta_x = Conv2D(filters=inter_shape, kernel_size=3, strides=2, padding='same')(x)
180+
181+
# Getting the gating signal to the same number of filters as the inter_shape
182+
phi_g = Conv2D(filters=inter_shape, kernel_size=1, strides=1, padding='same')(gating)
183+
184+
concat_xg = add([phi_g, theta_x])
185+
act_xg = Activation('relu')(concat_xg)
186+
psi = Conv2D(filters=1, kernel_size=1, padding='same')(act_xg)
187+
sigmoid_xg = Activation('sigmoid')(psi)
188+
upsample_psi = UpSampling2D(size=2)(sigmoid_xg)
189+
190+
upsample_psi = repeat_elem(upsample_psi, shape_x[3])
191+
192+
y = multiply([upsample_psi, x])
193+
194+
# Final 1x1 convolution to consolidate attention signal to original x dimensions
195+
result = Conv2D(filters=shape_x[3], kernel_size=1, strides=1, padding='same')(y)
196+
result_bn = BatchNormalization()(result)
197+
return result_bn
198+
199+
200+
# Create a contracting layer
201+
def contracting_layer_2D(input, neurons, ba_norm, ba_norm_momentum):
202+
conv1 = Conv2D(filters=neurons, kernel_size=3, activation='relu', padding='same')(input)
203+
if ba_norm: conv1 = BatchNormalization(momentum=ba_norm_momentum)(conv1)
204+
conv2 = Conv2D(filters=neurons, kernel_size=3, activation='relu', padding='same')(conv1)
205+
if ba_norm: conv2 = BatchNormalization(momentum=ba_norm_momentum)(conv2)
206+
pool = MaxPooling2D(pool_size=2)(conv2)
207+
return pool, conv2
208+
209+
210+
# Create the middle layer between the contracting and expanding layers
211+
def middle_layer_2D(input, neurons, ba_norm, ba_norm_momentum):
212+
conv_m1 = Conv2D(filters=neurons, kernel_size=3, activation='relu', padding='same')(input)
213+
if ba_norm: conv_m1 = BatchNormalization(momentum=ba_norm_momentum)(conv_m1)
214+
conv_m2 = Conv2D(filters=neurons, kernel_size=3, activation='relu', padding='same')(conv_m1)
215+
if ba_norm: conv_m2 = BatchNormalization(momentum=ba_norm_momentum)(conv_m2)
216+
return conv_m2
217+
218+
219+
# Create an expanding layer
220+
def expanding_layer_2D(input, neurons, concatenate_link, ba_norm,
221+
ba_norm_momentum):
222+
gating = gating_signal2D(input, neurons, ba_norm)
223+
att = attention_block2D(concatenate_link, gating, neurons)
224+
up = concatenate([Conv2DTranspose(filters=neurons, kernel_size=2, strides=2,
225+
padding='same')(input), att], axis=-1)
226+
conv1 = Conv2D(filters=neurons, kernel_size=3, activation='relu', padding='same')(up)
227+
if ba_norm: conv1 = BatchNormalization(momentum=ba_norm_momentum)(conv1)
228+
conv2 = Conv2D(filters=neurons, kernel_size=3, activation='relu', padding='same')(conv1)
229+
if ba_norm: conv2 = BatchNormalization(momentum=ba_norm_momentum)(conv2)
230+
return conv2
231+
232+
233+
# -----------------------------------------------------#
234+
# Subroutines 3D #
235+
# -----------------------------------------------------#
236+
def gating_signal3D(input, out_size, batch_norm=False):
237+
"""
238+
resize the down layer feature map into the same dimension as the up layer feature map
239+
using 1x1 conv
240+
:return: the gating feature map with the same dimension of the up layer feature map
241+
"""
242+
x = Conv3D(out_size, kernel_size=1, padding='same')(input)
243+
if batch_norm:
244+
x = BatchNormalization()(x)
245+
x = Activation('relu')(x)
246+
return x
247+
248+
249+
def attention_block3D(x, gating, inter_shape):
250+
shape_x = k.int_shape(x)
251+
252+
# Getting the x signal to the same shape as the gating signal
253+
theta_x = Conv3D(filters=inter_shape, kernel_size=3, strides=2, padding='same')(x) # 16
254+
255+
# Getting the gating signal to the same number of filters as the inter_shape
256+
phi_g = Conv3D(filters=inter_shape, kernel_size=1, strides=1, padding='same')(gating)
257+
258+
concat_xg = add([phi_g, theta_x])
259+
act_xg = Activation('relu')(concat_xg)
260+
psi = Conv3D(filters=1, kernel_size=1, padding='same')(act_xg)
261+
sigmoid_xg = Activation('sigmoid')(psi)
262+
upsample_psi = UpSampling3D(size=2)(sigmoid_xg)
263+
264+
upsample_psi = repeat_elem(upsample_psi, shape_x[4], axs=4)
265+
266+
y = multiply([upsample_psi, x])
267+
268+
result = Conv3D(filters=shape_x[4], kernel_size=1, strides=1, padding='same')(y)
269+
result_bn = BatchNormalization()(result)
270+
return result_bn
271+
272+
273+
# Create a contracting layer
274+
def contracting_layer_3D(input, neurons, ba_norm, ba_norm_momentum):
275+
conv1 = Conv3D(filters=neurons, kernel_size=3, activation='relu', padding='same')(input)
276+
if ba_norm: conv1 = BatchNormalization(momentum=ba_norm_momentum)(conv1)
277+
conv2 = Conv3D(filters=neurons, kernel_size=3, activation='relu', padding='same')(conv1)
278+
if ba_norm: conv2 = BatchNormalization(momentum=ba_norm_momentum)(conv2)
279+
pool = MaxPooling3D(pool_size=2)(conv2)
280+
return pool, conv2
281+
282+
283+
# Create the middle layer between the contracting and expanding layers
284+
def middle_layer_3D(input, neurons, ba_norm, ba_norm_momentum):
285+
conv_m1 = Conv3D(filters=neurons, kernel_size=3, activation='relu', padding='same')(input)
286+
if ba_norm: conv_m1 = BatchNormalization(momentum=ba_norm_momentum)(conv_m1)
287+
conv_m2 = Conv3D(filters=neurons, kernel_size=3, activation='relu', padding='same')(conv_m1)
288+
if ba_norm: conv_m2 = BatchNormalization(momentum=ba_norm_momentum)(conv_m2)
289+
return conv_m2
290+
291+
292+
# Create an expanding layer
293+
def expanding_layer_3D(input, neurons, concatenate_link, ba_norm,
294+
ba_norm_momentum):
295+
gating = gating_signal3D(input, neurons, ba_norm)
296+
att = attention_block3D(concatenate_link, gating, neurons) # Neurons = Filter?
297+
up = concatenate([Conv3DTranspose(filters=neurons, kernel_size=2, strides=2,
298+
padding='same')(input), att], axis=-1)
299+
conv1 = Conv3D(filters=neurons, kernel_size=3, activation='relu', padding='same')(up)
300+
if ba_norm: conv1 = BatchNormalization(momentum=ba_norm_momentum)(conv1)
301+
conv2 = Conv3D(filters=neurons, kernel_size=3, activation='relu', padding='same')(conv1)
302+
if ba_norm: conv2 = BatchNormalization(momentum=ba_norm_momentum)(conv2)
303+
return conv2

0 commit comments

Comments
 (0)