Skip to content

Commit 445e9a5

Browse files
committed
Set dtype consistently for all features.
1 parent 5e4be25 commit 445e9a5

4 files changed

+81
-76
lines changed

ocr_utils.py

+14-13
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,7 @@ class TruthedCharacters(object):
156156
Holds the training features and size information
157157
158158
"""
159-
def __init__(self, features, output_feature_list, one_hot_map, engine_type,h,w):
159+
def __init__(self, features, output_feature_list, one_hot_map, engine_type,h,w, dtype):
160160

161161
self._num_examples = features[0].shape[0]
162162
self._nRows = h
@@ -168,6 +168,7 @@ def __init__(self, features, output_feature_list, one_hot_map, engine_type,h,w)
168168
self._num_features = len(features)
169169
self._one_hot_map = one_hot_map # list >0 for each feature that is one_hot
170170
self._engine_type= engine_type
171+
self._dtype = dtype
171172

172173
self._feature_width=[]
173174
for i in range(self._num_features ):
@@ -210,7 +211,7 @@ def get_features(self, i, start, end):
210211
if n_hots==0:
211212
rtn=self.engine_conversion(t1, self._feature_names[i])
212213
else:
213-
rtn= self.engine_conversion(np.eye(n_hots )[t1], self._feature_names[i])
214+
rtn= self.engine_conversion(np.eye(n_hots, dtype=self._dtype )[t1], self._feature_names[i])
214215
return rtn
215216

216217
@property
@@ -222,7 +223,7 @@ def features(self):
222223
rtn.append(self.engine_conversion(t1, nm) )
223224
#assert(np.all(rtn[-1]==t1))
224225
else:
225-
rtn.append( self.engine_conversion(np.eye(n_hots )[t1], nm) )
226+
rtn.append( self.engine_conversion(np.eye(n_hots, dtype=self._dtype )[t1], nm) )
226227
return rtn
227228

228229
@property
@@ -271,7 +272,7 @@ def next_batch(self, batch_size):
271272
outs = []
272273
for i in range(self._num_features):
273274
outs += [self.get_features(i,start,end)]
274-
275+
275276
return outs
276277

277278
def dump_values(self):
@@ -500,19 +501,19 @@ class DataSets(object):
500501
feature_name=[]
501502
one_hot_map = []
502503

503-
for colName in output_feature_list:
504+
for colName in output_feature_list:
504505
one_hot_map.append(0)
505506
if colName=="aspect_ratio":
506-
t1 = np.array(df['originalW'] ,dtype=np.float32)
507-
t2 = np.array(df['originalH'] ,dtype=np.float32)
507+
t1 = np.array(df['originalW'] ,dtype=dtype)
508+
t2 = np.array(df['originalH'] ,dtype=dtype)
508509
t1 = t1[:]/t2[:]
509510
feature_name.append(colName)
510511

511512
elif colName=="upper_case":
512513
boolDF1 = df['m_label']>=64
513514
boolDF2 = df['m_label']<=90
514515
boolDF = boolDF1 & boolDF2
515-
t1 = np.array(boolDF,dtype=np.float32)
516+
t1 = np.array(boolDF,dtype=dtype)
516517
feature_name.append(colName)
517518

518519
elif colName=='image':
@@ -521,7 +522,7 @@ class DataSets(object):
521522
feature_name.append(colName)
522523

523524
elif colName=='m_label_one_hot':
524-
t1 = np.array(df['m_label'] )
525+
t1 = np.array(df['m_label'])
525526
t1 = convert_to_unique(t1)
526527
one_hot_map[-1] = len(np.unique(t1))
527528
feature_name.append(colName)
@@ -561,7 +562,7 @@ class DataSets(object):
561562

562563
else:
563564
if colName in df.columns :
564-
t1=np.array(df[colName])
565+
t1=np.array(df[colName], dtype=dtype)
565566
feature_name.append(colName)
566567
else:
567568
raise ValueError('Invalid ouput_feature_name: {}: it is not in the the database'.format(colName))
@@ -576,9 +577,9 @@ class DataSets(object):
576577
outvars_test.append( ot[:nTestCount])
577578
outvars_evaluation.append(ot[nTestCount:nTestCount+nEvaluationCount])
578579

579-
data_sets.train = TruthedCharacters(outvars_train, feature_name, one_hot_map, engine_type, h, w)
580-
data_sets.test = TruthedCharacters(outvars_test, feature_name, one_hot_map, engine_type, h, w)
581-
data_sets.evaluation = TruthedCharacters(outvars_evaluation,feature_name, one_hot_map, engine_type, h, w)
580+
data_sets.train = TruthedCharacters(outvars_train, feature_name, one_hot_map, engine_type, h, w, dtype)
581+
data_sets.test = TruthedCharacters(outvars_test, feature_name, one_hot_map, engine_type, h, w, dtype)
582+
data_sets.evaluation = TruthedCharacters(outvars_evaluation,feature_name, one_hot_map, engine_type, h, w, dtype)
582583
print ('feature results:')
583584
print ('\tnumber of train Images = ',nTrainCount)
584585
print ('\tnumber of test Images = ',nTestCount)

q2_tensorflow_mnist.py

+22-20
Original file line numberDiff line numberDiff line change
@@ -33,14 +33,16 @@
3333
import pandas as pd
3434

3535
import tensorflow as tf
36+
dtype=np.float32
3637
#with tf.device('/gpu:0'):
3738
#with tf.device('/cpu:0'):
3839
def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
3940

4041
ds = ocr_utils.read_data(input_filters_dict = input_filters_dict,
4142
output_feature_list=output_feature_list,
4243
test_size = .1,
43-
engine_type='tensorflow')
44+
engine_type='tensorflow',
45+
dtype=dtype)
4446

4547

4648
"""# ==============================================================================
@@ -85,7 +87,7 @@ def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
8587
nm = 'x_'+nm
8688
if i>1:
8789
extra_features_width += ds.train.feature_width[i]
88-
lst.append(tf.placeholder(tf.float32, shape=[None, ds.train.feature_width[i]], name=nm))
90+
lst.append(tf.placeholder(dtype, shape=[None, ds.train.feature_width[i]], name=nm))
8991

