@@ -386,7 +386,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
386
386
test_data = np .nan_to_num (test_data )
387
387
self .test = self .scaler .transform (test_data )
388
388
self .train = data
389
- self .val = self .test
389
+ data_len = len (self .train )
390
+ self .val = self .train [(int )(data_len * 0.8 ):]
390
391
self .test_labels = pd .read_csv (os .path .join (root_path , 'test_label.csv' )).values [:, 1 :]
391
392
print ("test:" , self .test .shape )
392
393
print ("train:" , self .train .shape )
@@ -428,7 +429,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
428
429
test_data = np .load (os .path .join (root_path , "MSL_test.npy" ))
429
430
self .test = self .scaler .transform (test_data )
430
431
self .train = data
431
- self .val = self .test
432
+ data_len = len (self .train )
433
+ self .val = self .train [(int )(data_len * 0.8 ):]
432
434
self .test_labels = np .load (os .path .join (root_path , "MSL_test_label.npy" ))
433
435
print ("test:" , self .test .shape )
434
436
print ("train:" , self .train .shape )
@@ -470,7 +472,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
470
472
test_data = np .load (os .path .join (root_path , "SMAP_test.npy" ))
471
473
self .test = self .scaler .transform (test_data )
472
474
self .train = data
473
- self .val = self .test
475
+ data_len = len (self .train )
476
+ self .val = self .train [(int )(data_len * 0.8 ):]
474
477
self .test_labels = np .load (os .path .join (root_path , "SMAP_test_label.npy" ))
475
478
print ("test:" , self .test .shape )
476
479
print ("train:" , self .train .shape )
@@ -560,7 +563,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
560
563
test_data = self .scaler .transform (test_data )
561
564
self .train = train_data
562
565
self .test = test_data
563
- self .val = test_data
566
+ data_len = len (self .train )
567
+ self .val = self .train [(int )(data_len * 0.8 ):]
564
568
self .test_labels = labels
565
569
print ("test:" , self .test .shape )
566
570
print ("train:" , self .train .shape )
0 commit comments