forked from pytorch/tutorials
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathneural_style_tutorial.py
603 lines (484 loc) · 22.9 KB
/
neural_style_tutorial.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
# -*- coding: utf-8 -*-
"""
Neural Transfer with PyTorch
============================
**Author**: `Alexis Jacq <https://alexis-jacq.github.io>`_
Introduction
------------
Welcome! This tutorial explains how to impletment the
`Neural-Style <https://arxiv.org/abs/1508.06576>`__ algorithm developed
by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge.
Neural what?
~~~~~~~~~~~~
The Neural-Style, or Neural-Transfer, is an algorithm that takes as
input a content-image (e.g. a tortle), a style-image (e.g. artistic
waves) and return the content of the content-image as if it was
'painted' using the artistic style of the style-image:
.. figure:: /_static/img/neural-style/neuralstyle.png
:alt: content1
How does it work?
~~~~~~~~~~~~~~~~~
The principle is simple: we define two distances, one for the content
(:math:`D_C`) and one for the style (:math:`D_S`). :math:`D_C` measures
how different is the content between two images, while :math:`D_S`
measures how different is the style between two images. Then, we take a
third image, the input, (e.g. a with noise), and we transform it in
order to both minimize its content-distance with the content-image and
its style-distance with the style-image.
OK. How does it work?
^^^^^^^^^^^^^^^^^^^^^
Well, going further requires some mathematics. Let :math:`C_{nn}` be a
pre-trained deep convolutional neural network and :math:`X` be any
image. :math:`C_{nn}(X)` is the network fed by :math:`X` (containing
feature maps at all layers). Let :math:`F_{XL} \in C_{nn}(X)` be the
feature maps at depth layer :math:`L`, all vectorized and concatenated
in one single vector. We simply define the content of :math:`X` at layer
:math:`L` by :math:`F_{XL}`. Then, if :math:`Y` is another image of same
the size than :math:`X`, we define the distance of content at layer
:math:`L` as follow:
.. math:: D_C^L(X,Y) = \|F_{XL} - F_{YL}\|^2 = \sum_i (F_{XL}(i) - F_{YL}(i))^2
Where :math:`F_{XL}(i)` is the :math:`i^{th}` element of :math:`F_{XL}`.
The style is a bit less trivial to define. Let :math:`F_{XL}^k` with
:math:`k \leq K` be the vectorized :math:`k^{th}` of the :math:`K`
feature maps at layer :math:`L`. The style :math:`G_{XL}` of :math:`X`
at layer :math:`L` is defined by the Gram produce of all vectorized
feature maps :math:`F_{XL}^k` with :math:`k \leq K`. In other words,
:math:`G_{XL}` is a :math:`K`\ x\ :math:`K` matrix and the element
:math:`G_{XL}(k,l)` at the :math:`k^{th}` line and :math:`l^{th}` column
of :math:`G_{XL}` is the vectorial produce between :math:`F_{XL}^k` and
:math:`F_{XL}^l` :
.. math::
G_{XL}(k,l) = \langle F_{XL}^k, F_{XL}^l\\rangle = \sum_i F_{XL}^k(i) . F_{XL}^l(i)
Where :math:`F_{XL}^k(i)` is the :math:`i^{th}` element of
:math:`F_{XL}^k`. We can see :math:`G_{XL}(k,l)` as a measure of the
correlation between feature maps :math:`k` and :math:`l`. In that way,
:math:`G_{XL}` represents the correlation matrix of feature maps of
:math:`X` at layer :math:`L`. Note that the size of :math:`G_{XL}` only
depends on the number of feature maps, not on the size of :math:`X`.
Then, if :math:`Y` is another image *of any size*, we define the
distance of style at layer :math:`L` as follow:
.. math::
D_S^L(X,Y) = \|G_{XL} - G_{YL}\|^2 = \sum_{k,l} (F_{XL}(k,l) - F_{YL}(k,l))^2
In order to minimize in one shot :math:`D_C(X,C)` between a variable
image :math:`X` and target content-image :math:`C` and :math:`D_S(X,S)`
between :math:`X` and target style-image :math:`S`, both computed at
several layers , we compute and sum the gradients (derivative with
respect to :math:`X`) of each distance at each wanted layer:
.. math::
\\nabla_{\textit{total}}(X,S,C) = \sum_{L_C} w_{CL_C}.\\nabla_{\textit{content}}^{L_C}(X,C) + \sum_{L_S} w_{SL_S}.\\nabla_{\textit{style}}^{L_S}(X,S)
Where :math:`L_C` and :math:`L_S` are respectivement the wanted layers
(arbitrary stated) of content and style and :math:`w_{CL_C}` and
:math:`w_{SL_S}` the weights (arbitrary stated) associated with the
style or the content at each wanted layer. Then, we run a gradient
descent over :math:`X`:
.. math:: X \leftarrow X - \\alpha \\nabla_{\textit{total}}(X,S,C)
Ok. That's enough with maths. If you want to go deeper (how to compute
the gradients) **we encourage you to read the original paper** by Leon
A. Gatys and AL, where everything is much better and much clearer
explained.
For our implementation in PyTorch, we already have everything
we need: indeed, with PyTorch, all the gradients are automatically and
dynamically computed for you (while you use functions from the library).
This is why the implementation of this algorithm becomes very
confortable with PyTorch.
PyTorch implementation
----------------------
If you are not sure to understand all the mathematics above, you will
probably get it by implementing it. If you are discovering PyTorch, we
recommend you to first read this :doc:`Introduction to
PyTorch </beginner/deep_learning_60min_blitz>`.
Packages
~~~~~~~~
We will have recourse to the following packages:
- ``torch``, ``torch.nn``, ``numpy`` (indispensables packages for
neural networks with PyTorch)
- ``torch.autograd.Variable`` (dynamic computation of the gradient wrt
a variable)
- ``torch.optim`` (efficient gradient descents)
- ``PIL``, ``PIL.Image``, ``matplotlib.pyplot`` (load and display
images)
- ``torchvision.transforms`` (treat PIL images and transform into torch
tensors)
- ``torchvision.models`` (train or load pre-trained models)
- ``copy`` (to deep copy the models; system package)
"""
from __future__ import print_function
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.optim as optim
from PIL import Image
import matplotlib.pyplot as plt
import torchvision.transforms as transforms
import torchvision.models as models
import copy
######################################################################
# Cuda
# ~~~~
#
# If you have a GPU on your computer, it is preferable to run the
# algorithm on it, especially if you want to try larger networks (like
# VGG). For this, we have ``torch.cuda.is_available()`` that returns
# ``True`` if you computer has an available GPU. Then, we can use method
# ``.cuda()`` that moves allocated proccesses associated with a module
# from the CPU to the GPU. When we want to move back this module to the
# CPU (e.g. to use numpy), we use the ``.cpu()`` method. Finally,
# ``.type(dtype)`` will be use to convert a ``torch.FloatTensor`` into
# ``torch.cuda.FloatTensor`` to feed GPU processes.
#
use_cuda = torch.cuda.is_available()
dtype = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor
######################################################################
# Load images
# ~~~~~~~~~~~
#
# In order to simplify the implementation, let's start by importing a
# style and a content image of the same dimentions. We then scale them to
# the desired output image size (128 or 512 in the example, depending on gpu
# availablity) and transform them into torch tensors, ready to feed
# a neural network:
#
# .. Note::
# Here are links to download the images required to run the tutorial:
# `picasso.jpg <http://pytorch.org/tutorials/_static/img/neural-style/picasso.jpg>`__ and
# `dancing.jpg <http://pytorch.org/tutorials/_static/img/neural-style/dancing.jpg>`__.
# Download these two images and add them to a directory
# with name ``images``
# desired size of the output image
imsize = 512 if use_cuda else 128 # use small size if no gpu
loader = transforms.Compose([
transforms.Scale(imsize), # scale imported image
transforms.ToTensor()]) # transform it into a torch tensor
def image_loader(image_name):
image = Image.open(image_name)
image = Variable(loader(image))
# fake batch dimension required to fit network's input dimensions
image = image.unsqueeze(0)
return image
style_img = image_loader("images/picasso.jpg").type(dtype)
content_img = image_loader("images/dancing.jpg").type(dtype)
assert style_img.size() == content_img.size(), \
"we need to import style and content images of the same size"
######################################################################
# Imported PIL images has values between 0 and 255. Transformed into torch
# tensors, their values are between 0 and 1. This is an important detail:
# neural networks from torch library are trained with 0-1 tensor image. If
# you try to feed the networks with 0-255 tensor images the activated
# feature maps will have no sense. This is not the case with pre-trained
# networks from the Caffe library: they are trained with 0-255 tensor
# images.
#
# Display images
# ~~~~~~~~~~~~~~
#
# We will use ``plt.imshow`` to display images. So we need to first
# reconvert them into PIL images:
#
unloader = transforms.ToPILImage() # reconvert into PIL image
plt.ion()
def imshow(tensor, title=None):
image = tensor.clone().cpu() # we clone the tensor to not do changes on it
image = image.view(3, imsize, imsize) # remove the fake batch dimension
image = unloader(image)
plt.imshow(image)
if title is not None:
plt.title(title)
plt.pause(0.001) # pause a bit so that plots are updated
plt.figure()
imshow(style_img.data, title='Style Image')
plt.figure()
imshow(content_img.data, title='Content Image')
######################################################################
# Content loss
# ~~~~~~~~~~~~
#
# The content loss is a function that takes as input the feature maps
# :math:`F_{XL}` at a layer :math:`L` in a network fed by :math:`X` and
# return the weigthed content distance :math:`w_{CL}.D_C^L(X,C)` between
# this image and the content image. Hence, the weight :math:`w_{CL}` and
# the target content :math:`F_{CL}` are parameters of the function. We
# implement this function as a torch module with a constructor that takes
# these parameters as input. The distance :math:`\|F_{XL} - F_{YL}\|^2` is
# the Mean Square Error between the two sets of feature maps, that can be
# computed using a criterion ``nn.MSELoss`` stated as a third parameter.
#
# We will add our content losses at each desired layer as additive modules
# of the neural network. That way, each time we will feed the network with
# an input image :math:`X`, all the content losses will be computed at the
# desired layers and, thanks to autograd, all the gradients will be
# computed. For that, we just need to make the ``forward`` method of our
# module returning the input: the module becomes a ''transparent layer''
# of the neural network. The computed loss is saved as a parameter of the
# module.
#
# Finally, we define a fake ``backward`` method, that just call the
# backward method of ``nn.MSELoss`` in order to reconstruct the gradient.
# This method returns the computed loss: this will be useful when running
# the gradient descent in order to display the evolution of style and
# content losses.
#
class ContentLoss(nn.Module):
def __init__(self, target, weight):
super(ContentLoss, self).__init__()
# we 'detach' the target content from the tree used
self.target = target.detach() * weight
# to dynamically compute the gradient: this is a stated value,
# not a variable. Otherwise the forward method of the criterion
# will throw an error.
self.weight = weight
self.criterion = nn.MSELoss()
def forward(self, input):
self.loss = self.criterion(input * self.weight, self.target)
self.output = input
return self.output
def backward(self, retain_variables=True):
self.loss.backward(retain_variables=retain_variables)
return self.loss
######################################################################
# .. Note::
# **Important detail**: this module, although it is named ``ContentLoss``,
# is not a true PyTorch Loss function. If you want to define your content
# loss as a PyTorch Loss, you have to create a PyTorch autograd Function
# and to recompute/implement the gradient by the hand in the ``backward``
# method.
#
# Style loss
# ~~~~~~~~~~
#
# For the style loss, we need first to define a module that compute the
# gram produce :math:`G_{XL}` given the feature maps :math:`F_{XL}` of the
# neural network fed by :math:`X`, at layer :math:`L`. Let
# :math:`\hat{F}_{XL}` be the re-shaped version of :math:`F_{XL}` into a
# :math:`K`\ x\ :math:`N` matrix, where :math:`K` is the number of feature
# maps at layer :math:`L` and :math:`N` the lenght of any vectorized
# feature map :math:`F_{XL}^k`. The :math:`k^{th}` line of
# :math:`\hat{F}_{XL}` is :math:`F_{XL}^k`. We let you check that
# :math:`\hat{F}_{XL} \cdot \hat{F}_{XL}^T = G_{XL}`. Given that, it
# becomes easy to implement our module:
#
class GramMatrix(nn.Module):
def forward(self, input):
a, b, c, d = input.size() # a=batch size(=1)
# b=number of feature maps
# (c,d)=dimensions of a f. map (N=c*d)
features = input.view(a * b, c * d) # resise F_XL into \hat F_XL
G = torch.mm(features, features.t()) # compute the gram product
# we 'normalize' the values of the gram matrix
# by dividing by the number of element in each feature maps.
return G.div(a * b * c * d)
######################################################################
# The longer is the feature maps dimension :math:`N`, the bigger are the
# values of the gram matrix. Therefore, if we don't normalize by :math:`N`,
# the loss computed at the first layers (before pooling layers) will have
# much more importance during the gradient descent. We dont want that,
# since the most interesting style features are in the deepest layers!
#
# Then, the style loss module is implemented exactly the same way than the
# content loss module, but we have to add the ``gramMatrix`` as a
# parameter:
#
class StyleLoss(nn.Module):
def __init__(self, target, weight):
super(StyleLoss, self).__init__()
self.target = target.detach() * weight
self.weight = weight
self.gram = GramMatrix()
self.criterion = nn.MSELoss()
def forward(self, input):
self.output = input.clone()
self.G = self.gram(input)
self.G.mul_(self.weight)
self.loss = self.criterion(self.G, self.target)
return self.output
def backward(self, retain_variables=True):
self.loss.backward(retain_variables=retain_variables)
return self.loss
######################################################################
# Load the neural network
# ~~~~~~~~~~~~~~~~~~~~~~~
#
# Now, we have to import a pre-trained neural network. As in the paper, we
# are going to use a pretrained VGG network with 19 layers (VGG19).
#
# PyTorch's implementation of VGG is a module divided in two child
# ``Sequential`` modules: ``features`` (containing convolution and pooling
# layers) and ``classifier`` (containing fully connected layers). We are
# just interested by ``features``:
#
cnn = models.vgg19(pretrained=True).features
# move it to the GPU if possible:
if use_cuda:
cnn = cnn.cuda()
######################################################################
# A ``Sequential`` module contains an ordered list of child modules. For
# instance, ``vgg19.features`` contains a sequence (Conv2d, ReLU,
# Maxpool2d, Conv2d, ReLU...) aligned in the right order of depth. As we
# said in *Content loss* section, we wand to add our style and content
# loss modules as additive 'transparent' layers in our network, at desired
# depths. For that, we construct a new ``Sequential`` module, in wich we
# are going to add modules from ``vgg19`` and our loss modules in the
# right order:
#
# desired depth layers to compute style/content losses :
content_layers_default = ['conv_4']
style_layers_default = ['conv_1', 'conv_2', 'conv_3', 'conv_4', 'conv_5']
def get_style_model_and_losses(cnn, style_img, content_img,
style_weight=1000, content_weight=1,
content_layers=content_layers_default,
style_layers=style_layers_default):
cnn = copy.deepcopy(cnn)
# just in order to have an iterable access to or list of content/syle
# losses
content_losses = []
style_losses = []
model = nn.Sequential() # the new Sequential module network
gram = GramMatrix() # we need a gram module in order to compute style targets
# move these modules to the GPU if possible:
if use_cuda:
model = model.cuda()
gram = gram.cuda()
i = 1
for layer in list(cnn):
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).clone()
target_feature_gram = gram(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
if isinstance(layer, nn.ReLU):
name = "relu_" + str(i)
model.add_module(name, layer)
if name in content_layers:
# add content loss:
target = model(content_img).clone()
content_loss = ContentLoss(target, content_weight)
model.add_module("content_loss_" + str(i), content_loss)
content_losses.append(content_loss)
if name in style_layers:
# add style loss:
target_feature = model(style_img).clone()
target_feature_gram = gram(target_feature)
style_loss = StyleLoss(target_feature_gram, style_weight)
model.add_module("style_loss_" + str(i), style_loss)
style_losses.append(style_loss)
i += 1
if isinstance(layer, nn.MaxPool2d):
name = "pool_" + str(i)
model.add_module(name, layer) # ***
return model, style_losses, content_losses
######################################################################
# .. Note::
# In the paper they recommend to change max pooling layers into
# average pooling. With AlexNet, that is a small network compared to VGG19
# used in the paper, we are not going to see any difference of quality in
# the result. However, you can use these lines instead if you want to do
# this substitution:
#
# ::
#
# # avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size,
# # stride=layer.stride, padding = layer.padding)
# # model.add_module(name,avgpool)
######################################################################
# Input image
# ~~~~~~~~~~~
#
# Again, in order to simplify the code, we take an image of the same
# dimensions than content and style images. This image can be a white
# noise, or it can also be a copy of the content-image.
#
input_img = content_img.clone()
# if you want to use a white noise instead uncomment the below line:
# input_img = Variable(torch.randn(content_img.data.size())).type(dtype)
# add the original input image to the figure:
plt.figure()
imshow(input_img.data, title='Input Image')
######################################################################
# Gradient descent
# ~~~~~~~~~~~~~~~~
#
# As Leon Gatys, the author of the algorithm, suggested
# `here <https://discuss.pytorch.org/t/pytorch-tutorial-for-neural-transfert-of-artistic-style/336/20?u=alexis-jacq>`__,
# we will use L-BFGS algorithm to run our gradient descent. Unlike
# training a network, we want to train the input image in order to
# minimise the content/style losses. We would like to simply create a
# PyTorch L-BFGS optimizer, passing our image as the variable to optimize.
# But ``optim.LBFGS`` takes as first argument a list of PyTorch
# ``Variable`` that require gradient. Our input image is a ``Variable``
# but is not a leaf of the tree that requires computation of gradients. In
# order to show that this variable requires a gradient, a possibility is
# to construct a ``Parameter`` object from the input image. Then, we just
# give a list containing this ``Parameter`` to the optimizer's
# constructor:
#
def get_input_param_optimizer(input_img):
# this line to show that input is a parameter that requires a gradient
input_param = nn.Parameter(input_img.data)
optimizer = optim.LBFGS([input_param])
return input_param, optimizer
######################################################################
# **Last step**: the loop of gradient descent. At each step, we must feed
# the network with the updated input in order to compute the new losses,
# we must run the ``backward`` methods of each loss to dynamically compute
# their gradients and perform the step of gradient descent. The optimizer
# requires as argument a "closure": a function that reevaluates the model
# and returns the loss.
#
# However, there's a small catch. The optimized image may take its values
# between :math:`-\infty` and :math:`+\infty` instead of staying between 0
# and 1. In other words, the image might be well optimized and have absurd
# values. In fact, we must perform an optimization under constraints in
# order to keep having right vaues into our input image. There is a simple
# solution: at each step, to correct the image to maintain its values into
# the 0-1 interval.
#
def run_style_transfer(cnn, content_img, style_img, input_img, num_steps=300,
style_weight=1000, content_weight=1):
"""Run the style transfer."""
print('Building the style transfer model..')
model, style_losses, content_losses = get_style_model_and_losses(cnn,
style_img, content_img, style_weight, content_weight)
input_param, optimizer = get_input_param_optimizer(input_img)
print('Optimizing..')
run = [0]
while run[0] <= num_steps:
def closure():
# correct the values of updated input image
input_param.data.clamp_(0, 1)
optimizer.zero_grad()
model(input_param)
style_score = 0
content_score = 0
for sl in style_losses:
style_score += sl.backward()
for cl in content_losses:
content_score += cl.backward()
run[0] += 1
if run[0] % 50 == 0:
print("run {}:".format(run))
print('Style Loss : {:4f} Content Loss: {:4f}'.format(
style_score.data[0], content_score.data[0]))
print()
return style_score + content_score
optimizer.step(closure)
# a last correction...
input_param.data.clamp_(0, 1)
return input_param.data
######################################################################
# Finally, run the algorithm
output = run_style_transfer(cnn, content_img, style_img, input_img)
plt.figure()
imshow(output, title='Output Image')
# sphinx_gallery_thumbnail_number = 4
plt.ioff()
plt.show()