|
5 | 5 |
|
6 | 6 | import numpy as np
|
7 | 7 | import os
|
| 8 | +from sklearn.model_selection import train_test_split |
8 | 9 | from .utils import *
|
9 | 10 |
|
10 | 11 | train_list = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
|
@@ -40,44 +41,27 @@ def load_CIFAR10(file_dir):
|
40 | 41 | return x_train, y_train, x_test, y_test
|
41 | 42 |
|
42 | 43 |
|
43 |
| -def get_CIFAR10_data(cifar_dir, num_validation=2000, normalize=True): |
| 44 | +def get_CIFAR10_data(cifar_dir, val_size=0.05, normalize=True): |
44 | 45 | """
|
45 | 46 | 加载CIFAR10数据,从训练集中分类验证集数据
|
46 | 47 | :param cifar_dir: cifar解压文件路径
|
47 |
| - :param num_validation: 验证集数量 |
| 48 | + :param val_size: 浮点数,表示验证集占整个训练集的百分比 |
48 | 49 | :param normalize: 是否初始化为零均值,1方差
|
49 | 50 | :return: dict,保存训练集、验证集以及测试集的数据和标签
|
50 | 51 | """
|
51 | 52 | x_train, y_train, x_test, y_test = load_CIFAR10(cifar_dir)
|
52 | 53 |
|
53 |
| - num_train = x_train.shape[0] - num_validation |
54 |
| - |
55 |
| - # 打乱数据集 |
56 |
| - np.random.shuffle(x_train) |
57 |
| - |
58 |
| - mask = list(range(num_train, num_train + num_validation)) |
59 |
| - x_val = x_train[mask] |
60 |
| - y_val = y_train[mask] |
61 |
| - mask = list(range(num_train)) |
62 |
| - x_train = x_train[mask] |
63 |
| - y_train = y_train[mask] |
| 54 | + x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=val_size, shuffle=False) |
64 | 55 |
|
65 | 56 | # Normalize the data: subtract the mean image and divide the variance
|
66 | 57 | if normalize:
|
67 |
| - # eps = 1e-8 |
68 |
| - # train_mean = np.mean(x_train, axis=0) |
69 |
| - # train_var = np.var(x_train, axis=0) |
70 |
| - # x_train = (x_train - train_mean) / np.sqrt(train_var + eps) |
71 |
| - # x_val = (x_val - train_mean) / np.sqrt(train_var + eps) |
72 |
| - # x_test = (x_test - train_mean) / np.sqrt(train_var + eps) |
73 |
| - |
74 | 58 | x_train = x_train / 255 - 0.5
|
75 | 59 | x_val = x_val / 255 - 0.5
|
76 | 60 | x_test = x_test / 255 - 0.5
|
77 | 61 |
|
78 | 62 | # Package data into a dictionary
|
79 | 63 | return {
|
80 |
| - 'x_train': x_train, 'y_train': y_train, |
81 |
| - 'x_val': x_val, 'y_val': y_val, |
82 |
| - 'x_test': x_test, 'y_test': y_test, |
| 64 | + 'X_train': x_train, 'y_train': y_train, |
| 65 | + 'X_val': x_val, 'y_val': y_val, |
| 66 | + 'X_test': x_test, 'y_test': y_test, |
83 | 67 | }
|
0 commit comments