Skip to content

Commit bc86802

Browse files
committed
📚 wasserstein gan
1 parent da040fa commit bc86802

File tree

11 files changed

+1125
-435
lines changed

11 files changed

+1125
-435
lines changed

docs/cnn/utils/cv_train.html

Lines changed: 12 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -72,38 +72,7 @@
7272
<div class='section-link'>
7373
<a href='#section-0'>#</a>
7474
</div>
75-
<h1>Cross-Validation & Early Stopping</h1>
76-
<p>Implementation of fundamental techniques namely <em>Cross-Validation</em> and <em>Early Stopping</em>
77-
<h3>Cross-Validation</h3>
78-
<p>
79-
Getting data is expensive and in some cases, one has no option but to use a limited amount of data for training their machine learning models.
80-
This is where Cross-Validation is useful. Steps are as follows:
81-
<ol type = "1">
82-
<li> Split the data in K folds </li>
83-
<li> Use K-1 folds to train a set of models</li>
84-
<li> Validate the models on the remaining fold</li>
85-
<li> Repeat (1) and (2) for all the folds</li>
86-
<li> Average the performance over all runs</li>
87-
</ol>
88-
</p>
89-
<h3>Early-Stopping</h3>
90-
Deep Learning networks are prone to overfitting, that is although overfitted models have a good performance on train set, they have poor generalization capabilities.
91-
In other words, overfitted models have low bias and high variance. Lower the bias higher the capability of model to fit the data. Higher the variance higher the sensitivity with respect to training data.
92-
<br>Formally, it can be represented as: </br>
93-
<p><script type="math/tex; mode=display"> loss = {bias}^2 + {variance} + noise </script></p>
94-
<p>Therefore, user has to find a tradeoff between bias and variance.</p>
95-
<p> </p>
96-
<p> Early-Stopping is one of the way to find this tradeoff. It helps to find a good setting of parameters and preventing overfitting on dataset and saving computation time.
97-
This can be visualized through the following graph of train loss and validation loss over time: </p> <br>
98-
99-
100-
<a href="https://www.deeplearningbook.org/contents/regularization.html"><img src="Cross-validation.png" alt="Training v/s Validation set Loss"></a>
101-
<br>
102-
<p> It can be seen that train error continue to decrease but the validation error start to increase after around 40 epochs.
103-
Therefore, our goal is to stop the training after the validation loss increases </p>
104-
105-
</p>
106-
75+
10776
</div>
10877
<div class='code'>
10978
<div class="highlight"><pre><span class="lineno">3</span><span></span><span class="kn">import</span> <span class="nn">torch</span>
@@ -128,10 +97,7 @@ <h3>Early-Stopping</h3>
12897
<div class='section-link'>
12998
<a href='#section-1'>#</a>
13099
</div>
131-
<h3>Cross-Validation</h3>
132-
<p> Splitting of training set in folds can be represented as: </p>
133-
<img src="cv-folds.png" alt="CV folds">
134-
100+
135101
</div>
136102
<div class='code'>
137103
<div class="highlight"><pre><span class="lineno">21</span><span class="k">def</span> <span class="nf">cross_val_train</span><span class="p">(</span><span class="n">cost</span><span class="p">,</span> <span class="n">trainset</span><span class="p">,</span> <span class="n">epochs</span><span class="p">,</span> <span class="n">splits</span><span class="p">,</span> <span class="n">device</span><span class="o">=</span><span class="kc">None</span><span class="p">):</span>
@@ -190,7 +156,7 @@ <h3>Cross-Validation</h3>
190156
<div class='section-link'>
191157
<a href='#section-3'>#</a>
192158
</div>
193-
<p>Training steps</p>
159+
<p>training steps</p>
194160
</div>
195161
<div class='code'>
196162
<div class="highlight"><pre><span class="lineno">65</span> <span class="n">net</span><span class="o">.</span><span class="n">train</span><span class="p">()</span> <span class="c1"># Enable Dropout</span>
@@ -203,7 +169,6 @@ <h3>Cross-Validation</h3>
203169
<a href='#section-4'>#</a>
204170
</div>
205171
<p>Get the inputs; data is a list of [inputs, labels]</p>
206-
<p>Load the inputs in GPU if available else CPU</p>
207172
</div>
208173
<div class='code'>
209174
<div class="highlight"><pre><span class="lineno">68</span> <span class="k">if</span> <span class="n">device</span><span class="p">:</span>
@@ -242,7 +207,7 @@ <h3>Cross-Validation</h3>
242207
<div class='section-link'>
243208
<a href='#section-7'>#</a>
244209
</div>
245-
<p>Calculate loss</p>
210+
<p>Print loss</p>
246211
</div>
247212
<div class='code'>
248213
<div class="highlight"><pre><span class="lineno">82</span> <span class="n">running_loss</span> <span class="o">+=</span> <span class="n">loss</span><span class="o">.</span><span class="n">item</span><span class="p">()</span>
@@ -258,7 +223,7 @@ <h3>Cross-Validation</h3>
258223
<div class='section-link'>
259224
<a href='#section-8'>#</a>
260225
</div>
261-
<p>Validation and printing the metrics</p>
226+
<p>Validation</p>
262227
</div>
263228
<div class='code'>
264229
<div class="highlight"><pre><span class="lineno">90</span> <span class="n">loss_accuracy</span> <span class="o">=</span> <span class="n">Test</span><span class="p">(</span><span class="n">net</span><span class="p">,</span> <span class="n">cost</span><span class="p">,</span> <span class="n">valdata</span><span class="p">,</span> <span class="n">device</span><span class="p">)</span>
@@ -294,17 +259,7 @@ <h3>Cross-Validation</h3>
294259
<div class='section-link'>
295260
<a href='#section-10'>#</a>
296261
</div>
297-
<h3>Early stopping</h3>
298-
<p>Early stopping can be understood graphically - the way weights change during the course of training.</p>
299-
<ul>
300-
<li> Solid contour lines indicate the contours of the negative log-likelihood (train error)</li>
301-
<li> Dashed line indicates the trajectory taken by the optimizer</li>
302-
<li> w∗ denotes the weight setting correspoding to the minimum training error </li>
303-
<li> w denotes the final weights setting chosen by the model after early-stopping </li>
304-
</ul>
305-
<a href="https://www.deeplearningbook.org/contents/regularization.html"><img src="early-stopping.png" alt="early-stopping" hspace="100" ></a> <!--align="middle"-->
306-
<br>
307-
<a href="https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py"><em>code reference here</em></a>
262+
<p>Early stopping refered from https://github.com/Bjarten/early-stopping-pytorch/blob/master/pytorchtools.py</p>
308263
</div>
309264
<div class='code'>
310265
<div class="highlight"><pre><span class="lineno">110</span> <span class="k">if</span> <span class="n">losses</span><span class="p">[</span><span class="n">epoch</span><span class="p">]</span> <span class="o">&gt;</span> <span class="n">min_loss</span><span class="p">:</span>
@@ -358,7 +313,7 @@ <h3>Early stopping</h3>
358313
<div class='section-link'>
359314
<a href='#section-13'>#</a>
360315
</div>
361-
<p>Retrieve the model which has the best accuracy over the validation set </p>
316+
362317
</div>
363318
<div class='code'>
364319
<div class="highlight"><pre><span class="lineno">138</span><span class="k">def</span> <span class="nf">retreive_best_trial</span><span class="p">():</span>
@@ -412,7 +367,7 @@ <h3>Early stopping</h3>
412367
<div class='section-link'>
413368
<a href='#section-16'>#</a>
414369
</div>
415-
<p>Forward pass</p>
370+
<p>forward pass</p>
416371
</div>
417372
<div class='code'>
418373
<div class="highlight"><pre><span class="lineno">166</span> <span class="n">output</span> <span class="o">=</span> <span class="n">net</span><span class="p">(</span><span class="n">images</span><span class="p">)</span></pre></div>
@@ -423,7 +378,7 @@ <h3>Early stopping</h3>
423378
<div class='section-link'>
424379
<a href='#section-17'>#</a>
425380
</div>
426-
<p>Loss in batch</p>
381+
<p>loss in batch</p>
427382
</div>
428383
<div class='code'>
429384
<div class="highlight"><pre><span class="lineno">168</span> <span class="n">loss</span> <span class="o">=</span> <span class="n">cost</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span></pre></div>
@@ -434,7 +389,7 @@ <h3>Early stopping</h3>
434389
<div class='section-link'>
435390
<a href='#section-18'>#</a>
436391
</div>
437-
<p>Update validation loss</p>
392+
<p>update validation loss</p>
438393
</div>
439394
<div class='code'>
440395
<div class="highlight"><pre><span class="lineno">171</span> <span class="n">_</span><span class="p">,</span> <span class="n">preds</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">output</span><span class="p">,</span> <span class="n">dim</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
@@ -502,7 +457,7 @@ <h3>Early stopping</h3>
502457
<div class='section-link'>
503458
<a href='#section-23'>#</a>
504459
</div>
505-
<p>Loss in batch</p>
460+
<p>loss in batch</p>
506461
</div>
507462
<div class='code'>
508463
<div class="highlight"><pre><span class="lineno">197</span> <span class="n">loss</span> <span class="o">+=</span> <span class="n">cost</span><span class="p">(</span><span class="n">outputs</span><span class="p">,</span> <span class="n">labels</span><span class="p">)</span>
@@ -514,7 +469,7 @@ <h3>Early stopping</h3>
514469
<div class='section-link'>
515470
<a href='#section-24'>#</a>
516471
</div>
517-
<p>Calculate loss and accuracy over the validation set</p>
472+
<p>losses[epoch] += loss.item()</p>
518473
</div>
519474
<div class='code'>
520475
<div class="highlight"><pre><span class="lineno">201</span> <span class="n">_</span><span class="p">,</span> <span class="n">predicted</span> <span class="o">=</span> <span class="n">torch</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="n">outputs</span><span class="o">.</span><span class="n">data</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>

