7
7
8
8
9
9
class DnnCrf (DnnCrfBase ):
10
- def __init__ (self , * ,config : DnnCrfConfig , data_path : str = '' , dtype : type = tf .float32 , mode : str = 'train' ,nn :str , model_path :str = '' ):
10
+ def __init__ (self , * , config : DnnCrfConfig = None , data_path : str = '' , dtype : type = tf .float32 , mode : str = 'train' ,
11
+ train :str = 'll' ,nn : str , model_path : str = '' ):
11
12
if mode not in ['train' , 'predict' ]:
12
13
raise Exception ('mode error' )
13
- if nn not in ['mlp' ,' lstm' ,'gru' ]:
14
+ if nn not in ['mlp' , 'rnn' , ' lstm' , 'gru' ]:
14
15
raise Exception ('name of neural network entered is not supported' )
15
16
16
17
DnnCrfBase .__init__ (self , config , data_path , mode , model_path )
@@ -25,6 +26,8 @@ def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf
25
26
# 输入层
26
27
if mode == 'train' :
27
28
self .input = tf .placeholder (tf .int32 , [self .batch_size , self .batch_length , self .windows_size ])
29
+ self .real_indices = tf .placeholder (tf .int32 , [self .batch_size , self .batch_length ])
30
+ self .seq_length = tf .placeholder (tf .int32 , [self .batch_size ])
28
31
else :
29
32
self .input = tf .placeholder (tf .int32 , [None , self .windows_size ])
30
33
# 查找表层
@@ -33,14 +36,21 @@ def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf
33
36
if nn == 'mlp' :
34
37
self .hidden_layer = self .get_mlp_layer (tf .transpose (self .embedding_layer ))
35
38
elif nn == 'lstm' :
36
- self .hidden_layer = self .get_lstm_layer (tf . transpose ( self .embedding_layer ) )
37
- else :
39
+ self .hidden_layer = self .get_lstm_layer (self .embedding_layer )
40
+ elif nn == 'gru' :
38
41
self .hidden_layer = self .get_gru_layer (tf .transpose (self .embedding_layer ))
42
+ else :
43
+ self .hidden_layer = self .get_rnn_layer (tf .transpose (self .embedding_layer ))
39
44
# 输出层
40
45
self .output = self .get_output_layer (self .hidden_layer )
41
46
42
47
if mode == 'predict' :
43
48
self .output = tf .squeeze (self .output , axis = 2 )
49
+ elif train == 'll' :
50
+ self .ll_loss , _ = tf .contrib .crf .crf_log_likelihood (self .output , self .real_indices , self .seq_length ,
51
+ self .transition )
52
+ self .optimizer = tf .train .AdagradOptimizer (self .learning_rate )
53
+ self .train_ll = self .optimizer .minimize (- self .ll_loss )
44
54
else :
45
55
# 构建训练函数
46
56
# 训练用placeholder
@@ -56,6 +66,7 @@ def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf
56
66
self .train = self .optimizer .minimize (self .loss )
57
67
self .train_with_init = self .optimizer .minimize (self .loss_with_init )
58
68
69
+
59
70
def fit (self , epochs : int = 100 , interval : int = 20 ):
60
71
with tf .Session () as sess :
61
72
tf .global_variables_initializer ().run ()
@@ -119,6 +130,26 @@ def fit_batch(self, characters, labels, lengths, sess):
119
130
feed_dict [self .trans_init_curr ] = trans_init_neg_indices
120
131
sess .run (self .train_with_init , feed_dict )
121
132
133
+ def fit_ll (self ,epochs : int = 100 , interval : int = 20 ):
134
+ with tf .Session () as sess :
135
+ tf .global_variables_initializer ().run ()
136
+ saver = tf .train .Saver (max_to_keep = epochs )
137
+ for epoch in range (1 , epochs + 1 ):
138
+ print ('epoch:' , epoch )
139
+ for _ in range (self .batch_count ):
140
+ characters , labels , lengths = self .get_batch ()
141
+ #scores = sess.run(self.output, feed_dict={self.input: characters})
142
+ feed_dict = {self .input : characters , self .real_indices :labels , self .seq_length :lengths }
143
+ sess .run (self .train_ll , feed_dict = feed_dict )
144
+ # self.fit_batch(characters, labels, lengths, sess)
145
+ # if epoch % interval == 0:
146
+ model_path = '../dnlp/models/cws{0}.ckpt' .format (epoch )
147
+ saver .save (sess , model_path )
148
+ self .save_config (model_path )
149
+
150
+ def fit_batch_ll (self ):
151
+ pass
152
+
122
153
def generate_transition_update_index (self , correct_labels , current_labels ):
123
154
if correct_labels .shape != current_labels .shape :
124
155
print ('sequence length is not equal' )
@@ -176,11 +207,17 @@ def get_mlp_layer(self, layer: tf.Tensor) -> tf.Tensor:
176
207
layer = tf .sigmoid (tf .tensordot (hidden_weight , layer , [[1 ], [0 ]]) + hidden_bias )
177
208
return layer
178
209
210
+ def get_rnn_layer (self , layer : tf .Tensor ) -> tf .Tensor :
211
+ rnn = tf .nn .rnn_cell .RNNCell (self .hidden_units )
212
+ rnn_output , rnn_out_state = tf .nn .dynamic_rnn (rnn , layer , dtype = self .dtype )
213
+ self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
214
+ return tf .transpose (rnn_output )
215
+
179
216
def get_lstm_layer (self , layer : tf .Tensor ) -> tf .Tensor :
180
- lstm = tf .nn .rnn_cell .BasicLSTMCell (self .hidden_units )
217
+ lstm = tf .nn .rnn_cell .LSTMCell (self .hidden_units )
181
218
lstm_output , lstm_out_state = tf .nn .dynamic_rnn (lstm , layer , dtype = self .dtype )
182
219
self .params += [v for v in tf .global_variables () if v .name .startswith ('rnn' )]
183
- return tf . transpose ( lstm_output )
220
+ return lstm_output
184
221
185
222
def get_gru_layer (self , layer : tf .Tensor ) -> tf .Tensor :
186
223
gru = tf .nn .rnn_cell .GRUCell (self .hidden_units )
@@ -192,10 +229,10 @@ def get_dropout_layer(self, layer: tf.Tensor) -> tf.Tensor:
192
229
return tf .layers .dropout (layer , self .dropout_rate )
193
230
194
231
def get_output_layer (self , layer : tf .Tensor ) -> tf .Tensor :
195
- output_weight = self .__get_variable ([self .tags_count , self .hidden_units ], 'output_weight' )
196
- output_bias = self .__get_variable ([self . tags_count , 1 , 1 ], 'output_bias' )
232
+ output_weight = self .__get_variable ([self .hidden_units , self .tags_count ], 'output_weight' )
233
+ output_bias = self .__get_variable ([1 , 1 , self . tags_count ], 'output_bias' )
197
234
self .params += [output_weight , output_bias ]
198
- return tf .tensordot (output_weight , layer , [[1 ], [0 ]]) + output_bias
235
+ return tf .tensordot ( layer ,output_weight , [[2 ], [0 ]]) + output_bias
199
236
200
237
def get_loss (self ) -> (tf .Tensor , tf .Tensor ):
201
238
output_loss = tf .reduce_sum (tf .gather_nd (self .output , self .ll_curr ) - tf .gather_nd (self .output , self .ll_corr ))
0 commit comments