|
23 | 23 | from tensorflow.keras.callbacks import ModelCheckpoint
|
24 | 24 | import os
|
25 | 25 |
|
| 26 | +#-----------------------------------------------------# |
| 27 | +# Cross Validation Model Group class # |
| 28 | +#-----------------------------------------------------# |
| 29 | +# Cross validation using a Model Group. |
26 | 30 | class CrossValidationGroup(Model_Group):
|
27 | 31 |
|
| 32 | + """ Initialization function for creating a Model Group object. This object will train and predict sub models. |
| 33 | + The predictions are merged using an aggregation function. |
| 34 | +
|
| 35 | + Args: |
| 36 | + model (Model): Model that should be used for cross validation. |
| 37 | + preprocessor (Preprocessor): Preprocessor class that the Model Group should refer to for pipeline structure. |
| 38 | + This does not necissarily mean that all models share that preprocessor. |
| 39 | + folds (integer): the number of folds or models that should be used. |
| 40 | + verify_preprocessor (Boolean): Enable checking whether all models share the preprocessor of the model group. |
| 41 | + EWnabled by default. Disable to use models in combination with different preprocessing methods. |
| 42 | + """ |
28 | 43 | def __init__(self, model, preprocessor, folds, verify_preprocessor=True):
|
29 | 44 | modelList = [model] + [model.copy() for i in range(folds)]
|
30 | 45 | Model_Group.__init__(self, modelList, preprocessor, verify_preprocessor)
|
31 | 46 | self.folds = folds
|
32 | 47 |
|
| 48 | + #---------------------------------------------# |
| 49 | + # Evaluation # |
| 50 | + #---------------------------------------------# |
| 51 | + """ Evaluation function for the model group using the provided lists of sample indices |
| 52 | + for training and validation. It is also possible to pass custom Callback classes in order to |
| 53 | + obtain more information. |
| 54 | +
|
| 55 | + Args: |
| 56 | + samples (list of indices): A list of sample indicies which will be used |
| 57 | + evaluation_path (string): The base path for the evaluation. |
| 58 | + epochs (integer): Number of epochs. A single epoch is defined as one iteration through the complete data set. |
| 59 | + iterations (integer): Number of iterations (batches) in a single epoch. |
| 60 | + callbacks (list of Callback classes): A list of Callback classes for custom evaluation |
| 61 | + """ |
33 | 62 | def evaluate(self, samples, evaluation_path="evaluation", epochs=20, iterations=None, callbacks=[], *args, **kwargs):
|
34 | 63 | samples_permuted = np.random.permutation(samples)
|
35 | 64 | # Split sample list into folds
|
|
0 commit comments