Skip to content

Commit c9111a3

Browse files
committed
refactor(cifar): 利用sklearn库实现训练集和验证集分离
1 parent e524aaa commit c9111a3

File tree

1 file changed

+7
-23
lines changed

1 file changed

+7
-23
lines changed

pynet/vision/data/cifar.py

Lines changed: 7 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
import numpy as np
77
import os
8+
from sklearn.model_selection import train_test_split
89
from .utils import *
910

1011
train_list = ['data_batch_1', 'data_batch_2', 'data_batch_3', 'data_batch_4', 'data_batch_5']
@@ -40,44 +41,27 @@ def load_CIFAR10(file_dir):
4041
return x_train, y_train, x_test, y_test
4142

4243

43-
def get_CIFAR10_data(cifar_dir, num_validation=2000, normalize=True):
44+
def get_CIFAR10_data(cifar_dir, val_size=0.05, normalize=True):
4445
"""
4546
加载CIFAR10数据,从训练集中分类验证集数据
4647
:param cifar_dir: cifar解压文件路径
47-
:param num_validation: 验证集数量
48+
:param val_size: 浮点数,表示验证集占整个训练集的百分比
4849
:param normalize: 是否初始化为零均值,1方差
4950
:return: dict,保存训练集、验证集以及测试集的数据和标签
5051
"""
5152
x_train, y_train, x_test, y_test = load_CIFAR10(cifar_dir)
5253

53-
num_train = x_train.shape[0] - num_validation
54-
55-
# 打乱数据集
56-
np.random.shuffle(x_train)
57-
58-
mask = list(range(num_train, num_train + num_validation))
59-
x_val = x_train[mask]
60-
y_val = y_train[mask]
61-
mask = list(range(num_train))
62-
x_train = x_train[mask]
63-
y_train = y_train[mask]
54+
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=val_size, shuffle=False)
6455

6556
# Normalize the data: subtract the mean image and divide the variance
6657
if normalize:
67-
# eps = 1e-8
68-
# train_mean = np.mean(x_train, axis=0)
69-
# train_var = np.var(x_train, axis=0)
70-
# x_train = (x_train - train_mean) / np.sqrt(train_var + eps)
71-
# x_val = (x_val - train_mean) / np.sqrt(train_var + eps)
72-
# x_test = (x_test - train_mean) / np.sqrt(train_var + eps)
73-
7458
x_train = x_train / 255 - 0.5
7559
x_val = x_val / 255 - 0.5
7660
x_test = x_test / 255 - 0.5
7761

7862
# Package data into a dictionary
7963
return {
80-
'x_train': x_train, 'y_train': y_train,
81-
'x_val': x_val, 'y_val': y_val,
82-
'x_test': x_test, 'y_test': y_test,
64+
'X_train': x_train, 'y_train': y_train,
65+
'X_val': x_val, 'y_val': y_val,
66+
'X_test': x_test, 'y_test': y_test,
8367
}

0 commit comments

Comments
 (0)