-
Notifications
You must be signed in to change notification settings - Fork 2.9k
/
Copy pathutils.py
140 lines (119 loc) · 3.85 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
"""
Arguments for configuration
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
import argparse
import io
import sys
import random
import numpy as np
import os
import paddle
import paddle.fluid as fluid
def str2bool(v):
"""
String to Boolean
"""
# because argparse does not support to parse "true, False" as python
# boolean directly
return v.lower() in ("true", "t", "1")
class ArgumentGroup(object):
"""
Argument Class
"""
def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des)
def add_arg(self, name, type, default, help, **kwargs):
"""
Add argument
"""
type = str2bool if type == bool else type
self._group.add_argument(
"--" + name,
default=default,
type=type,
help=help + ' Default: %(default)s.',
**kwargs)
def print_arguments(args):
"""
Print Arguments
"""
print('----------- Configuration Arguments -----------')
for arg, value in sorted(six.iteritems(vars(args))):
print('%s: %s' % (arg, value))
print('------------------------------------------------')
def init_checkpoint(exe, init_checkpoint_path, main_program):
"""
Init CheckPoint
"""
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
try:
checkpoint_path = os.path.join(init_checkpoint_path, "checkpoint")
fluid.load(main_program, checkpoint_path, exe)
except:
fluid.load(main_program, init_checkpoint_path, exe)
print("Load model from {}".format(init_checkpoint_path))
def data_reader(file_path, word_dict, num_examples, phrase, epoch, max_seq_len):
"""
Convert word sequence into slot
"""
unk_id = len(word_dict)
pad_id = 0
all_data = []
with io.open(file_path, "r", encoding='utf8') as fin:
for line in fin:
if line.startswith('text_a'):
continue
cols = line.strip().split("\t")
if len(cols) != 2:
sys.stderr.write("[NOTICE] Error Format Line!")
continue
label = int(cols[1])
wids = [word_dict[x] if x in word_dict else unk_id
for x in cols[0].split(" ")]
seq_len = len(wids)
if seq_len < max_seq_len:
for i in range(max_seq_len - seq_len):
wids.append(pad_id)
else:
wids = wids[:max_seq_len]
seq_len = max_seq_len
all_data.append((wids, label, seq_len))
if phrase == "train":
random.shuffle(all_data)
num_examples[phrase] = len(all_data)
def reader():
"""
Reader Function
"""
for epoch_index in range(epoch):
for doc, label, seq_len in all_data:
yield doc, label, seq_len
return reader
def load_vocab(file_path):
"""
load the given vocabulary
"""
vocab = {}
with io.open(file_path, 'r', encoding='utf8') as f:
wid = 0
for line in f:
if line.strip() not in vocab:
vocab[line.strip()] = wid
wid += 1
vocab["<unk>"] = len(vocab)
return vocab
def init_pretraining_params(exe,
pretraining_params_path,
main_program,
use_fp16=False):
"""load params of pretrained model, NOT including moment, learning_rate"""
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
fluid.load(main_program, pretraining_params_path, exe)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))