3
3
author: Ye Hu
4
4
2016/12/18
5
5
"""
6
+ import os
6
7
import timeit
7
8
import numpy as np
8
9
import tensorflow as tf
10
+ from PIL import Image
11
+ from utils import tile_raster_images
12
+ import input_data
9
13
from rbm import RBM
10
14
11
15
@@ -56,39 +60,97 @@ def free_energy(self, v_sample):
56
60
hidden_term = tf .reduce_sum (tf .log (1.0 + tf .exp (wx_b )), axis = 1 )
57
61
return - hidden_term + vbias_term
58
62
63
+ def get_reconstruction_cost (self ):
64
+ """Compute the mse of the original input and the reconstruction"""
65
+ activation_h = self .propup (self .input )
66
+ activation_v = self .propdown (activation_h )
67
+ mse = tf .reduce_mean (tf .reduce_sum (tf .square (self .input - activation_v ), axis = 1 ))
68
+ return mse
69
+
70
+
59
71
60
72
if __name__ == "__main__" :
61
- data = np .random .randn (1000 , 6 )
62
- x = tf .placeholder (tf .float32 , shape = [None , 6 ])
63
-
64
- gbrbm = GBRBM (x , n_visiable = 6 , n_hidden = 5 )
65
-
66
- learning_rate = 0.1
67
- k = 1
68
- batch_size = 20
69
- n_epochs = 10
70
-
71
- cost = gbrbm .get_reconstruction_cost ()
73
+ # mnist examples
74
+ mnist = input_data .read_data_sets ("MNIST_data/" , one_hot = True )
75
+ # define input
76
+ x = tf .placeholder (tf .float32 , shape = [None , 784 ])
77
+ # set random_seed
78
+ tf .set_random_seed (seed = 99999 )
79
+ np .random .seed (123 )
80
+ # the rbm model
81
+ n_visiable , n_hidden = 784 , 500
82
+ rbm = GBRBM (x , n_visiable = n_visiable , n_hidden = n_hidden )
83
+
84
+ learning_rate = 0.01
85
+ batch_size = 50
86
+ cost = rbm .get_reconstruction_cost ()
72
87
# Create the persistent variable
73
88
#persistent_chain = tf.Variable(tf.zeros([batch_size, n_hidden]), dtype=tf.float32)
74
89
persistent_chain = None
75
- train_ops = gbrbm .get_train_ops (learning_rate = learning_rate , k = 1 , persistent = persistent_chain )
90
+ train_ops = rbm .get_train_ops (learning_rate = learning_rate , k = 1 , persistent = persistent_chain )
76
91
init = tf .global_variables_initializer ()
77
92
78
- sess = tf .Session ()
79
- sess .run (init )
80
- for epoch in range (n_epochs ):
81
- avg_cost = 0.0
82
- for i in range (len (data )// batch_size ):
83
- sess .run (train_ops , feed_dict = {x : data [i * batch_size :(i + 1 )* batch_size ]})
84
- avg_cost += sess .run (cost , feed_dict = {x : data [i * batch_size :(i + 1 )* batch_size ]})/ batch_size
85
- print (avg_cost )
86
-
87
- # test
88
- v = np .random .randn (10 , 6 )
89
- print (v )
93
+ output_folder = "rbm_plots"
94
+ if not os .path .isdir (output_folder ):
95
+ os .makedirs (output_folder )
96
+ os .chdir (output_folder )
90
97
91
- preds = sess .run (gbrbm .reconstruct (x ), feed_dict = {x : v })
92
- print (preds )
98
+ training_epochs = 15
99
+ display_step = 1
100
+ print ("Start training..." )
101
+
102
+ with tf .Session () as sess :
103
+ start_time = timeit .default_timer ()
104
+ sess .run (init )
105
+ for epoch in range (training_epochs ):
106
+ avg_cost = 0.0
107
+ batch_num = int (mnist .train .num_examples / batch_size )
108
+ for i in range (batch_num ):
109
+ x_batch , _ = mnist .train .next_batch (batch_size )
110
+ # 训练
111
+ sess .run (train_ops , feed_dict = {x : x_batch })
112
+ # 计算cost
113
+ avg_cost += sess .run (cost , feed_dict = {x : x_batch ,}) / batch_num
114
+ # 输出
115
+ if epoch % display_step == 0 :
116
+ print ("Epoch {0} cost: {1}" .format (epoch , avg_cost ))
117
+ # Construct image from the weight matrix
118
+ image = Image .fromarray (
119
+ tile_raster_images (
120
+ X = sess .run (tf .transpose (rbm .W )),
121
+ img_shape = (28 , 28 ),
122
+ tile_shape = (10 , 10 ),
123
+ tile_spacing = (1 , 1 )))
124
+ image .save ("test_filters_at_epoch_{0}.png" .format (epoch ))
93
125
94
-
126
+ end_time = timeit .default_timer ()
127
+ training_time = end_time - start_time
128
+ print ("Finished!" )
129
+ print (" The training ran for {0} minutes." .format (training_time / 60 ,))
130
+
131
+ # Randomly select the 'n_chains' examples
132
+ n_chains = 20
133
+ n_batch = 10
134
+ n_samples = n_batch * 2
135
+ number_test_examples = mnist .test .num_examples
136
+ test_indexs = np .random .randint (number_test_examples - n_chains * n_batch )
137
+ test_samples = mnist .test .images [test_indexs :test_indexs + n_chains * n_batch ]
138
+ image_data = np .zeros ((29 * (n_samples + 1 )+ 1 , 29 * (n_chains )- 1 ),
139
+ dtype = "uint8" )
140
+ # Add the original images
141
+ for i in range (n_batch ):
142
+ image_data [2 * i * 29 :2 * i * 29 + 28 ,:] = tile_raster_images (X = test_samples [i * n_batch :(i + 1 )* n_chains ],
143
+ img_shape = (28 , 28 ),
144
+ tile_shape = (1 , n_chains ),
145
+ tile_spacing = (1 , 1 ))
146
+ samples = sess .run (rbm .reconstruct (x ), feed_dict = {x :test_samples [i * n_batch :(i + 1 )* n_chains ]})
147
+ image_data [(2 * i + 1 )* 29 :(2 * i + 1 )* 29 + 28 ,:] = tile_raster_images (X = samples ,
148
+ img_shape = (28 , 28 ),
149
+ tile_shape = (1 , n_chains ),
150
+ tile_spacing = (1 , 1 ))
151
+
152
+ image = Image .fromarray (image_data )
153
+ image .save ("original_and_reconstruct.png" )
154
+
155
+
156
+
0 commit comments