Skip to content

Commit e524aaa

Browse files
committed
refactor(ThreeLayerNet): 添加forward操作中输入数据的变形
1 parent 7f2d6a6 commit e524aaa

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

pynet/models/ThreeLayerNet.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ def __call__(self, inputs):
4343
return self.forward(inputs)
4444

4545
def forward(self, inputs):
46-
# inputs.shape = [N, D_in]
47-
assert len(inputs.shape) == 2
46+
inputs = inputs.reshape(inputs.shape[0], -1)
47+
4848
self.z1, self.z1_cache = self.fc1(inputs, self.params['W1'], self.params['b1'])
4949
a1 = self.relu(self.z1)
5050
if self.use_dropout and self.dropout_param['mode'] == 'train':

0 commit comments

Comments
 (0)