Skip to content

Commit a2daa18

Browse files
committed
allow any feature to be the target. Only m_label_one_hot was being
handled.
1 parent 44fda8a commit a2daa18

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

q2_tensorflow_mnist.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def train_a_font(input_filters_dict,output_feature_list, nEpochs=5000):
6868

6969
for i,nm in enumerate(output_feature_list):
7070

71-
# features[0], is the target, 'm_label_one_hot'
71+
# features[0], is always the target. For instance it may be m_label_one_hot
7272
# the second features[1] is the 'image' that is passed to the convolution layers
7373
# Any additional features bypass the convolution layers and go directly
7474
# into the fully connected layer.
@@ -225,14 +225,14 @@ def max_pool_2x2(x):
225225

226226
with tf.name_scope("xent") as scope:
227227
# 1e-8 added to eliminate the crash of training when taking log of 0
228-
cross_entropy = -tf.reduce_sum(ph.m_label_one_hot*tf.log(y_conv+1e-8))
228+
cross_entropy = -tf.reduce_sum(ph[0]*tf.log(y_conv+1e-8))
229229
ce_summ = tf.scalar_summary("cross entropy", cross_entropy)
230230

231231
with tf.name_scope("train") as scope:
232232
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
233233

234234
with tf.name_scope("test") as scope:
235-
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(ph.m_label_one_hot,1))
235+
correct_prediction = tf.equal(tf.argmax(y_conv,1), tf.argmax(ph[0],1))
236236

237237
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
238238
accuracy_summary = tf.scalar_summary("accuracy", accuracy)
@@ -351,6 +351,9 @@ def computeSize(s,tens):
351351

352352
# output only the character label and the image
353353
# output_feature_list = ['m_label_one_hot','image']
354+
355+
# identify the font given the input images
356+
#output_feature_list = ['font_one_hot','image','italic','aspect_ratio','upper_case']
354357

355358
# train the digits 0-9 for all fonts
356359
input_filters_dict = {'m_label': range(48,58)}

0 commit comments

Comments
 (0)