|
| 1 | +{ |
| 2 | + "cells": [ |
| 3 | + { |
| 4 | + "cell_type": "code", |
| 5 | + "execution_count": 1, |
| 6 | + "metadata": { |
| 7 | + "collapsed": true |
| 8 | + }, |
| 9 | + "outputs": [], |
| 10 | + "source": [ |
| 11 | + "import torch\n", |
| 12 | + "import torch.nn\n", |
| 13 | + "import torch.nn.functional as nn\n", |
| 14 | + "import torch.autograd as autograd\n", |
| 15 | + "import torch.optim as optim\n", |
| 16 | + "import numpy as np\n", |
| 17 | + "import matplotlib.pyplot as plt\n", |
| 18 | + "import matplotlib.gridspec as gridspec\n", |
| 19 | + "import os\n", |
| 20 | + "from torch.autograd import Variable\n", |
| 21 | + "from tensorflow.examples.tutorials.mnist import input_data" |
| 22 | + ] |
| 23 | + }, |
| 24 | + { |
| 25 | + "cell_type": "code", |
| 26 | + "execution_count": 2, |
| 27 | + "metadata": {}, |
| 28 | + "outputs": [ |
| 29 | + { |
| 30 | + "name": "stdout", |
| 31 | + "output_type": "stream", |
| 32 | + "text": [ |
| 33 | + "Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.\n", |
| 34 | + "Extracting ../../MNIST_data/train-images-idx3-ubyte.gz\n", |
| 35 | + "Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.\n", |
| 36 | + "Extracting ../../MNIST_data/train-labels-idx1-ubyte.gz\n", |
| 37 | + "Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.\n", |
| 38 | + "Extracting ../../MNIST_data/t10k-images-idx3-ubyte.gz\n", |
| 39 | + "Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.\n", |
| 40 | + "Extracting ../../MNIST_data/t10k-labels-idx1-ubyte.gz\n" |
| 41 | + ] |
| 42 | + } |
| 43 | + ], |
| 44 | + "source": [ |
| 45 | + "mnist = input_data.read_data_sets('../../MNIST_data', one_hot=True)" |
| 46 | + ] |
| 47 | + }, |
| 48 | + { |
| 49 | + "cell_type": "code", |
| 50 | + "execution_count": 3, |
| 51 | + "metadata": { |
| 52 | + "collapsed": true |
| 53 | + }, |
| 54 | + "outputs": [], |
| 55 | + "source": [ |
| 56 | + "mb_size = 32\n", |
| 57 | + "z_dim = 5\n", |
| 58 | + "X_dim = mnist.train.images.shape[1]\n", |
| 59 | + "y_dim = mnist.train.labels.shape[1]\n", |
| 60 | + "h_dim = 128\n", |
| 61 | + "cnt = 0\n", |
| 62 | + "lr = 1e-3" |
| 63 | + ] |
| 64 | + }, |
| 65 | + { |
| 66 | + "cell_type": "code", |
| 67 | + "execution_count": 4, |
| 68 | + "metadata": { |
| 69 | + "collapsed": true |
| 70 | + }, |
| 71 | + "outputs": [], |
| 72 | + "source": [ |
| 73 | + "# Encoder\n", |
| 74 | + "Q = torch.nn.Sequential(\n", |
| 75 | + " torch.nn.Linear(X_dim, h_dim),\n", |
| 76 | + " torch.nn.ReLU(),\n", |
| 77 | + " torch.nn.Linear(h_dim, z_dim)\n", |
| 78 | + ")" |
| 79 | + ] |
| 80 | + }, |
| 81 | + { |
| 82 | + "cell_type": "code", |
| 83 | + "execution_count": 6, |
| 84 | + "metadata": {}, |
| 85 | + "outputs": [], |
| 86 | + "source": [ |
| 87 | + "# Decoder\n", |
| 88 | + "P = torch.nn.Sequential(\n", |
| 89 | + " torch.nn.Linear(z_dim, h_dim),\n", |
| 90 | + " torch.nn.ReLU(),\n", |
| 91 | + " torch.nn.Linear(h_dim, X_dim),\n", |
| 92 | + " torch.nn.Sigmoid())" |
| 93 | + ] |
| 94 | + }, |
| 95 | + { |
| 96 | + "cell_type": "code", |
| 97 | + "execution_count": 7, |
| 98 | + "metadata": { |
| 99 | + "collapsed": true |
| 100 | + }, |
| 101 | + "outputs": [], |
| 102 | + "source": [ |
| 103 | + "# Discriminator\n", |
| 104 | + "D = torch.nn.Sequential(\n", |
| 105 | + " torch.nn.Linear(z_dim, h_dim),\n", |
| 106 | + " torch.nn.ReLU(),\n", |
| 107 | + " torch.nn.Linear(h_dim, 1),\n", |
| 108 | + " torch.nn.Sigmoid()\n", |
| 109 | + ")" |
| 110 | + ] |
| 111 | + }, |
| 112 | + { |
| 113 | + "cell_type": "code", |
| 114 | + "execution_count": 8, |
| 115 | + "metadata": { |
| 116 | + "collapsed": true |
| 117 | + }, |
| 118 | + "outputs": [], |
| 119 | + "source": [ |
| 120 | + "# Reset Gradient\n", |
| 121 | + "def reset_grad():\n", |
| 122 | + " Q.zero_grad()\n", |
| 123 | + " P.zero_grad()\n", |
| 124 | + " D.zero_grad()" |
| 125 | + ] |
| 126 | + }, |
| 127 | + { |
| 128 | + "cell_type": "code", |
| 129 | + "execution_count": 9, |
| 130 | + "metadata": { |
| 131 | + "collapsed": true |
| 132 | + }, |
| 133 | + "outputs": [], |
| 134 | + "source": [ |
| 135 | + "def sample_X(size, include_y=False):\n", |
| 136 | + " X, y = mnist.train.next_batch(size)\n", |
| 137 | + " X = Variable(torch.from_numpy(X))\n", |
| 138 | + "\n", |
| 139 | + " if include_y:\n", |
| 140 | + " y = np.argmax(y, axis=1).astype(np.int)\n", |
| 141 | + " y = Variable(torch.from_numpy(y))\n", |
| 142 | + " return X, y\n", |
| 143 | + "\n", |
| 144 | + " return X\n" |
| 145 | + ] |
| 146 | + }, |
| 147 | + { |
| 148 | + "cell_type": "code", |
| 149 | + "execution_count": 10, |
| 150 | + "metadata": {}, |
| 151 | + "outputs": [ |
| 152 | + { |
| 153 | + "name": "stdout", |
| 154 | + "output_type": "stream", |
| 155 | + "text": [ |
| 156 | + "Iter-0; D_loss: 1.439; G_loss: 0.7287; recon_loss: 0.6982\n", |
| 157 | + "Iter-1000; D_loss: 1.371; G_loss: 0.7533; recon_loss: 0.2727\n", |
| 158 | + "Iter-2000; D_loss: 1.449; G_loss: 0.6482; recon_loss: 0.2138\n", |
| 159 | + "Iter-3000; D_loss: 1.408; G_loss: 0.69; recon_loss: 0.1909\n", |
| 160 | + "Iter-4000; D_loss: 1.383; G_loss: 0.7024; recon_loss: 0.1826\n", |
| 161 | + "Iter-5000; D_loss: 1.39; G_loss: 0.6931; recon_loss: 0.177\n", |
| 162 | + "Iter-6000; D_loss: 1.372; G_loss: 0.7667; recon_loss: 0.1709\n", |
| 163 | + "Iter-7000; D_loss: 1.395; G_loss: 0.7129; recon_loss: 0.1825\n", |
| 164 | + "Iter-8000; D_loss: 1.389; G_loss: 0.6816; recon_loss: 0.1665\n", |
| 165 | + "Iter-9000; D_loss: 1.39; G_loss: 0.6768; recon_loss: 0.1914\n", |
| 166 | + "Iter-10000; D_loss: 1.387; G_loss: 0.6906; recon_loss: 0.1478\n", |
| 167 | + "Iter-11000; D_loss: 1.379; G_loss: 0.7249; recon_loss: 0.167\n", |
| 168 | + "Iter-12000; D_loss: 1.393; G_loss: 0.6823; recon_loss: 0.1833\n", |
| 169 | + "Iter-13000; D_loss: 1.386; G_loss: 0.6821; recon_loss: 0.1486\n", |
| 170 | + "Iter-14000; D_loss: 1.393; G_loss: 0.6952; recon_loss: 0.1572\n", |
| 171 | + "Iter-15000; D_loss: 1.386; G_loss: 0.7; recon_loss: 0.1638\n", |
| 172 | + "Iter-16000; D_loss: 1.383; G_loss: 0.696; recon_loss: 0.1668\n", |
| 173 | + "Iter-17000; D_loss: 1.391; G_loss: 0.6997; recon_loss: 0.163\n", |
| 174 | + "Iter-18000; D_loss: 1.388; G_loss: 0.6924; recon_loss: 0.1619\n", |
| 175 | + "Iter-19000; D_loss: 1.388; G_loss: 0.6861; recon_loss: 0.1596\n" |
| 176 | + ] |
| 177 | + }, |
| 178 | + { |
| 179 | + "ename": "KeyboardInterrupt", |
| 180 | + "evalue": "", |
| 181 | + "output_type": "error", |
| 182 | + "traceback": [ |
| 183 | + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", |
| 184 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", |
| 185 | + "\u001b[0;32m<ipython-input-10-35e7d869027e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mG_loss\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mQ_solver\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mreset_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", |
| 186 | + "\u001b[0;32m~/anaconda/lib/python3.6/site-packages/torch/optim/adam.py\u001b[0m in \u001b[0;36mstep\u001b[0;34m(self, closure)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[0mstep_size\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mgroup\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m'lr'\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mmath\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msqrt\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbias_correction2\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m/\u001b[0m \u001b[0mbias_correction1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 73\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 74\u001b[0;31m \u001b[0mp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0maddcdiv_\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0mstep_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mexp_avg\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdenom\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 75\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mloss\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", |
| 187 | + "\u001b[0;31mKeyboardInterrupt\u001b[0m: " |
| 188 | + ] |
| 189 | + } |
| 190 | + ], |
| 191 | + "source": [ |
| 192 | + "Q_solver = optim.Adam(Q.parameters(), lr=lr)\n", |
| 193 | + "P_solver = optim.Adam(P.parameters(), lr=lr)\n", |
| 194 | + "D_solver = optim.Adam(D.parameters(), lr=lr)\n", |
| 195 | + "\n", |
| 196 | + "\n", |
| 197 | + "for it in range(1000000):\n", |
| 198 | + " X = sample_X(mb_size)\n", |
| 199 | + "\n", |
| 200 | + " \"\"\" Reconstruction phase \"\"\"\n", |
| 201 | + " z_sample = Q(X)\n", |
| 202 | + " X_sample = P(z_sample)\n", |
| 203 | + "\n", |
| 204 | + " recon_loss = nn.binary_cross_entropy(X_sample, X)\n", |
| 205 | + "\n", |
| 206 | + " recon_loss.backward()\n", |
| 207 | + " P_solver.step()\n", |
| 208 | + " Q_solver.step()\n", |
| 209 | + " reset_grad()\n", |
| 210 | + "\n", |
| 211 | + " \"\"\" Regularization phase \"\"\"\n", |
| 212 | + " # Discriminator\n", |
| 213 | + " z_real = Variable(torch.randn(mb_size, z_dim))\n", |
| 214 | + " z_fake = Q(X)\n", |
| 215 | + "\n", |
| 216 | + " D_real = D(z_real)\n", |
| 217 | + " D_fake = D(z_fake)\n", |
| 218 | + "\n", |
| 219 | + " D_loss = -torch.mean(torch.log(D_real) + torch.log(1 - D_fake))\n", |
| 220 | + "\n", |
| 221 | + " D_loss.backward()\n", |
| 222 | + " D_solver.step()\n", |
| 223 | + " reset_grad()\n", |
| 224 | + "\n", |
| 225 | + " # Generator\n", |
| 226 | + " z_fake = Q(X)\n", |
| 227 | + " D_fake = D(z_fake)\n", |
| 228 | + "\n", |
| 229 | + " G_loss = -torch.mean(torch.log(D_fake))\n", |
| 230 | + "\n", |
| 231 | + " G_loss.backward()\n", |
| 232 | + " Q_solver.step()\n", |
| 233 | + " reset_grad()\n", |
| 234 | + "\n", |
| 235 | + " # Print and plot every now and then\n", |
| 236 | + " if it % 1000 == 0:\n", |
| 237 | + " print('Iter-{}; D_loss: {:.4}; G_loss: {:.4}; recon_loss: {:.4}'\n", |
| 238 | + " .format(it, D_loss.data[0], G_loss.data[0], recon_loss.data[0]))\n", |
| 239 | + "\n", |
| 240 | + " samples = P(z_real).data.numpy()[:16]\n", |
| 241 | + "\n", |
| 242 | + " fig = plt.figure(figsize=(4, 4))\n", |
| 243 | + " gs = gridspec.GridSpec(4, 4)\n", |
| 244 | + " gs.update(wspace=0.05, hspace=0.05)\n", |
| 245 | + "\n", |
| 246 | + " for i, sample in enumerate(samples):\n", |
| 247 | + " ax = plt.subplot(gs[i])\n", |
| 248 | + " plt.axis('off')\n", |
| 249 | + " ax.set_xticklabels([])\n", |
| 250 | + " ax.set_yticklabels([])\n", |
| 251 | + " ax.set_aspect('equal')\n", |
| 252 | + " plt.imshow(sample.reshape(28, 28), cmap='Greys_r')\n", |
| 253 | + "\n", |
| 254 | + " if not os.path.exists('out/'):\n", |
| 255 | + " os.makedirs('out/')\n", |
| 256 | + "\n", |
| 257 | + " plt.savefig('out/{}.png'\n", |
| 258 | + " .format(str(cnt).zfill(3)), bbox_inches='tight')\n", |
| 259 | + " cnt += 1\n", |
| 260 | + " plt.close(fig)" |
| 261 | + ] |
| 262 | + } |
| 263 | + ], |
| 264 | + "metadata": { |
| 265 | + "kernelspec": { |
| 266 | + "display_name": "Python 3", |
| 267 | + "language": "python", |
| 268 | + "name": "python3" |
| 269 | + }, |
| 270 | + "language_info": { |
| 271 | + "codemirror_mode": { |
| 272 | + "name": "ipython", |
| 273 | + "version": 3 |
| 274 | + }, |
| 275 | + "file_extension": ".py", |
| 276 | + "mimetype": "text/x-python", |
| 277 | + "name": "python", |
| 278 | + "nbconvert_exporter": "python", |
| 279 | + "pygments_lexer": "ipython3", |
| 280 | + "version": "3.6.1" |
| 281 | + } |
| 282 | + }, |
| 283 | + "nbformat": 4, |
| 284 | + "nbformat_minor": 2 |
| 285 | +} |
0 commit comments