Skip to content

Commit 5b9a489

Browse files
authored
Modify the GBRBM (use mse cost)
1 parent bbd5cf8 commit 5b9a489

File tree

1 file changed

+89
-27
lines changed

1 file changed

+89
-27
lines changed

models/gbrbm.py

+89-27
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@
33
author: Ye Hu
44
2016/12/18
55
"""
6+
import os
67
import timeit
78
import numpy as np
89
import tensorflow as tf
10+
from PIL import Image
11+
from utils import tile_raster_images
12+
import input_data
913
from rbm import RBM
1014

1115

@@ -56,39 +60,97 @@ def free_energy(self, v_sample):
5660
hidden_term = tf.reduce_sum(tf.log(1.0 + tf.exp(wx_b)), axis=1)
5761
return -hidden_term + vbias_term
5862

63+
def get_reconstruction_cost(self):
64+
"""Compute the mse of the original input and the reconstruction"""
65+
activation_h = self.propup(self.input)
66+
activation_v = self.propdown(activation_h)
67+
mse = tf.reduce_mean(tf.reduce_sum(tf.square(self.input - activation_v), axis=1))
68+
return mse
69+
70+
5971

6072
if __name__ == "__main__":
61-
data = np.random.randn(1000, 6)
62-
x = tf.placeholder(tf.float32, shape=[None, 6])
63-
64-
gbrbm = GBRBM(x, n_visiable=6, n_hidden=5)
65-
66-
learning_rate = 0.1
67-
k = 1
68-
batch_size = 20
69-
n_epochs = 10
70-
71-
cost = gbrbm.get_reconstruction_cost()
73+
# mnist examples
74+
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
75+
# define input
76+
x = tf.placeholder(tf.float32, shape=[None, 784])
77+
# set random_seed
78+
tf.set_random_seed(seed=99999)
79+
np.random.seed(123)
80+
# the rbm model
81+
n_visiable, n_hidden = 784, 500
82+
rbm = GBRBM(x, n_visiable=n_visiable, n_hidden=n_hidden)
83+
84+
learning_rate = 0.01
85+
batch_size = 50
86+
cost = rbm.get_reconstruction_cost()
7287
# Create the persistent variable
7388
#persistent_chain = tf.Variable(tf.zeros([batch_size, n_hidden]), dtype=tf.float32)
7489
persistent_chain = None
75-
train_ops = gbrbm.get_train_ops(learning_rate=learning_rate, k=1, persistent=persistent_chain)
90+
train_ops = rbm.get_train_ops(learning_rate=learning_rate, k=1, persistent=persistent_chain)
7691
init = tf.global_variables_initializer()
7792

78-
sess = tf.Session()
79-
sess.run(init)
80-
for epoch in range(n_epochs):
81-
avg_cost = 0.0
82-
for i in range(len(data)//batch_size):
83-
sess.run(train_ops, feed_dict={x: data[i*batch_size:(i+1)*batch_size]})
84-
avg_cost += sess.run(cost, feed_dict={x: data[i*batch_size:(i+1)*batch_size]})/batch_size
85-
print(avg_cost)
86-
87-
# test
88-
v = np.random.randn(10, 6)
89-
print(v)
93+
output_folder = "rbm_plots"
94+
if not os.path.isdir(output_folder):
95+
os.makedirs(output_folder)
96+
os.chdir(output_folder)
9097

91-
preds = sess.run(gbrbm.reconstruct(x), feed_dict={x: v})
92-
print(preds)
98+
training_epochs = 15
99+
display_step = 1
100+
print("Start training...")
101+
102+
with tf.Session() as sess:
103+
start_time = timeit.default_timer()
104+
sess.run(init)
105+
for epoch in range(training_epochs):
106+
avg_cost = 0.0
107+
batch_num = int(mnist.train.num_examples / batch_size)
108+
for i in range(batch_num):
109+
x_batch, _ = mnist.train.next_batch(batch_size)
110+
# 训练
111+
sess.run(train_ops, feed_dict={x: x_batch})
112+
# 计算cost
113+
avg_cost += sess.run(cost, feed_dict={x: x_batch,}) / batch_num
114+
# 输出
115+
if epoch % display_step == 0:
116+
print("Epoch {0} cost: {1}".format(epoch, avg_cost))
117+
# Construct image from the weight matrix
118+
image = Image.fromarray(
119+
tile_raster_images(
120+
X=sess.run(tf.transpose(rbm.W)),
121+
img_shape=(28, 28),
122+
tile_shape=(10, 10),
123+
tile_spacing=(1, 1)))
124+
image.save("test_filters_at_epoch_{0}.png".format(epoch))
93125

94-
126+
end_time = timeit.default_timer()
127+
training_time = end_time - start_time
128+
print("Finished!")
129+
print(" The training ran for {0} minutes.".format(training_time/60,))
130+
131+
# Randomly select the 'n_chains' examples
132+
n_chains = 20
133+
n_batch = 10
134+
n_samples = n_batch*2
135+
number_test_examples = mnist.test.num_examples
136+
test_indexs = np.random.randint(number_test_examples - n_chains*n_batch)
137+
test_samples = mnist.test.images[test_indexs:test_indexs+n_chains*n_batch]
138+
image_data = np.zeros((29*(n_samples+1)+1, 29*(n_chains)-1),
139+
dtype="uint8")
140+
# Add the original images
141+
for i in range(n_batch):
142+
image_data[2*i*29:2*i*29+28,:] = tile_raster_images(X=test_samples[i*n_batch:(i+1)*n_chains],
143+
img_shape=(28, 28),
144+
tile_shape=(1, n_chains),
145+
tile_spacing=(1, 1))
146+
samples = sess.run(rbm.reconstruct(x), feed_dict={x:test_samples[i*n_batch:(i+1)*n_chains]})
147+
image_data[(2*i+1)*29:(2*i+1)*29+28,:] = tile_raster_images(X=samples,
148+
img_shape=(28, 28),
149+
tile_shape=(1, n_chains),
150+
tile_spacing=(1, 1))
151+
152+
image = Image.fromarray(image_data)
153+
image.save("original_and_reconstruct.png")
154+
155+
156+

0 commit comments

Comments
 (0)