Skip to content

Commit 38537dd

Browse files
add log likehood train function
1 parent 72920a9 commit 38537dd

File tree

2 files changed

+48
-11
lines changed

2 files changed

+48
-11
lines changed

python/dnlp/core/dnn_crf.py

Lines changed: 46 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,10 +7,11 @@
77

88

99
class DnnCrf(DnnCrfBase):
10-
def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf.float32, mode: str = 'train',nn:str, model_path:str=''):
10+
def __init__(self, *, config: DnnCrfConfig = None, data_path: str = '', dtype: type = tf.float32, mode: str = 'train',
11+
train:str='ll',nn: str, model_path: str = ''):
1112
if mode not in ['train', 'predict']:
1213
raise Exception('mode error')
13-
if nn not in ['mlp','lstm','gru']:
14+
if nn not in ['mlp', 'rnn', 'lstm', 'gru']:
1415
raise Exception('name of neural network entered is not supported')
1516

1617
DnnCrfBase.__init__(self, config, data_path, mode, model_path)
@@ -25,6 +26,8 @@ def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf
2526
# 输入层
2627
if mode == 'train':
2728
self.input = tf.placeholder(tf.int32, [self.batch_size, self.batch_length, self.windows_size])
29+
self.real_indices = tf.placeholder(tf.int32, [self.batch_size, self.batch_length])
30+
self.seq_length = tf.placeholder(tf.int32, [self.batch_size])
2831
else:
2932
self.input = tf.placeholder(tf.int32, [None, self.windows_size])
3033
# 查找表层
@@ -33,14 +36,21 @@ def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf
3336
if nn == 'mlp':
3437
self.hidden_layer = self.get_mlp_layer(tf.transpose(self.embedding_layer))
3538
elif nn == 'lstm':
36-
self.hidden_layer = self.get_lstm_layer(tf.transpose(self.embedding_layer))
37-
else:
39+
self.hidden_layer = self.get_lstm_layer(self.embedding_layer)
40+
elif nn == 'gru':
3841
self.hidden_layer = self.get_gru_layer(tf.transpose(self.embedding_layer))
42+
else:
43+
self.hidden_layer = self.get_rnn_layer(tf.transpose(self.embedding_layer))
3944
# 输出层
4045
self.output = self.get_output_layer(self.hidden_layer)
4146

4247
if mode == 'predict':
4348
self.output = tf.squeeze(self.output, axis=2)
49+
elif train == 'll':
50+
self.ll_loss, _ = tf.contrib.crf.crf_log_likelihood(self.output, self.real_indices, self.seq_length,
51+
self.transition)
52+
self.optimizer = tf.train.AdagradOptimizer(self.learning_rate)
53+
self.train_ll = self.optimizer.minimize(-self.ll_loss)
4454
else:
4555
# 构建训练函数
4656
# 训练用placeholder
@@ -56,6 +66,7 @@ def __init__(self, *,config: DnnCrfConfig, data_path: str = '', dtype: type = tf
5666
self.train = self.optimizer.minimize(self.loss)
5767
self.train_with_init = self.optimizer.minimize(self.loss_with_init)
5868

69+
5970
def fit(self, epochs: int = 100, interval: int = 20):
6071
with tf.Session() as sess:
6172
tf.global_variables_initializer().run()
@@ -119,6 +130,26 @@ def fit_batch(self, characters, labels, lengths, sess):
119130
feed_dict[self.trans_init_curr] = trans_init_neg_indices
120131
sess.run(self.train_with_init, feed_dict)
121132

133+
def fit_ll(self,epochs: int = 100, interval: int = 20):
134+
with tf.Session() as sess:
135+
tf.global_variables_initializer().run()
136+
saver = tf.train.Saver(max_to_keep=epochs)
137+
for epoch in range(1, epochs + 1):
138+
print('epoch:', epoch)
139+
for _ in range(self.batch_count):
140+
characters, labels, lengths = self.get_batch()
141+
#scores = sess.run(self.output, feed_dict={self.input: characters})
142+
feed_dict = {self.input: characters, self.real_indices:labels, self.seq_length:lengths}
143+
sess.run(self.train_ll, feed_dict=feed_dict)
144+
# self.fit_batch(characters, labels, lengths, sess)
145+
# if epoch % interval == 0:
146+
model_path = '../dnlp/models/cws{0}.ckpt'.format(epoch)
147+
saver.save(sess, model_path)
148+
self.save_config(model_path)
149+
150+
def fit_batch_ll(self):
151+
pass
152+
122153
def generate_transition_update_index(self, correct_labels, current_labels):
123154
if correct_labels.shape != current_labels.shape:
124155
print('sequence length is not equal')
@@ -176,11 +207,17 @@ def get_mlp_layer(self, layer: tf.Tensor) -> tf.Tensor:
176207
layer = tf.sigmoid(tf.tensordot(hidden_weight, layer, [[1], [0]]) + hidden_bias)
177208
return layer
178209

210+
def get_rnn_layer(self, layer: tf.Tensor) -> tf.Tensor:
211+
rnn = tf.nn.rnn_cell.RNNCell(self.hidden_units)
212+
rnn_output, rnn_out_state = tf.nn.dynamic_rnn(rnn, layer, dtype=self.dtype)
213+
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
214+
return tf.transpose(rnn_output)
215+
179216
def get_lstm_layer(self, layer: tf.Tensor) -> tf.Tensor:
180-
lstm = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_units)
217+
lstm = tf.nn.rnn_cell.LSTMCell(self.hidden_units)
181218
lstm_output, lstm_out_state = tf.nn.dynamic_rnn(lstm, layer, dtype=self.dtype)
182219
self.params += [v for v in tf.global_variables() if v.name.startswith('rnn')]
183-
return tf.transpose(lstm_output)
220+
return lstm_output
184221

