Skip to content

Commit 4880150

Browse files
authored
Futher fix the sample_beam in CaptionModel.
1 parent c16f982 commit 4880150

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

models/CaptionModel.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,9 +95,10 @@ def sample_beam(self, fc_feats, att_feats, opt={}):
9595

9696
self.done_beams = [[] for _ in range(batch_size)]
9797
for k in range(batch_size):
98-
state = self.init_hidden(fc_feats[k:k+1]).expand(beam_size, self.rnn_size)
9998
tmp_fc_feats = fc_feats[k:k+1].expand(beam_size, self.fc_feat_size)
10099
tmp_att_feats = att_feats[k:k+1].expand(*((beam_size,)+att_feats.size()[1:]))
100+
101+
state = self.init_hidden(tmp_fc_feats)
101102

102103
beam_seq = torch.LongTensor(self.seq_length, beam_size).zero_()
103104
beam_seq_logprobs = torch.FloatTensor(self.seq_length, beam_size).zero_()

0 commit comments

Comments
 (0)