|
1 | 1 | # -*- coding: utf-8 -*-
|
2 | 2 |
|
3 |
| -# @Time : 19-6-20 下午2:29 |
| 3 | +# @Time : 19-7-2 上午9:53 |
4 | 4 | # @Author : zj
|
5 | 5 |
|
6 | 6 |
|
|
9 | 9 | from .pool2row import *
|
10 | 10 | from .Layer import *
|
11 | 11 |
|
12 |
| -__all__ = ['Conv2d'] |
| 12 | +__all__ = ['Conv2d2'] |
13 | 13 |
|
14 | 14 |
|
15 |
| -class Conv2d(Layer): |
| 15 | +class Conv2d2: |
16 | 16 | """
|
17 | 17 | convolutional layer
|
18 | 18 | 卷积层
|
19 | 19 | """
|
20 | 20 |
|
21 |
| - def __init__(self, in_c, filter_h, filter_w, filter_num, stride=1, padding=0, momentum=0, nesterov=False): |
22 |
| - super(Conv2d, self).__init__() |
| 21 | + def __init__(self, in_c, filter_h, filter_w, filter_num, stride=1, padding=0, weight_scale=1e-2): |
| 22 | + """ |
| 23 | + :param in_c: 输入数据体通道数 |
| 24 | + :param filter_h: 滤波器长 |
| 25 | + :param filter_w: 滤波器宽 |
| 26 | + :param filter_num: 滤波器个数 |
| 27 | + :param stride: 步长 |
| 28 | + :param padding: 零填充 |
| 29 | + :param weight_scale: |
| 30 | + """ |
| 31 | + super(Conv2d2, self).__init__() |
23 | 32 | self.in_c = in_c
|
24 | 33 | self.filter_h = filter_h
|
25 | 34 | self.filter_w = filter_w
|
26 | 35 | self.filter_num = filter_num
|
27 | 36 | self.stride = stride
|
28 | 37 | self.padding = padding
|
| 38 | + self.weight_scale = weight_scale |
29 | 39 |
|
30 |
| - self.W = \ |
31 |
| - {'val': 0.01 * np.random.normal(loc=0, scale=1.0, size=(filter_h * filter_w * in_c, filter_num)), |
32 |
| - 'grad': 0, |
33 |
| - 'v': 0, |
34 |
| - 'momentum': momentum, |
35 |
| - 'nesterov': nesterov} |
36 |
| - self.b = {'val': 0.01 * np.random.normal(loc=0, scale=1.0, size=(1, filter_num)), 'grad': 0} |
37 |
| - self.a = None |
38 |
| - self.input_shape = None |
| 40 | + def __call__(self, inputs, w, b): |
| 41 | + return self.forward(inputs, w, b) |
39 | 42 |
|
40 |
| - def __call__(self, inputs): |
41 |
| - return self.forward(inputs) |
42 |
| - |
43 |
| - def forward(self, inputs): |
| 43 | + def forward(self, inputs, w, b): |
44 | 44 | # input.shape == [N, C, H, W]
|
45 | 45 | assert len(inputs.shape) == 4
|
46 | 46 | N, C, H, W = inputs.shape[:4]
|
47 | 47 | out_h = int((H - self.filter_h + 2 * self.padding) / self.stride + 1)
|
48 | 48 | out_w = int((W - self.filter_w + 2 * self.padding) / self.stride + 1)
|
49 | 49 |
|
50 | 50 | a = im2row_indices(inputs, self.filter_h, self.filter_w, stride=self.stride, padding=self.padding)
|
51 |
| - z = a.dot(self.W['val']) + self.b['val'] |
52 |
| - |
53 |
| - self.input_shape = inputs.shape |
54 |
| - self.a = a.copy() |
| 51 | + z = a.dot(w) + b |
55 | 52 |
|
56 | 53 | out = conv_fc2output(z, N, out_h, out_w)
|
57 |
| - return out |
| 54 | + cache = (a, inputs.shape, w, b) |
| 55 | + return out, cache |
58 | 56 |
|
59 |
| - def backward(self, grad_out): |
| 57 | + def backward(self, grad_out, cache): |
60 | 58 | assert len(grad_out.shape) == 4
|
61 | 59 |
|
62 |
| - dz = conv_output2fc(grad_out) |
63 |
| - self.W['grad'] = self.a.T.dot(dz) |
64 |
| - self.b['grad'] = np.sum(dz, axis=0, keepdims=True) / dz.shape[0] |
65 |
| - |
66 |
| - da = dz.dot(self.W['val'].T) |
67 |
| - return row2im_indices(da, self.input_shape, field_height=self.filter_h, |
68 |
| - field_width=self.filter_w, stride=self.stride, padding=self.padding) |
69 |
| - |
70 |
| - def update(self, learning_rate=0, regularization_rate=0): |
71 |
| - v_prev = self.W['v'] |
72 |
| - self.W['v'] = self.W['momentum'] * self.W['v'] - learning_rate * ( |
73 |
| - self.W['grad'] + regularization_rate * self.W['val']) |
74 |
| - if self.W['nesterov']: |
75 |
| - self.W['val'] += (1 + self.W['momentum']) * self.W['v'] - self.W['momentum'] * v_prev |
76 |
| - else: |
77 |
| - self.W['val'] += self.W['v'] |
78 |
| - self.b['val'] -= learning_rate * (self.b['grad']) |
| 60 | + a, input_shape, w, b = cache |
79 | 61 |
|
80 |
| - def get_params(self): |
81 |
| - return {'W': self.W['val'], 'momentum': self.W['momentum'], 'nesterov': self.W['nesterov'], 'b': self.b['val']} |
| 62 | + dz = conv_output2fc(grad_out) |
| 63 | + grad_W = a.T.dot(dz) |
| 64 | + grad_b = np.sum(dz, axis=0, keepdims=True) / dz.shape[0] |
82 | 65 |
|
83 |
| - def set_params(self, params): |
84 |
| - self.W['val'] = params.get('W') |
85 |
| - self.b['val'] = params.get('b') |
| 66 | + da = dz.dot(w.T) |
| 67 | + return grad_W, grad_b, row2im_indices(da, input_shape, field_height=self.filter_h, |
| 68 | + field_width=self.filter_w, stride=self.stride, padding=self.padding) |
86 | 69 |
|
87 |
| - self.W['momentum'] = params.get('momentum', 0.0) |
88 |
| - self.W['nesterov'] = params.get('nesterov', False) |
| 70 | + def get_params(self): |
| 71 | + return self.weight_scale * np.random.normal(loc=0, scale=1.0, size=( |
| 72 | + self.filter_h * self.filter_w * self.in_c, self.filter_num)), \ |
| 73 | + self.weight_scale * np.random.normal(loc=0, scale=1.0, size=(1, self.filter_num)) |
0 commit comments