docs/gan/cycle_gan.html renamed to docs/gan/cycle_gan/index.html

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
<meta name="twitter:site" content="@labmlai"/>
1313
<meta name="twitter:creator" content="@labmlai"/>
1414

15-
<meta property="og:url" content="https://nn.labml.ai/gan/cycle_gan.html"/>
15+
<meta property="og:url" content="https://nn.labml.ai/gan/cycle_gan/index.html"/>
1616
<meta property="og:title" content="Cycle GAN"/>
1717
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
1818
<meta property="og:site_name" content="LabML Neural Networks"/>
@@ -22,8 +22,8 @@
2222

2323
<title>Cycle GAN</title>
2424
<link rel="shortcut icon" href="/icon.png"/>
25-
<link rel="stylesheet" href="../pylit.css">
26-
<link rel="canonical" href="https://nn.labml.ai/gan/cycle_gan.html"/>
25+
<link rel="stylesheet" href="../../pylit.css">
26+
<link rel="canonical" href="https://nn.labml.ai/gan/cycle_gan/index.html"/>
2727
<!-- Global site tag (gtag.js) - Google Analytics -->
2828
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
2929
<script>
@@ -45,11 +45,12 @@
4545
<div class='docs'>
4646
<p>
4747
<a class="parent" href="/">home</a>
48-
<a class="parent" href="index.html">gan</a>
48+
<a class="parent" href="../index.html">gan</a>
49+
<a class="parent" href="index.html">cycle_gan</a>
4950
</p>
5051
<p>
5152

