Skip to content

Commit 062118e

Browse files
committed
Add option in eval.py to use different resnets. (Must match with what you use to preprocess.)
1 parent 879971b commit 062118e

File tree

2 files changed

+15
-8
lines changed

2 files changed

+15
-8
lines changed

dataloaderraw.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,6 @@
2222
from misc.resnet_utils import myResnet
2323
import misc.resnet as resnet
2424

25-
resnet = resnet.resnet101()
26-
resnet.load_state_dict(torch.load('./data/imagenet_weights/resnet101.pth'))
27-
my_resnet = myResnet(resnet)
28-
my_resnet.cuda()
29-
my_resnet.eval()
30-
3125
class DataLoaderRaw():
3226

3327
def __init__(self, opt):
@@ -38,6 +32,16 @@ def __init__(self, opt):
3832
self.batch_size = opt.get('batch_size', 1)
3933
self.seq_per_img = 1
4034

35+
# Load resnet
36+
self.cnn_model = opt.get('cnn_model', 'resnet101')
37+
resnet = getattr(resnet, self.cnn_model)()
38+
resnet.load_state_dict(torch.load('./data/imagenet_weights/'+self.cnn_model+'.pth'))
39+
self.my_resnet = myResnet(resnet)
40+
self.my_resnet.cuda()
41+
self.my_resnet.eval()
42+
43+
44+
4145
# load the json file which contains additional information about the dataset
4246
print('DataLoaderRaw loading images from folder: ', self.folder_path)
4347

@@ -106,7 +110,7 @@ def get_batch(self, split, batch_size=None):
106110
img = img.astype('float32')/255.0
107111
img = torch.from_numpy(img.transpose([2,0,1])).cuda()
108112
img = Variable(preprocess(img), volatile=True)
109-
tmp_fc, tmp_att = my_resnet(img)
113+
tmp_fc, tmp_att = self.my_resnet(img)
110114

111115
fc_batch[i] = tmp_fc.data.cpu().float().numpy()
112116
att_batch[i] = tmp_att.data.cpu().float().numpy()

eval.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,8 @@
2323
# Input paths
2424
parser.add_argument('--model', type=str, default='',
2525
help='path to model to evaluate')
26+
parser.add_argument('--cnn_model', type=str, default='resnet101',
27+
help='resnet101, resnet152')
2628
parser.add_argument('--infos_path', type=str, default='',
2729
help='path to infos to evaluate')
2830
# Basic options
@@ -108,7 +110,8 @@
108110
else:
109111
loader = DataLoaderRaw({'folder_path': opt.image_folder,
110112
'coco_json': opt.coco_json,
111-
'batch_size': opt.batch_size})
113+
'batch_size': opt.batch_size,
114+
'cnn_model': opt.cnn_model})
112115
loader.ix_to_word = infos['vocab']
113116

114117

0 commit comments

Comments
 (0)