Skip to content

Commit beb2cb0

Browse files
authored
Update data_loader.py
1 parent d13b3de commit beb2cb0

File tree

1 file changed

+8
-4
lines changed

1 file changed

+8
-4
lines changed

data_provider/data_loader.py

+8-4
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
386386
test_data = np.nan_to_num(test_data)
387387
self.test = self.scaler.transform(test_data)
388388
self.train = data
389-
self.val = self.test
389+
data_len = len(self.train)
390+
self.val = self.train[(int)(data_len * 0.8):]
390391
self.test_labels = pd.read_csv(os.path.join(root_path, 'test_label.csv')).values[:, 1:]
391392
print("test:", self.test.shape)
392393
print("train:", self.train.shape)
@@ -428,7 +429,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
428429
test_data = np.load(os.path.join(root_path, "MSL_test.npy"))
429430
self.test = self.scaler.transform(test_data)
430431
self.train = data
431-
self.val = self.test
432+
data_len = len(self.train)
433+
self.val = self.train[(int)(data_len * 0.8):]
432434
self.test_labels = np.load(os.path.join(root_path, "MSL_test_label.npy"))
433435
print("test:", self.test.shape)
434436
print("train:", self.train.shape)
@@ -470,7 +472,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
470472
test_data = np.load(os.path.join(root_path, "SMAP_test.npy"))
471473
self.test = self.scaler.transform(test_data)
472474
self.train = data
473-
self.val = self.test
475+
data_len = len(self.train)
476+
self.val = self.train[(int)(data_len * 0.8):]
474477
self.test_labels = np.load(os.path.join(root_path, "SMAP_test_label.npy"))
475478
print("test:", self.test.shape)
476479
print("train:", self.train.shape)
@@ -560,7 +563,8 @@ def __init__(self, root_path, win_size, step=1, flag="train"):
560563
test_data = self.scaler.transform(test_data)
561564
self.train = train_data
562565
self.test = test_data
563-
self.val = test_data
566+
data_len = len(self.train)
567+
self.val = self.train[(int)(data_len * 0.8):]
564568
self.test_labels = labels
565569
print("test:", self.test.shape)
566570
print("train:", self.train.shape)

0 commit comments

Comments
 (0)