9092
# ph is a named tuple with key names like 'image', 'm_label', and values that
9193
# are tensors. The display name on the Chrome graph are 'y_m_label', 'x_image,
@@ -110,13 +112,13 @@ def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
110112
111113
"""# ==============================================================================
112114

113-
def weight_variable(shape):
114-
initial = tf.truncated_normal(shape, stddev=0.1)
115+
def weight_variable(shape, dtype):
116+
initial = tf.truncated_normal(shape, stddev=0.1,dtype=dtype)
115117
return tf.Variable(initial)
116118

117-
def bias_variable(shape):
118-
initial = tf.constant(0.1, shape=shape)
119-
return tf.Variable(initial)
119+
def bias_variable(shape, dtype):
120+
initial = tf.constant(0.1, shape=shape, dtype=dtype)
121+
return tf.Variable(initial)
120122

121123
"""# ==============================================================================
122124
@@ -139,8 +141,8 @@ def max_pool_2x2(x):
139141
140142
"""# ==============================================================================
141143
with tf.name_scope("w_conv1") as scope:
142-
W_conv1 = weight_variable([5, 5, 1, nConv1])
143-
b_conv1 = bias_variable([nConv1])
144+
W_conv1 = weight_variable([5, 5, 1, nConv1],dtype)
145+
b_conv1 = bias_variable([nConv1],dtype)
144146

145147
with tf.name_scope("reshape_x_image") as scope:
146148
x_image = tf.reshape(ph.image, [-1,nCols,nRows,1])
@@ -170,8 +172,8 @@ def max_pool_2x2(x):
170172
"""# ==============================================================================
171173

172174
with tf.name_scope("convolve_2") as scope:
173-
W_conv2 = weight_variable([5, 5, nConv1, nConv2])
174-
b_conv2 = bias_variable([64])
175+
W_conv2 = weight_variable([5, 5, nConv1, nConv2],dtype)
176+
b_conv2 = bias_variable([64],dtype)
175177
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
176178

177179
with tf.name_scope("pool_2") as scope:
@@ -189,8 +191,8 @@ def max_pool_2x2(x):
189191
"""# ==============================================================================
190192

191193
with tf.name_scope("W_fc1_b") as scope:
192-
W_fc1 = weight_variable([n_h_pool2_outputsx, nFc])
193-
b_fc1 = bias_variable([nFc])
194+
W_fc1 = weight_variable([n_h_pool2_outputsx, nFc],dtype)
195+
b_fc1 = bias_variable([nFc],dtype)
194196

195197
h_pool2_flat = tf.reshape(h_pool2, [-1, n_h_pool2_outputs])
196198

@@ -204,7 +206,7 @@ def max_pool_2x2(x):
204206
Dropout
205207
206208
"""# ==============================================================================
207-
keep_prob = tf.placeholder(tf.float32,name='keep_prob')
209+
keep_prob = tf.placeholder(dtype,name='keep_prob')
208210

