Skip to content

Commit a793973

Browse files
add evaluation functions
1 parent 5436518 commit a793973

File tree

2 files changed

+97
-2
lines changed

2 files changed

+97
-2
lines changed

python/dnlp/utils/evaluation.py

Lines changed: 93 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,93 @@
1+
# -*- coding: UTF-8 -*-
2+
import pickle
3+
from dnlp.utils.constant import TAG_BEGIN, TAG_INSIDE, TAG_OTHER, TAG_END, TAG_SINGLE
4+
5+
6+
def get_cws_statistics(correct_labels, predict_labels) -> (int, int, int):
7+
if len(correct_labels) != len(predict_labels):
8+
raise Exception('length of correct labels and predict labels is not equal')
9+
10+
true_positive_count = 0
11+
corrects = {}
12+
predicts = {}
13+
correct_start = 0
14+
predict_start = 0
15+
16+
for i, (correct_label, predict_label) in enumerate(zip(correct_labels, predict_labels)):
17+
if correct_label == TAG_BEGIN:
18+
correct_start = i
19+
corrects[correct_start] = correct_start
20+
elif correct_label == TAG_SINGLE:
21+
correct_start = i
22+
corrects[correct_start] = correct_start
23+
elif correct_label == TAG_INSIDE or correct_label == TAG_END:
24+
corrects[correct_start] = i
25+
26+
if predict_label == TAG_BEGIN:
27+
predict_start = i
28+
predicts[predict_start] = predict_start
29+
elif predict_label == TAG_SINGLE:
30+
predict_start = i
31+
predicts[predict_start] = predict_start
32+
elif predict_label == TAG_INSIDE or predict_label == TAG_END:
33+
predicts[predict_start] = i
34+
35+
for predict in predicts:
36+
if corrects.get(predict) is not None and corrects[predict] == predicts[predict]:
37+
true_positive_count += 1
38+
39+
return true_positive_count, len(predicts), len(corrects)
40+
41+
42+
def get_ner_statistics(correct_labels, predict_labels) -> (int, int, int):
43+
if len(correct_labels) != len(predict_labels):
44+
raise Exception('length of correct labels and predict labels is not equal')
45+
46+
true_positive_count = 0
47+
corrects = {}
48+
predicts = {}
49+
correct_start = 0
50+
predict_start = 0
51+
52+
for i, (correct_label, predict_label) in enumerate(zip(correct_labels, predict_labels)):
53+
if correct_label == TAG_BEGIN:
54+
correct_start = i
55+
corrects[correct_start] = correct_start
56+
elif correct_label == TAG_INSIDE:
57+
corrects[correct_start] = i
58+
59+
if predict_label == TAG_BEGIN:
60+
predict_start = i
61+
predicts[predict_start] = predict_start
62+
elif predict_label == TAG_INSIDE:
63+
predicts[predict_start] = i
64+
65+
for predict in predicts:
66+
if corrects.get(predict) is not None and corrects[predict] == predicts[predict]:
67+
true_positive_count += 1
68+
69+
return true_positive_count, len(predicts), len(corrects)
70+
71+
72+
def evaluate_cws(model, data_path: str):
73+
with open(data_path, 'rb') as f:
74+
data = pickle.load(f)
75+
dictionary = data['dictionary']
76+
tags = data['tags']
77+
reversed_map = dict(zip(tags.values(), tags.keys()))
78+
characters = data['characters']
79+
labels_true = data['labels']
80+
c_count = 0
81+
p_count = 0
82+
r_count = -0
83+
for sentence, label in enumerate(characters, labels_true):
84+
words, labels_predict = model.predict(sentence, return_labels=True)
85+
seq = []
86+
for l in zip(labels_predict):
87+
seq.append(reversed_map[l])
88+
c, p, r = get_cws_statistics(label, seq)
89+
c_count += c
90+
p_count += p
91+
r_count += r
92+
print(c / p)
93+
print(c / r)

python/scripts/cws_ner.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,12 +3,13 @@
33
import getopt
44
from dnlp.config.config import DnnCrfConfig
55
from dnlp.core.dnn_crf import DnnCrf
6+
from dnlp.utils.evaluation import get_cws_statistics, evaluate_cws
67

78

89
def train_cws():
910
data_path = '../dnlp/data/cws/pku_training.pickle'
1011
config = DnnCrfConfig()
11-
dnncrf = DnnCrf(config=config, data_path=data_path,nn='lstm')
12+
dnncrf = DnnCrf(config=config, data_path=data_path, nn='lstm')
1213
dnncrf.fit_ll()
1314

1415

@@ -17,8 +18,9 @@ def test_cws():
1718
model_path = '../dnlp/models/cws1.ckpt'
1819
config = DnnCrfConfig()
1920
dnncrf = DnnCrf(config=config, mode='predict', model_path=model_path, nn='lstm')
20-
res = dnncrf.predict(sentence)
21+
res, labels = dnncrf.predict(sentence, return_labels=True)
2122
print(res)
23+
evaluate_cws(dnncrf, '../dnlp/data/cws/pku_test.pickle')
2224

2325

2426
if __name__ == '__main__':

0 commit comments

Comments
 (0)