Skip to content

Commit cb6eb1d

Browse files
committed
Adversarial Autoencoder Example with PyTorch
Trying out MNIST dataset…
1 parent f8f6ff7 commit cb6eb1d

File tree

1 file changed

+285
-0
lines changed

1 file changed

+285
-0
lines changed

pytorch/Adversarial Autoencoder.ipynb

+285
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,285 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "code",
5+
"execution_count": 1,
6+
"metadata": {
7+
"collapsed": true
8+
},
9+
"outputs": [],
10+
"source": [
11+
"import torch\n",
12+
"import torch.nn\n",
13+
"import torch.nn.functional as nn\n",
14+
"import torch.autograd as autograd\n",
15+
"import torch.optim as optim\n",
16+
"import numpy as np\n",
17+
"import matplotlib.pyplot as plt\n",
18+
"import matplotlib.gridspec as gridspec\n",
19+
"import os\n",
20+
"from torch.autograd import Variable\n",
21+
"from tensorflow.examples.tutorials.mnist import input_data"
22+
]
23+
},
24+
{
25+
"cell_type": "code",
26+
"execution_count": 2,
27+
"metadata": {},
28+
"outputs": [
29+
{
30+
"name": "stdout",
31+
"output_type": "stream",
32+
"text": [
33+
"Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n",
34+
"Extracting ../../MNIST_data/train-images-idx3-ubyte.gz\n",
35+
"Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n",
36+
"Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz\n",
37+
"Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n",
38+
"Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz\n",
39+
"Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n",
40+
"Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz\n"
41+
]
42+
}
43+
],
44+
"source": [
45+
"mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)"
46+
]
47+
},
48+
{
49+
"cell_type": "code",
50+
"execution_count": 3,
51+
"metadata": {
52+
"collapsed": true
53+
},
54+
"outputs": [],
55+
"source": [
56+
"mb_size = 32\n",
57+
"z_dim = 5\n",
58+
"X_dim = mnist.train.images.shape[1]\n",
59+
"y_dim = mnist.train.labels.shape[1]\n",
60+
"h_dim = 128\n",
61+
"cnt = 0\n",
62+
"lr = 1e-3"
63+
]
64+
},
65+
{
66+
"cell_type": "code",
67+
"execution_count": 4,
68+
"metadata": {
69+
"collapsed": true
70+
},
71+
"outputs": [],
72+
"source": [
73+
"# Encoder\n",
74+
"Q = torch.nn.Sequential(\n",
75+
" torch.nn.Linear(X_dim, h_dim),\n",
76+
" torch.nn.ReLU(),\n",
77+
" torch.nn.Linear(h_dim, z_dim)\n",
78+
")"
79+
]
80+
},
81+
{
82+
"cell_type": "code",
83+
"execution_count": 6,
84+
"metadata": {},
85+
"outputs": [],
86+
"source": [
87+
"# Decoder\n",
88+
"P = torch.nn.Sequential(\n",
89+
" torch.nn.Linear(z_dim, h_dim),\n",
90+
" torch.nn.ReLU(),\n",
91+
" torch.nn.Linear(h_dim, X_dim),\n",
92+
" torch.nn.Sigmoid())"
93+
]
94+
},
95+
{
96+
"cell_type": "code",
97+
"execution_count": 7,
98+
"metadata": {
99+
"collapsed": true
100+
},
101+
"outputs": [],
102+
"source": [
103+
"# Discriminator\n",
104+
"D = torch.nn.Sequential(\n",
105+
" torch.nn.Linear(z_dim, h_dim),\n",
106+
" torch.nn.ReLU(),\n",
107+
" torch.nn.Linear(h_dim, 1),\n",
108+
" torch.nn.Sigmoid()\n",
109+
")"
110+
]
111+
},
112+
{
113+
"cell_type": "code",
114+
"execution_count": 8,
115+
"metadata": {
116+
"collapsed": true
117+
},
118+
"outputs": [],
119+
"source": [
120+
"# Reset Gradient\n",
121+
"def reset_grad():\n",
122+
" Q.zero_grad()\n",
123+
" P.zero_grad()\n",
124+
" D.zero_grad()"
125+
]
126+
},
127+
{
128+
"cell_type": "code",
129+
"execution_count": 9,
130+
"metadata": {
131+
"collapsed": true
132+
},
133+
"outputs": [],
134+
"source": [
135+
"def sample_X(size, include_y=False):\n",
136+
" X, y = mnist.train.next_batch(size)\n",
137+
" X = Variable(torch.from_numpy(X))\n",
138+
"\n",
139+
" if include_y:\n",
140+
" y = np.argmax(y, axis=1).astype(np.int)\n",
141+
" y = Variable(torch.from_numpy(y))\n",
142+
" return X, y\n",
143+
"\n",
144+
" return X\n"
145+
]
146+
},
147+
{
148+
"cell_type": "code",
149+
"execution_count": 10,
150+
"metadata": {},
151+
"outputs": [
152+
{
153+
"name": "stdout",
154+
"output_type": "stream",
155+
"text": [
156+
"Iter-0; D_loss: 1.439; G_loss: 0.7287; recon_loss: 0.6982\n",
157+
"Iter-1000; D_loss: 1.371; G_loss: 0.7533; recon_loss: 0.2727\n",
158+
"Iter-2000; D_loss: 1.449; G_loss: 0.6482; recon_loss: 0.2138\n",
159+
"Iter-3000; D_loss: 1.408; G_loss: 0.69; recon_loss: 0.1909\n",
160+
"Iter-4000; D_loss: 1.383; G_loss: 0.7024; recon_loss: 0.1826\n",
161+
"Iter-5000; D_loss: 1.39; G_loss: 0.6931; recon_loss: 0.177\n",
162+
"Iter-6000; D_loss: 1.372; G_loss: 0.7667; recon_loss: 0.1709\n",
163+
"Iter-7000; D_loss: 1.395; G_loss: 0.7129; recon_loss: 0.1825\n",
164+
"Iter-8000; D_loss: 1.389; G_loss: 0.6816; recon_loss: 0.1665\n",
165+
"Iter-9000; D_loss: 1.39; G_loss: 0.6768; recon_loss: 0.1914\n",
166+
"Iter-10000; D_loss: 1.387; G_loss: 0.6906; recon_loss: 0.1478\n",
167+
"Iter-11000; D_loss: 1.379; G_loss: 0.7249; recon_loss: 0.167\n",
168+
"Iter-12000; D_loss: 1.393; G_loss: 0.6823; recon_loss: 0.1833\n",
169+
"Iter-13000; D_loss: 1.386; G_loss: 0.6821; recon_loss: 0.1486\n",
170+
"Iter-14000; D_loss: 1.393; G_loss: 0.6952; recon_loss: 0.1572\n",
171+
"Iter-15000; D_loss: 1.386; G_loss: 0.7; recon_loss: 0.1638\n",
172+
"Iter-16000; D_loss: 1.383; G_loss: 0.696; recon_loss: 0.1668\n",
173+
"Iter-17000; D_loss: 1.391; G_loss: 0.6997; recon_loss: 0.163\n",
174+
"Iter-18000; D_loss: 1.388; G_loss: 0.6924; recon_loss: 0.1619\n",
175+
"Iter-19000; D_loss: 1.388; G_loss: 0.6861; recon_loss: 0.1596\n"
176+
]
177+
},
178+
{
179+
"ename": "KeyboardInterrupt",
180+
"evalue": "",
181+
"output_type": "error",
182+
"traceback": [
183+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
184+
"\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
185+
"\u001b[0;32m<ipython-input-10-35e7d869027e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mG_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mQ_solver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mreset_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
186+
"\u001b[0;32m~/anaconda/lib/python3.6/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddcdiv_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mstep_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexp_avg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdenom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
187+
"\u001b[0;31mKeyboardInterrupt\u001b[0m: "
188+
]
189+
}
190+
],
191+
"source": [
192+
"Q_solver = optim.Adam(Q.parameters(), lr=lr)\n",
193+
"P_solver = optim.Adam(P.parameters(), lr=lr)\n",
194+
"D_solver = optim.Adam(D.parameters(), lr=lr)\n",
195+
"\n",
196+
"\n",
197+
"for it in range(1000000):\n",
198+
" X = sample_X(mb_size)\n",
199+
"\n",
200+
" \"\"\" Reconstruction phase \"\"\"\n",
201+
" z_sample = Q(X)\n",
202+
" X_sample = P(z_sample)\n",
203+
"\n",
204+
" recon_loss = nn.binary_cross_entropy(X_sample, X)\n",
205+
"\n",
206+
" recon_loss.backward()\n",
207+
" P_solver.step()\n",
208+
" Q_solver.step()\n",
209+
" reset_grad()\n",
210+
"\n",
211+
" \"\"\" Regularization phase \"\"\"\n",
212+
" # Discriminator\n",
213+
" z_real = Variable(torch.randn(mb_size, z_dim))\n",
214+
" z_fake = Q(X)\n",
215+
"\n",
216+
" D_real = D(z_real)\n",
217+
" D_fake = D(z_fake)\n",
218+
"\n",
219+
" D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))\n",
220+
"\n",
221+
" D_loss.backward()\n",
222+
" D_solver.step()\n",
223+
" reset_grad()\n",
224+
"\n",
225+
" # Generator\n",
226+
" z_fake = Q(X)\n",
227+
" D_fake = D(z_fake)\n",
228+
"\n",
229+
" G_loss = -torch.mean(torch.log(D_fake))\n",
230+
"\n",
231+
" G_loss.backward()\n",
232+
" Q_solver.step()\n",
233+
" reset_grad()\n",
234+
"\n",
235+
" # Print and plot every now and then\n",
236+
" if it % 1000 == 0:\n",
237+
" print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'\n",
238+
" .format(it, D_loss.data[0], G_loss.data[0], recon_loss.data[0]))\n",
239+
"\n",
240+
" samples = P(z_real).data.numpy()[:16]\n",
241+
"\n",
242+
" fig = plt.figure(figsize=(4, 4))\n",
243+
" gs = gridspec.GridSpec(4, 4)\n",
244+
" gs.update(wspace=0.05, hspace=0.05)\n",
245+
"\n",
246+
" for i, sample in enumerate(samples):\n",
247+
" ax = plt.subplot(gs[i])\n",
248+
" plt.axis('off')\n",
249+
" ax.set_xticklabels([])\n",
250+
" ax.set_yticklabels([])\n",
251+
" ax.set_aspect('equal')\n",
252+
" plt.imshow(sample.reshape(28, 28), cmap='Greys_r')\n",
253+
"\n",
254+
" if not os.path.exists('out/'):\n",
255+
" os.makedirs('out/')\n",
256+
"\n",
257+
" plt.savefig('out/{}.png'\n",
258+
" .format(str(cnt).zfill(3)), bbox_inches='tight')\n",
259+
" cnt += 1\n",
260+
" plt.close(fig)"
261+
]
262+
}
263+
],
264+
"metadata": {
265+
"kernelspec": {
266+
"display_name": "Python 3",
267+
"language": "python",
268+
"name": "python3"
269+
},
270+
"language_info": {
271+
"codemirror_mode": {
272+
"name": "ipython",
273+
"version": 3
274+
},
275+
"file_extension": ".py",
276+
"mimetype": "text/x-python",
277+
"name": "python",
278+
"nbconvert_exporter": "python",
279+
"pygments_lexer": "ipython3",
280+
"version": "3.6.1"
281+
}
282+
},
283+
"nbformat": 4,
284+
"nbformat_minor": 2
285+
}

0 commit comments

Comments
 (0)