@@ -215,24 +215,44 @@ def categorize_data(self, num_classes, recurrent=False):
215
215
self .y_test_recurrent , num_classes = num_classes , dtype = "int"
216
216
)
217
217
218
- def normalize_data (self ):
219
- # TODO: double check this here
220
- # self.mean = self.x_train[1000:-1000].mean(axis=0)
221
- # self.std = np.std(self.x_train[1000:-1000], axis=0)
222
- self .mean = self .x_train .mean (axis = 0 )
223
- self .std = np .std (self .x_train , axis = 0 )
224
- self .x_train = self .x_train - self .mean
225
- self .x_train /= self .std
226
- self .x_test = self .x_test - self .mean
227
- self .x_test /= self .std
228
-
229
- if not self .dlc_train is None :
230
- self .mean_dlc = self .dlc_train .mean (axis = 0 )
231
- self .std_dlc = self .dlc_train .std (axis = 0 )
232
- self .dlc_train -= self .mean_dlc
233
- self .dlc_test -= self .mean
234
- self .dlc_train /= self .std_dlc
235
- self .dlc_test /= self .std_dlc
218
+ def normalize_data (self , mode = "default" ):
219
+ if mode == "default" :
220
+ # TODO: double check this here
221
+ # self.mean = self.x_train[1000:-1000].mean(axis=0)
222
+ # self.std = np.std(self.x_train[1000:-1000], axis=0)
223
+ self .mean = self .x_train .mean (axis = 0 )
224
+ self .std = np .std (self .x_train , axis = 0 )
225
+ self .x_train = self .x_train - self .mean
226
+ self .x_train /= self .std
227
+ self .x_test = self .x_test - self .mean
228
+ self .x_test /= self .std
229
+
230
+ if not self .dlc_train is None :
231
+ self .mean_dlc = self .dlc_train .mean (axis = 0 )
232
+ self .std_dlc = self .dlc_train .std (axis = 0 )
233
+ self .dlc_train -= self .mean_dlc
234
+ self .dlc_test -= self .mean
235
+ self .dlc_train /= self .std_dlc
236
+ self .dlc_test /= self .std_dlc
237
+ elif mode == "xception" :
238
+ self .x_train /= 127.5
239
+ self .x_train -= 1.0
240
+ self .x_test /= 127.5
241
+ self .x_test -= 1.0
242
+
243
+ if not self .dlc_train is None :
244
+ self .dlc_train /= 127.5
245
+ self .dlc_train -= 1.0
246
+ self .dlc_test /= 127.5
247
+ self .dlc_test -= 1.0
248
+
249
+ else :
250
+ self .x_train /= 255.0
251
+ self .x_test /= 255.0
252
+ if not self .dlc_train is None :
253
+ self .dlc_train /= 255.0
254
+ self .dlc_test /= 255.0
255
+
236
256
237
257
def create_dataset (dataset , oneD , look_back = 5 ):
238
258
"""
@@ -438,9 +458,9 @@ def undersample_data(self):
438
458
439
459
# TODO: undersample recurrent
440
460
441
- def change_dtype (self ):
442
- self .x_train = np .asarray (self .x_train , dtype = "uint8" )
443
- self .x_test = np .asarray (self .x_test , dtype = "uint8" )
461
+ def change_dtype (self , dtype = "uint8" ):
462
+ self .x_train = np .asarray (self .x_train , dtype = dtype )
463
+ self .x_test = np .asarray (self .x_test , dtype = dtype )
444
464
445
465
def get_input_shape (self , recurrent = False ):
446
466
"""
@@ -484,23 +504,30 @@ def downscale_frames(self, factor=0.5):
484
504
self .x_test = np .asarray (im_re )
485
505
486
506
def prepare_data (
487
- self , downscale = 0.5 , remove_behaviors = [], flatten = False
507
+ self , downscale = 0.5 , remove_behaviors = [], flatten = False , recurrent = False , normalization_mode = 'default'
488
508
):
489
509
print ("preparing data" )
510
+ print ("changing dtype" )
490
511
self .change_dtype ()
491
512
513
+ print ("removing behaviors" )
492
514
for behavior in remove_behaviors :
493
515
self .remove_behavior (behavior = behavior )
516
+ print ("downscaling" )
494
517
if downscale :
495
518
self .downscale_frames (factor = downscale )
519
+ print ("normalizing data" )
496
520
if self .config ["normalize_data" ]:
497
521
self .normalize_data ()
522
+ print ("doing flow" )
498
523
if self .config ["do_flow" ]:
499
524
self .create_flow_data ()
525
+ print ("encoding labels" )
500
526
if self .config ["encode_labels" ]:
501
527
print ("test" )
502
528
self .encode_labels ()
503
529
print ("labels encoded" )
530
+ print ("using class weights" )
504
531
if self .config ["use_class_weights" ]:
505
532
print ("calc class weights" )
506
533
self .class_weights = class_weight .compute_class_weight (
@@ -509,16 +536,17 @@ def prepare_data(
509
536
if self .config ["undersample_data" ]:
510
537
print ("undersampling data" )
511
538
self .undersample_data ()
539
+ print ("using generator" )
512
540
if self .config ["use_generator" ]:
513
- self .categorize_data (self .num_classes , recurrent = False )
541
+ self .categorize_data (self .num_classes , recurrent = recurrent )
514
542
else :
515
543
print ("preparing recurrent data" )
516
544
self .create_recurrent_data ()
517
545
print ("preparing flattened data" )
518
546
if flatten :
519
547
self .create_flattened_data ()
520
548
print ("categorize data" )
521
- self .categorize_data (self .num_classes , recurrent = True )
549
+ self .categorize_data (self .num_classes , recurrent = recurrent )
522
550
523
551
print ("data ready" )
524
552
0 commit comments