Skip to content

Commit 30c0137

Browse files
committed
feat(accuracy): 计算模型准确率
1 parent 445f782 commit 30c0137

File tree

2 files changed

+63
-0
lines changed

2 files changed

+63
-0
lines changed

py/accuracy.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# -*- coding: utf-8 -*-
2+
3+
"""
4+
@date: 2020/4/28 上午10:32
5+
@file: accuracy.py
6+
@author: zj
7+
@description: 计算Top-1 correct rate
8+
"""
9+
10+
import torch
11+
from torch.utils.data import DataLoader
12+
from torchvision.models import alexnet
13+
from torchvision.datasets import CIFAR10
14+
import torchvision.transforms as transforms
15+
16+
from utils import util
17+
18+
19+
def accuracy(data_loader, model, device=None):
20+
if device:
21+
model = model.to(device)
22+
23+
running_corrects = 0
24+
for inputs, targets in data_loader:
25+
if device:
26+
inputs = inputs.to(device)
27+
targets = targets.to(device)
28+
29+
# forward
30+
# track history if only in train
31+
with torch.set_grad_enabled(False):
32+
outputs = model(inputs)
33+
# print(outputs.shape)
34+
_, preds = torch.max(outputs, 1)
35+
36+
# statistics
37+
running_corrects += torch.sum(preds == targets.data)
38+
39+
epoch_acc = running_corrects.double() / len(data_loader.dataset)
40+
return epoch_acc
41+
42+
43+
if __name__ == '__main__':
44+
transform = transforms.Compose([
45+
transforms.Resize(224, 224),
46+
transforms.ToTensor(),
47+
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
48+
])
49+
50+
# 提取测试集
51+
data_set = CIFAR10('./data', download=True, train=False, transform=transform)
52+
data_loader = DataLoader(data_set, shuffle=True, batch_size=128, num_workers=8)
53+
54+
num_classes = 10
55+
model = alexnet(num_classes=num_classes)
56+
57+
device = util.get_device()
58+
acc = accuracy(data_loader, model, device=device)
59+
print('acc: {:.3f}'.format(acc))

py/utils/util.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@
1010
import numpy as np
1111
import math
1212
import sys
13+
import torch
14+
15+
def get_device():
16+
return torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
1317

1418

1519
def error(msg):

0 commit comments

Comments
 (0)