209211
with tf.name_scope("drop") as scope:
210212
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
@@ -215,8 +217,8 @@ def max_pool_2x2(x):
215217
216218
"""# ==============================================================================
217219
with tf.name_scope("softmax") as scope:
218-
W_fc2 = weight_variable([nFc, nTarget])
219-
b_fc2 = bias_variable([nTarget])
220+
W_fc2 = weight_variable([nFc, nTarget],dtype)
221+
b_fc2 = bias_variable([nTarget],dtype)
220222
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
221223

222224
"""# ==============================================================================
@@ -236,7 +238,7 @@ def max_pool_2x2(x):
236238
with tf.name_scope("test") as scope:
237239
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(ph[0],1))
238240

239-
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
241+
accuracy = tf.reduce_mean(tf.cast(correct_prediction, dtype))
240242
accuracy_summary = tf.scalar_summary("accuracy", accuracy)
241243

242244
merged = tf.merge_all_summaries()
@@ -340,7 +342,7 @@ def computeSize(s,tens):
340342
# input_filters_dict = {'font': ('OCRA','OCRB'), 'fontVariant':('scanned',)}
341343

342344
# select everything; all fonts , font variants, etc.
343-
input_filters_dict = {}
345+
#input_filters_dict = {}
344346

345347
# select the digits 0 through 9 in the E13B font
346348
# input_filters_dict = {'m_label': range(48,58), 'font': 'E13B'}
@@ -358,9 +360,9 @@ def computeSize(s,tens):
358360
#output_feature_list = ['font_one_hot','image','italic','aspect_ratio','upper_case']
359361

360362
# train the digits 0-9 for all fonts
361-
#input_filters_dict = {'m_label': range(48,58)}
363+
input_filters_dict = {'m_label': range(48,58)}
362364
output_feature_list = ['m_label_one_hot','image','italic','aspect_ratio','upper_case']
363-
train_a_font(input_filters_dict, output_feature_list, nEpochs = 50000)
365+
train_a_font(input_filters_dict, output_feature_list, nEpochs = 5000)
364366

365367
else:
366368
# loop through all the fonts and train individually

q5_tensorflow_residual.py

+24-22
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,16 @@
3535
import pandas as pd
3636

3737
import tensorflow as tf
38+
dtype = np.float32
3839
#with tf.device('/GPU:0'):
3940
#with tf.device('/cpu:0'):
4041
def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
4142

4243
ds = ocr_utils.read_data(input_filters_dict = input_filters_dict,
4344
output_feature_list=output_feature_list,
4445
test_size = .1,
45-
engine_type='tensorflow')
46+
engine_type='tensorflow',
47+
dtype=dtype)
4648

4749

4850
"""# ==============================================================================
@@ -86,7 +88,7 @@ def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
8688
nm = 'x_'+nm
8789
if i>1:
8890
extra_features_width += ds.train.feature_width[i]
89-
lst.append(tf.placeholder(tf.float32, shape=[None, ds.train.feature_width[i]], name=nm))
91+
lst.append(tf.placeholder(dtype, shape=[None, ds.train.feature_width[i]], name=nm))
9092

9193
# ph is a named tuple with key names like 'image', 'm_label', and values that
9294
# are tensors. The display name on the Chrome graph are 'y_m_label', 'x_image,
@@ -113,13 +115,13 @@ def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
113115
114116
"""# ==============================================================================
115117

116-
def weight_variable(shape):
117-
initial = tf.truncated_normal(shape, stddev=0.1)
118+
def weight_variable(shape, dtype):
119+
initial = tf.truncated_normal(shape, stddev=0.1,dtype=dtype)
118120
return tf.Variable(initial)
119121

120-
def bias_variable(shape):
121-
initial = tf.constant(0.1, shape=shape)
122-
return tf.Variable(initial)
122+
def bias_variable(shape, dtype):
123+
initial = tf.constant(0.1, shape=shape, dtype=dtype)
124+
return tf.Variable(initial)
123125

124126
"""# ==============================================================================
125127
@@ -142,8 +144,8 @@ def max_pool_2x2(x):
142144
143145
"""# ==============================================================================
144146
with tf.name_scope("w_conv1") as scope:
145-
W_conv1 = weight_variable([5, 5, 1, nConv1])
146-
b_conv1 = bias_variable([nConv1])
147+
W_conv1 = weight_variable([5, 5, 1, nConv1],dtype)
148+
b_conv1 = bias_variable([nConv1],dtype)
147149

