Skip to content

Commit 63c101c

Browse files
committed
commented cross validation.
1 parent e5d2558 commit 63c101c

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

miscnn/model/cross_validation_group.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,13 +23,42 @@
2323
from tensorflow.keras.callbacks import ModelCheckpoint
2424
import os
2525

26+
#-----------------------------------------------------#
27+
# Cross Validation Model Group class #
28+
#-----------------------------------------------------#
29+
# Cross validation using a Model Group.
2630
class CrossValidationGroup(Model_Group):
2731

32+
""" Initialization function for creating a Model Group object. This object will train and predict sub models.
33+
The predictions are merged using an aggregation function.
34+
35+
Args:
36+
model (Model): Model that should be used for cross validation.
37+
preprocessor (Preprocessor): Preprocessor class that the Model Group should refer to for pipeline structure.
38+
This does not necissarily mean that all models share that preprocessor.
39+
folds (integer): the number of folds or models that should be used.
40+
verify_preprocessor (Boolean): Enable checking whether all models share the preprocessor of the model group.
41+
EWnabled by default. Disable to use models in combination with different preprocessing methods.
42+
"""
2843
def __init__(self, model, preprocessor, folds, verify_preprocessor=True):
2944
modelList = [model] + [model.copy() for i in range(folds)]
3045
Model_Group.__init__(self, modelList, preprocessor, verify_preprocessor)
3146
self.folds = folds
3247

48+
#---------------------------------------------#
49+
# Evaluation #
50+
#---------------------------------------------#
51+
""" Evaluation function for the model group using the provided lists of sample indices
52+
for training and validation. It is also possible to pass custom Callback classes in order to
53+
obtain more information.
54+
55+
Args:
56+
samples (list of indices): A list of sample indicies which will be used
57+
evaluation_path (string): The base path for the evaluation.
58+
epochs (integer): Number of epochs. A single epoch is defined as one iteration through the complete data set.
59+
iterations (integer): Number of iterations (batches) in a single epoch.
60+
callbacks (list of Callback classes): A list of Callback classes for custom evaluation
61+
"""
3362
def evaluate(self, samples, evaluation_path="evaluation", epochs=20, iterations=None, callbacks=[], *args, **kwargs):
3463
samples_permuted = np.random.permutation(samples)
3564
# Split sample list into folds

miscnn/model/model_group.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,6 @@ def predict(self, sample_list, aggregation_func, activation_output=False):
145145
epochs (integer): Number of epochs. A single epoch is defined as one iteration through the complete data set.
146146
iterations (integer): Number of iterations (batches) in a single epoch.
147147
callbacks (list of Callback classes): A list of Callback classes for custom evaluation
148-
Return:
149-
history (Keras history object): Gathered fitting information and evaluation results of the validation
150148
"""
151149
def evaluate(self, training_samples, validation_samples, evaluation_path="evaluation", epochs=20, iterations=None, callbacks=[]):
152150
for model in self.models:

0 commit comments

Comments
 (0)