52-
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/cycle_gan.py">
53+
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/cycle_gan/__init__.py">
5354
<img alt="Github"
5455
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
5556
style="max-width:100%;"/></a>
@@ -88,7 +89,7 @@ <h1>Cycle GAN</h1>
8889
The discriminators test whether the generated images look real.</p>
8990
<p>This file contains the model code as well as the training code.
9091
We also have a Google Colab notebook.</p>
91-
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
92+
<p><a href="https://colab.research.google.com/github/lab-ml/nn/blob/master/labml_nn/gan/cycle_gan/experiment.ipynb"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" /></a>
9293
<a href="https://app.labml.ai/run/93b11a665d6811ebaac80242ac1c0002"><img alt="View Run" src="https://img.shields.io/badge/labml-experiment-brightgreen" /></a></p>
9394
</div>
9495
<div class='code'>

docs/gan/dcgan.html renamed to docs/gan/dcgan/index.html

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
<meta name="twitter:site" content="@labmlai"/>
1313
<meta name="twitter:creator" content="@labmlai"/>
1414

15-
<meta property="og:url" content="https://nn.labml.ai/gan/dcgan.html"/>
15+
<meta property="og:url" content="https://nn.labml.ai/gan/dcgan/index.html"/>
1616
<meta property="og:title" content="Deep Convolutional Generative Adversarial Networks (DCGAN)"/>
1717
<meta property="og:image" content="https://avatars1.githubusercontent.com/u/64068543?s=400&amp;v=4"/>
1818
<meta property="og:site_name" content="LabML Neural Networks"/>
@@ -22,8 +22,8 @@
2222

