2222from misc .resnet_utils import myResnet
2323import misc .resnet as resnet
2424
25- resnet = resnet .resnet101 ()
26- resnet .load_state_dict (torch .load ('./data/imagenet_weights/resnet101.pth' ))
27- my_resnet = myResnet (resnet )
28- my_resnet .cuda ()
29- my_resnet .eval ()
30-
3125class DataLoaderRaw ():
3226
3327 def __init__ (self , opt ):
@@ -38,6 +32,16 @@ def __init__(self, opt):
3832 self .batch_size = opt .get ('batch_size' , 1 )
3933 self .seq_per_img = 1
4034
35+ # Load resnet
36+ self .cnn_model = opt .get ('cnn_model' , 'resnet101' )
37+ resnet = getattr (resnet , self .cnn_model )()
38+ resnet .load_state_dict (torch .load ('./data/imagenet_weights/' + self .cnn_model + '.pth' ))
39+ self .my_resnet = myResnet (resnet )
40+ self .my_resnet .cuda ()
41+ self .my_resnet .eval ()
42+
43+
44+
4145 # load the json file which contains additional information about the dataset
4246 print ('DataLoaderRaw loading images from folder: ' , self .folder_path )
4347
@@ -106,7 +110,7 @@ def get_batch(self, split, batch_size=None):
106110 img = img .astype ('float32' )/ 255.0
107111 img = torch .from_numpy (img .transpose ([2 ,0 ,1 ])).cuda ()
108112 img = Variable (preprocess (img ), volatile = True )
109- tmp_fc , tmp_att = my_resnet (img )
113+ tmp_fc , tmp_att = self . my_resnet (img )
110114
111115 fc_batch [i ] = tmp_fc .data .cpu ().float ().numpy ()
112116 att_batch [i ] = tmp_att .data .cpu ().float ().numpy ()
0 commit comments