185222
def get_gru_layer(self, layer: tf.Tensor) -> tf.Tensor:
186223
gru = tf.nn.rnn_cell.GRUCell(self.hidden_units)
@@ -192,10 +229,10 @@ def get_dropout_layer(self, layer: tf.Tensor) -> tf.Tensor:
192229
return tf.layers.dropout(layer, self.dropout_rate)
193230

194231
def get_output_layer(self, layer: tf.Tensor) -> tf.Tensor:
195-
output_weight = self.__get_variable([self.tags_count, self.hidden_units], 'output_weight')
196-
output_bias = self.__get_variable([self.tags_count, 1, 1], 'output_bias')
232+
output_weight = self.__get_variable([self.hidden_units,self.tags_count], 'output_weight')
233+
output_bias = self.__get_variable([1, 1, self.tags_count ], 'output_bias')
197234
self.params += [output_weight, output_bias]
198-
return tf.tensordot(output_weight, layer, [[1], [0]]) + output_bias
235+
return tf.tensordot( layer,output_weight, [[2], [0]]) + output_bias
199236

200237
def get_loss(self) -> (tf.Tensor, tf.Tensor):
201238
output_loss = tf.reduce_sum(tf.gather_nd(self.output, self.ll_curr) - tf.gather_nd(self.output, self.ll_corr))

python/dnlp/core/dnn_crf_base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
class DnnCrfBase(object):
9-
def __init__(self, config: DnnCrfConfig, data_path: str = '', mode: str = 'train', model_path: str = ''):
9+
def __init__(self, config: DnnCrfConfig=None, data_path: str = '', mode: str = 'train', model_path: str = ''):
1010
# 加载数据
1111
self.data_path = data_path
1212
self.config_suffix = '.config.pickle'
@@ -81,7 +81,7 @@ def get_batch(self) -> (np.ndarray, np.ndarray, np.ndarray):
8181
else:
8282
ext_size = self.batch_length - len(chs)
8383
chs_batch[i] = chs + ext_size * [self.dictionary[BATCH_PAD]]
84-
lls_batch[i] = list(map(lambda t: self.tags_map[t], lls)) + ext_size * [self.tags_map[TAG_PAD]]
84+
lls_batch[i] = list(map(lambda t: self.tags_map[t], lls)) + ext_size * [0]#[self.tags_map[TAG_PAD]]
8585

8686
self.batch_start = new_start
8787
return self.indices2input(chs_batch), np.array(lls_batch, dtype=np.int32), np.array(len_batch, dtype=np.int32)

0 commit comments

Comments
 (0)