32
32
import numpy as np
33
33
import pandas as pd
34
34
35
-
35
+ import tensorflow as tf
36
+ #with tf.device('/gpu:0'):
37
+ #with tf.device('/cpu:0'):
36
38
def train_a_font (input_filters_dict ,output_feature_list , nEpochs = 5000 ):
37
39
38
40
ds = ocr_utils .read_data (input_filters_dict = input_filters_dict ,
@@ -47,7 +49,7 @@ def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
47
49
48
50
""" # ==============================================================================
49
51
50
- import tensorflow as tf
52
+
51
53
sess = tf .InteractiveSession ()
52
54
53
55
"""# ==============================================================================
@@ -243,7 +245,7 @@ def max_pool_2x2(x):
243
245
for i in range (4 ):
244
246
tm += str (tp [i ])+ '-'
245
247
tm += str (tp [4 ])
246
- writer = tf .train .SummaryWriter ("/tmp/ds_logs/" + tm , sess .graph_def )
248
+ writer = tf .train .SummaryWriter ("/tmp/ds_logs/" + tm , sess .graph )
247
249
248
250
# To see the results in Chrome,
249
251
# Run the following in terminal to activate server.
@@ -338,7 +340,7 @@ def computeSize(s,tens):
338
340
# input_filters_dict = {'font': ('OCRA','OCRB'), 'fontVariant':('scanned',)}
339
341
340
342
# select everything; all fonts , font variants, etc.
341
- # input_filters_dict = {}
343
+ input_filters_dict = {}
342
344
343
345
# select the digits 0 through 9 in the E13B font
344
346
# input_filters_dict = {'m_label': range(48,58), 'font': 'E13B'}
@@ -356,9 +358,9 @@ def computeSize(s,tens):
356
358
#output_feature_list = ['font_one_hot','image','italic','aspect_ratio','upper_case']
357
359
358
360
# train the digits 0-9 for all fonts
359
- input_filters_dict = {'m_label' : range (48 ,58 )}
361
+ # input_filters_dict = {'m_label': range(48,58)}
360
362
output_feature_list = ['m_label_one_hot' ,'image' ,'italic' ,'aspect_ratio' ,'upper_case' ]
361
- train_a_font (input_filters_dict , output_feature_list , nEpochs = 5000 )
363
+ train_a_font (input_filters_dict , output_feature_list , nEpochs = 50000 )
362
364
363
365
else :
364
366
# loop through all the fonts and train individually
@@ -377,6 +379,6 @@ def computeSize(s,tens):
377
379
input_filters_dict = {'font' : (l [0 ],)}
378
380
train_a_font (input_filters_dict ,output_feature_list , nEpochs = 500 )
379
381
380
-
382
+
381
383
print ('\n ########################### No Errors ####################################' )
382
384
0 commit comments