Skip to content

Commit db78ca0

Browse files
committed
dataloader fixes
1 parent dbd4ced commit db78ca0

File tree

1 file changed

+52
-24
lines changed

1 file changed

+52
-24
lines changed

SwissKnife/dataloader.py

+52-24
Original file line numberDiff line numberDiff line change
@@ -215,24 +215,44 @@ def categorize_data(self, num_classes, recurrent=False):
215215
self.y_test_recurrent, num_classes=num_classes, dtype="int"
216216
)
217217

218-
def normalize_data(self):
219-
# TODO: double check this here
220-
# self.mean = self.x_train[1000:-1000].mean(axis=0)
221-
# self.std = np.std(self.x_train[1000:-1000], axis=0)
222-
self.mean = self.x_train.mean(axis=0)
223-
self.std = np.std(self.x_train, axis=0)
224-
self.x_train = self.x_train - self.mean
225-
self.x_train /= self.std
226-
self.x_test = self.x_test - self.mean
227-
self.x_test /= self.std
228-
229-
if not self.dlc_train is None:
230-
self.mean_dlc = self.dlc_train.mean(axis=0)
231-
self.std_dlc = self.dlc_train.std(axis=0)
232-
self.dlc_train -= self.mean_dlc
233-
self.dlc_test -= self.mean
234-
self.dlc_train /= self.std_dlc
235-
self.dlc_test /= self.std_dlc
218+
def normalize_data(self, mode="default"):
219+
if mode == "default":
220+
# TODO: double check this here
221+
# self.mean = self.x_train[1000:-1000].mean(axis=0)
222+
# self.std = np.std(self.x_train[1000:-1000], axis=0)
223+
self.mean = self.x_train.mean(axis=0)
224+
self.std = np.std(self.x_train, axis=0)
225+
self.x_train = self.x_train - self.mean
226+
self.x_train /= self.std
227+
self.x_test = self.x_test - self.mean
228+
self.x_test /= self.std
229+
230+
if not self.dlc_train is None:
231+
self.mean_dlc = self.dlc_train.mean(axis=0)
232+
self.std_dlc = self.dlc_train.std(axis=0)
233+
self.dlc_train -= self.mean_dlc
234+
self.dlc_test -= self.mean
235+
self.dlc_train /= self.std_dlc
236+
self.dlc_test /= self.std_dlc
237+
elif mode == "xception":
238+
self.x_train /= 127.5
239+
self.x_train -= 1.0
240+
self.x_test /= 127.5
241+
self.x_test -= 1.0
242+
243+
if not self.dlc_train is None:
244+
self.dlc_train /= 127.5
245+
self.dlc_train -= 1.0
246+
self.dlc_test /= 127.5
247+
self.dlc_test -= 1.0
248+
249+
else:
250+
self.x_train /= 255.0
251+
self.x_test /= 255.0
252+
if not self.dlc_train is None:
253+
self.dlc_train /= 255.0
254+
self.dlc_test /= 255.0
255+
236256

237257
def create_dataset(dataset, oneD, look_back=5):
238258
"""
@@ -438,9 +458,9 @@ def undersample_data(self):
438458

439459
# TODO: undersample recurrent
440460

441-
def change_dtype(self):
442-
self.x_train = np.asarray(self.x_train, dtype="uint8")
443-
self.x_test = np.asarray(self.x_test, dtype="uint8")
461+
def change_dtype(self, dtype="uint8"):
462+
self.x_train = np.asarray(self.x_train, dtype=dtype)
463+
self.x_test = np.asarray(self.x_test, dtype=dtype)
444464

445465
def get_input_shape(self, recurrent=False):
446466
"""
@@ -484,23 +504,30 @@ def downscale_frames(self, factor=0.5):
484504
self.x_test = np.asarray(im_re)
485505

486506
def prepare_data(
487-
self, downscale=0.5, remove_behaviors=[], flatten=False
507+
self, downscale=0.5, remove_behaviors=[], flatten=False, recurrent=False, normalization_mode='default'
488508
):
489509
print("preparing data")
510+
print("changing dtype")
490511
self.change_dtype()
491512

513+
print("removing behaviors")
492514
for behavior in remove_behaviors:
493515
self.remove_behavior(behavior=behavior)
516+
print("downscaling")
494517
if downscale:
495518
self.downscale_frames(factor=downscale)
519+
print("normalizing data")
496520
if self.config["normalize_data"]:
497521
self.normalize_data()
522+
print("doing flow")
498523
if self.config["do_flow"]:
499524
self.create_flow_data()
525+
print("encoding labels")
500526
if self.config["encode_labels"]:
501527
print("test")
502528
self.encode_labels()
503529
print("labels encoded")
530+
print("using class weights")
504531
if self.config["use_class_weights"]:
505532
print("calc class weights")
506533
self.class_weights = class_weight.compute_class_weight(
@@ -509,16 +536,17 @@ def prepare_data(
509536
if self.config["undersample_data"]:
510537
print("undersampling data")
511538
self.undersample_data()
539+
print("using generator")
512540
if self.config["use_generator"]:
513-
self.categorize_data(self.num_classes, recurrent=False)
541+
self.categorize_data(self.num_classes, recurrent=recurrent)
514542
else:
515543
print("preparing recurrent data")
516544
self.create_recurrent_data()
517545
print("preparing flattened data")
518546
if flatten:
519547
self.create_flattened_data()
520548
print("categorize data")
521-
self.categorize_data(self.num_classes, recurrent=True)
549+
self.categorize_data(self.num_classes, recurrent=recurrent)
522550

523551
print("data ready")
524552

0 commit comments

Comments
 (0)