Skip to content

Commit a37fb85

Browse files
committed
Implemented passing of activation_output parameter to all Subfunctions. Added unittesting. Fixed buggy postprocessing for resize/resampling. Fixing #109
1 parent 83785a9 commit a37fb85

File tree

9 files changed

+61
-31
lines changed

9 files changed

+61
-31
lines changed

miscnn/processing/preprocessor.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -215,10 +215,9 @@ def postprocessing(self, sample, prediction, activation_output=False):
215215
else : prediction = np.squeeze(prediction, axis=0)
216216
# Transform probabilities to classes
217217
if not activation_output : prediction = np.argmax(prediction, axis=-1)
218-
219218
# Run Subfunction postprocessing on the prediction
220219
for sf in reversed(self.subfunctions):
221-
prediction = sf.postprocessing(sample, prediction)
220+
prediction = sf.postprocessing(sample, prediction, activation_output)
222221
# Return postprocessed prediction
223222
return prediction
224223

miscnn/processing/subfunctions/abstract_subfunction.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -77,11 +77,13 @@ def preprocessing(self, sample, training=True):
7777
It is possible to pass configurations through the initialization function of this class.
7878
7979
Parameter:
80-
sample (Sample Object): The sample object that was segmented. Contains metadata.
81-
prediction (numpy array): Numpy array of the predicted segmentation
80+
sample (Sample Object): The sample object that was segmented. Contains metadata.
81+
prediction (numpy array): Numpy array of the predicted segmentation
82+
activation_output (boolean): Parameter which decides, if model output (activation function, normally softmax) will
83+
be saved/outputed (if FALSE) or if the resulting class label (argmax) should be outputed.
8284
Return:
8385
prediction (numpy array): Numpy array of processed predicted segmentation
8486
"""
8587
@abstractmethod
86-
def postprocessing(self, sample, prediction):
88+
def postprocessing(self, sample, prediction, activation_output=False):
8789
return prediction

miscnn/processing/subfunctions/clipping.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,5 +56,5 @@ def preprocessing(self, sample, training=True):
5656
#---------------------------------------------#
5757
# Postprocessing #
5858
#---------------------------------------------#
59-
def postprocessing(self, sample, prediction):
59+
def postprocessing(self, sample, prediction, activation_output=False):
6060
return prediction

miscnn/processing/subfunctions/normalization.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,5 +83,5 @@ def preprocessing(self, sample, training=True):
8383
#---------------------------------------------#
8484
# Postprocessing #
8585
#---------------------------------------------#
86-
def postprocessing(self, sample, prediction):
86+
def postprocessing(self, sample, prediction, activation_output=False):
8787
return prediction

miscnn/processing/subfunctions/padding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def preprocessing(self, sample, training=True):
9494
#---------------------------------------------#
9595
# Postprocessing #
9696
#---------------------------------------------#
97-
def postprocessing(self, sample, prediction):
97+
def postprocessing(self, sample, prediction, activation_output=False):
9898
# Access original coordinates of the last sample and reset it
9999
original_coords = sample.extended["orig_crop_coords"]
100100
# Transform original shape to one-channel array for cropping

miscnn/processing/subfunctions/resampling.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,6 @@ class Resampling(Abstract_Subfunction):
4444
#---------------------------------------------#
4545
def __init__(self, new_spacing=(1,1,1)):
4646
self.new_spacing = new_spacing
47-
self.original_shape = None
4847

4948
#---------------------------------------------#
5049
# Preprocessing #
@@ -58,7 +57,7 @@ def preprocessing(self, sample, training=True):
5857
except AttributeError:
5958
print("'spacing' is not initialized in sample details!")
6059
# Cache current spacing for later postprocessing
61-
if not training : sample.extended["orig_spacing"] = (1,) + img_data.shape[0:-1]
60+
if not training : sample.extended["original_shape"] = img_data.shape[0:-1]
6261
# Calculate spacing ratio
6362
ratio = current_spacing / np.array(self.new_spacing)
6463
# Calculate new shape
@@ -79,24 +78,24 @@ def preprocessing(self, sample, training=True):
7978
#---------------------------------------------#
8079
# Postprocessing #
8180
#---------------------------------------------#
82-
def postprocessing(self, sample, prediction):
81+
def postprocessing(self, sample, prediction, activation_output=False):
8382
# Access original shape of the last sample and reset it
84-
original_shape = sample.get_extended_data()["orig_spacing"]
85-
# Handle resampling shape for activation output
86-
if len(prediction.shape) != (len(original_shape) - 1):
87-
original_shape = (prediction.shape[-1], ) + original_shape[1:]
83+
original_shape = sample.get_extended_data()["original_shape"]
8884
# Transform original shape to one-channel array for resampling
89-
else:
85+
if not activation_output:
86+
target_shape = (1,) + original_shape
9087
prediction = np.reshape(prediction, prediction.shape + (1,))
88+
# Handle resampling shape for activation output
89+
else : target_shape = (prediction.shape[-1], ) + original_shape
9190
# Transform prediction from channel-last to channel-first structure
9291
prediction = np.moveaxis(prediction, -1, 0)
9392
# Resample imaging data
94-
prediction = resize_segmentation(prediction, original_shape, order=1,
93+
prediction = resize_segmentation(prediction, target_shape, order=1,
9594
cval=0)
9695
# Transform data from channel-first back to channel-last structure
9796
prediction = np.moveaxis(prediction, 0, -1)
9897
# Transform one-channel array back to original shape
99-
if prediction.shape[-1] == 1:
100-
prediction = np.reshape(prediction, original_shape[1:])
98+
if not activation_output:
99+
prediction = np.reshape(prediction, original_shape)
101100
# Return postprocessed prediction
102101
return prediction

miscnn/processing/subfunctions/resize.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def preprocessing(self, sample, training=True):
5151
img_data = sample.img_data
5252
seg_data = sample.seg_data
5353
# Cache current spacing for later postprocessing
54-
if not training : sample.extended["orig_resize_shape"] = (1,) + img_data.shape[0:-1]
54+
if not training : sample.extended["original_shape"] = img_data.shape[0:-1]
5555
# Transform data from channel-last to channel-first structure
5656
img_data = np.moveaxis(img_data, -1, 0)
5757
if training : seg_data = np.moveaxis(seg_data, -1, 0)
@@ -68,19 +68,24 @@ def preprocessing(self, sample, training=True):
6868
#---------------------------------------------#
6969
# Postprocessing #
7070
#---------------------------------------------#
71-
def postprocessing(self, sample, prediction):
71+
def postprocessing(self, sample, prediction, activation_output=False):
7272
# Access original shape of the last sample and reset it
73-
original_shape = sample.get_extended_data()["orig_resize_shape"]
73+
original_shape = sample.get_extended_data()["original_shape"]
7474
# Transform original shape to one-channel array
75-
prediction = np.reshape(prediction, prediction.shape + (1,))
75+
if not activation_output:
76+
target_shape = (1,) + original_shape
77+
prediction = np.reshape(prediction, prediction.shape + (1,))
78+
# Handle resampling shape for activation output
79+
else : target_shape = (prediction.shape[-1], ) + original_shape
7680
# Transform prediction from channel-last to channel-first structure
7781
prediction = np.moveaxis(prediction, -1, 0)
7882
# Resize imaging data
79-
prediction = resize_segmentation(prediction, original_shape, order=1,
83+
prediction = resize_segmentation(prediction, target_shape, order=1,
8084
cval=0)
8185
# Transform data from channel-first back to channel-last structure
8286
prediction = np.moveaxis(prediction, 0, -1)
8387
# Transform one-channel array back to original shape
84-
prediction = np.reshape(prediction, original_shape[1:])
88+
if not activation_output:
89+
prediction = np.reshape(prediction, original_shape)
8590
# Return postprocessed prediction
8691
return prediction

miscnn/processing/subfunctions/transform_HU.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,5 @@ def normalize_HU(self, image):
110110
#---------------------------------------------#
111111
# Postprocessing #
112112
#---------------------------------------------#
113-
def postprocessing(self, sample, prediction):
113+
def postprocessing(self, sample, prediction, activation_output=False):
114114
return prediction

tests/test_subfunctions.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -266,6 +266,32 @@ def test_SUBFUNCTIONS_RESIZE_postprocessing(self):
266266
else : old_shape = (16,16,16)
267267
self.assertEqual(pred.shape, old_shape)
268268

269+
def test_SUBFUNCTIONS_RESIZE_postprocessing_activationOutput(self):
270+
# Test for 2D and 3D
271+
for dim in ["2D", "3D"]:
272+
# Initialize Subfunction
273+
if dim == "2D" : new_shape = (7,7)
274+
else : new_shape = (7,7,7)
275+
sf = Resize(new_shape=new_shape)
276+
# Create sample objects
277+
sample_pred = deepcopy(getattr(self, "sample" + dim))
278+
sample_train = deepcopy(getattr(self, "sample" + dim + "seg"))
279+
# Run preprocessing of the subfunction
280+
sf.preprocessing(sample_train, training=True)
281+
sf.preprocessing(sample_pred, training=False)
282+
# Transform segmentation data to simulate prediction data
283+
if dim == "2D":
284+
sample_pred.pred_data = np.random.rand(16, 16, 3)
285+
else:
286+
sample_pred.pred_data = np.random.rand(16, 16, 16, 3)
287+
# Run postprocessing of the subfunction
288+
pred = sf.postprocessing(sample_pred, sample_pred.pred_data,
289+
activation_output=True)
290+
# Check for correctness
291+
if dim == "2D" : old_shape = (16,16,3)
292+
else : old_shape = (16,16,16,3)
293+
self.assertEqual(pred.shape, old_shape)
294+
269295
#-------------------------------------------------#
270296
# Resampling #
271297
#-------------------------------------------------#
@@ -327,8 +353,6 @@ def test_SUBFUNCTIONS_RESAMPLING_postprocessing(self):
327353
def test_SUBFUNCTIONS_RESAMPLING_postprocessing_activationOutput(self):
328354
# Test for 2D and 3D
329355
for dim in ["2D", "3D"]:
330-
if dim == "2D" : continue
331-
332356
# Initialize Subfunction
333357
if dim == "2D" : spacing = (1,1)
334358
else : spacing = (1,1,1)
@@ -345,11 +369,12 @@ def test_SUBFUNCTIONS_RESAMPLING_postprocessing_activationOutput(self):
345369
sf.preprocessing(sample_pred, training=False)
346370
# Transform segmentation data to simulate prediction data
347371
if dim == "2D":
348-
sample_pred.pred_data = np.random.rand(16, 16, 3) * 3
372+
sample_pred.pred_data = np.random.rand(16, 16, 3)
349373
else:
350-
sample_pred.pred_data = np.random.rand(16, 16, 16, 3) * 3
374+
sample_pred.pred_data = np.random.rand(16, 16, 16, 3)
351375
# Run postprocessing of the subfunction
352-
pred = sf.postprocessing(sample_pred, sample_pred.pred_data)
376+
pred = sf.postprocessing(sample_pred, sample_pred.pred_data,
377+
activation_output=True)
353378
# Check for correctness
354379
if dim == "2D" : old_shape = (16,16,3)
355380
else : old_shape = (16,16,16,3)

0 commit comments

Comments
 (0)