148150
with tf.name_scope("reshape_x_image") as scope:
149151
x_image = tf.reshape(ph.image, [-1,nCols,nRows,1])
@@ -173,8 +175,8 @@ def max_pool_2x2(x):
173175
"""# ==============================================================================
174176

175177
with tf.name_scope("convolve_2") as scope:
176-
W_conv2 = weight_variable([5, 5, nConv1, nConv2])
177-
b_conv2 = bias_variable([64])
178+
W_conv2 = weight_variable([5, 5, nConv1, nConv2],dtype)
179+
b_conv2 = bias_variable([64],dtype)
178180
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
179181

180182
with tf.name_scope("pool_2") as scope:
@@ -192,8 +194,8 @@ def max_pool_2x2(x):
192194
"""# ==============================================================================
193195

194196
with tf.name_scope("W_fc0_b") as scope:
195-
W_fc0 = weight_variable([n_h_pool2_outputsx, nFc0])
196-
b_fc0 = bias_variable([nFc0])
197+
W_fc0 = weight_variable([n_h_pool2_outputsx, nFc0],dtype)
198+
b_fc0 = bias_variable([nFc0],dtype)
197199

198200
h_pool2_flat = tf.reshape(h_pool2, [-1, n_h_pool2_outputs])
199201

@@ -213,8 +215,8 @@ def max_pool_2x2(x):
213215
"""# ==============================================================================
214216

215217
with tf.name_scope("W_fc1_b") as scope:
216-
W_fc1 = weight_variable([nFc0, nFc1])
217-
b_fc1 = bias_variable([nFc1])
218+
W_fc1 = weight_variable([nFc0, nFc1],dtype)
219+
b_fc1 = bias_variable([nFc1],dtype)
218220

219221
h_fc1 = tf.nn.relu(tf.matmul(h_fc0, W_fc1) + b_fc1)
220222

@@ -230,16 +232,16 @@ def max_pool_2x2(x):
230232
"""# ==============================================================================
231233

232234
with tf.name_scope("W_fc2_b") as scope:
233-
W_fc2 = weight_variable([nFc1, nFc2])
234-
b_fc2 = bias_variable([nFc2])
235+
W_fc2 = weight_variable([nFc1, nFc2],dtype)
236+
b_fc2 = bias_variable([nFc2],dtype)
235237

236238
h_fc2 = tf.nn.relu(tf.matmul(h_fc1, W_fc2) + b_fc2)
237239

238240
"""# ==============================================================================
239241
Dropout
240242
241243
"""# ==============================================================================
242-
keep_prob = tf.placeholder(tf.float32,name='keep_prob')
244+
keep_prob = tf.placeholder(dtype,name='keep_prob')
243245

244246
with tf.name_scope("drop") as scope:
245247
h_fc2_drop = tf.nn.dropout(h_fc2, keep_prob)
@@ -250,8 +252,8 @@ def max_pool_2x2(x):
250252
251253
"""# ==============================================================================
252254
with tf.name_scope("softmax") as scope:
253-
W_fc3 = weight_variable([nFc2, nTarget])
254-
b_fc3 = bias_variable([nTarget])
255+
W_fc3 = weight_variable([nFc2, nTarget],dtype)
256+
b_fc3 = bias_variable([nTarget],dtype)
255257
y_conv=tf.nn.softmax(tf.matmul(h_fc2_drop, W_fc3) + b_fc3)
256258

257259
"""# ==============================================================================
@@ -271,7 +273,7 @@ def max_pool_2x2(x):
271273
with tf.name_scope("test") as scope:
272274
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(ph[0],1))
273275

274-
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
276+
accuracy = tf.reduce_mean(tf.cast(correct_prediction,dtype))
275277
accuracy_summary = tf.scalar_summary("accuracy", accuracy)
276278

277279
merged = tf.merge_all_summaries()
@@ -399,7 +401,7 @@ def computeSize(s,tens):
399401
#output_feature_list = ['font_one_hot','image','italic','aspect_ratio','upper_case']
400402

401403
# train the digits 0-9 for all fonts
402-
input_filters_dict = {'m_label': list(range(48,58))+list(range(65,91))+list(range(97,123))}
404+
input_filters_dict = {'m_label': list(range(48,58))+list(range(65,91))+list(range(97,123)),'fontVariant':'scanned'}
403405
#input_filters_dict = {}
404406
output_feature_list = ['m_label_one_hot','image','italic','aspect_ratio','upper_case']
405407
train_a_font(input_filters_dict, output_feature_list, nEpochs = 20000)

0 commit comments

Comments
 (0)