2323
<title>Deep Convolutional Generative Adversarial Networks (DCGAN)</title>
2424
<link rel="shortcut icon" href="/icon.png"/>
25-
<link rel="stylesheet" href="../pylit.css">
26-
<link rel="canonical" href="https://nn.labml.ai/gan/dcgan.html"/>
25+
<link rel="stylesheet" href="../../pylit.css">
26+
<link rel="canonical" href="https://nn.labml.ai/gan/dcgan/index.html"/>
2727
<!-- Global site tag (gtag.js) - Google Analytics -->
2828
<script async src="https://www.googletagmanager.com/gtag/js?id=G-4V3HC8HBLH"></script>
2929
<script>
@@ -45,11 +45,12 @@
4545
<div class='docs'>
4646
<p>
4747
<a class="parent" href="/">home</a>
48-
<a class="parent" href="index.html">gan</a>
48+
<a class="parent" href="../index.html">gan</a>
49+
<a class="parent" href="index.html">dcgan</a>
4950
</p>
5051
<p>
5152

52-
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/dcgan.py">
53+
<a href="https://github.com/lab-ml/labml_nn/tree/master/labml_nn/gan/dcgan/__init__.py">
5354
<img alt="Github"
5455
src="https://img.shields.io/github/stars/lab-ml/nn?style=social"
5556
style="max-width:100%;"/></a>
@@ -82,7 +83,7 @@ <h1>Deep Convolutional Generative Adversarial Networks (DCGAN)</h1>
8283
<span class="lineno">17</span><span class="kn">from</span> <span class="nn">labml</span> <span class="kn">import</span> <span class="n">experiment</span>
8384
<span class="lineno">18</span><span class="kn">from</span> <span class="nn">labml.configs</span> <span class="kn">import</span> <span class="n">calculate</span>
8485
<span class="lineno">19</span><span class="kn">from</span> <span class="nn">labml_helpers.module</span> <span class="kn">import</span> <span class="n">Module</span>
85-
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.gan.simple_mnist_experiment</span> <span class="kn">import</span> <span class="n">Configs</span></pre></div>
86+
<span class="lineno">20</span><span class="kn">from</span> <span class="nn">labml_nn.gan.original.experiment</span> <span class="kn">import</span> <span class="n">Configs</span></pre></div>
8687
</div>
8788
</div>
8889
<div class='section' id='section-1'>
@@ -338,7 +339,7 @@ <h3>Convolutional Discriminator Network</h3>
338339
<div class='code'>
339340
<div class="highlight"><pre><span class="lineno">108</span><span class="k">def</span> <span class="nf">main</span><span class="p">():</span>
340341
<span class="lineno">109</span> <span class="n">conf</span> <span class="o">=</span> <span class="n">Configs</span><span class="p">()</span>
341-
<span class="lineno">110</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_dcgan&#39;</span><span class="p">,</span> <span class="n">comment</span><span class="o">=</span><span class="s1">&#39;test&#39;</span><span class="p">)</span>
342+
<span class="lineno">110</span> <span class="n">experiment</span><span class="o">.</span><span class="n">create</span><span class="p">(</span><span class="n">name</span><span class="o">=</span><span class="s1">&#39;mnist_dcgan&#39;</span><span class="p">)</span>
342343
<span class="lineno">111</span> <span class="n">experiment</span><span class="o">.</span><span class="n">configs</span><span class="p">(</span><span class="n">conf</span><span class="p">,</span>
343344
<span class="lineno">112</span> <span class="p">{</span><span class="s1">&#39;discriminator&#39;</span><span class="p">:</span> <span class="s1">&#39;cnn&#39;</span><span class="p">,</span>
344345
<span class="lineno">113</span> <span class="s1">&#39;generator&#39;</span><span class="p">:</span> <span class="s1">&#39;cnn&#39;</span><span class="p">,</span>

0 commit comments

Comments
 (0)