Skip to content

Issue 2338 #2458

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jun 12, 2023
13 changes: 7 additions & 6 deletions advanced_source/neural_style_tutorial.py
Original file line number Diff line number Diff line change
@@ -14,7 +14,7 @@
developed by Leon A. Gatys, Alexander S. Ecker and Matthias Bethge.
Neural-Style, or Neural-Transfer, allows you to take an image and
reproduce it with a new artistic style. The algorithm takes three images,
an input image, a content-image, and a style-image, and changes the input
an input image, a content-image, and a style-image, and changes the input
to resemble the content of the content-image and the artistic style of the style-image.


@@ -70,6 +70,7 @@
# method is used to move tensors or modules to a desired device.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.set_default_device(device)

######################################################################
# Loading the Images
@@ -261,7 +262,7 @@ def forward(self, input):
# network to evaluation mode using ``.eval()``.
#

cnn = models.vgg19(pretrained=True).features.to(device).eval()
cnn = models.vgg19(pretrained=True).features.eval()



@@ -271,8 +272,8 @@ def forward(self, input):
# We will use them to normalize the image before sending it into the network.
#

cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406]).to(device)
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225]).to(device)
cnn_normalization_mean = torch.tensor([0.485, 0.456, 0.406])
cnn_normalization_std = torch.tensor([0.229, 0.224, 0.225])

# create a module to normalize input image so we can easily put it in a
# ``nn.Sequential``
@@ -308,7 +309,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
content_layers=content_layers_default,
style_layers=style_layers_default):
# normalization module
normalization = Normalization(normalization_mean, normalization_std).to(device)
normalization = Normalization(normalization_mean, normalization_std)

# just in order to have an iterable access to or list of content/style
# losses
@@ -373,7 +374,7 @@ def get_style_model_and_losses(cnn, normalization_mean, normalization_std,
#
# ::
#
# input_img = torch.randn(content_img.data.size(), device=device)
# input_img = torch.randn(content_img.data.size())

# add the original input image to the figure:
plt.figure()
14 changes: 7 additions & 7 deletions beginner_source/examples_autograd/polynomial_autograd.py
Original file line number Diff line number Diff line change
@@ -18,23 +18,23 @@
import math

dtype = torch.float
device = torch.device("cpu")
# device = torch.device("cuda:0") # Uncomment this to run on GPU
device = "cuda" if torch.cuda.is_available() else "cpu"
torch.set_default_device(device)

# Create Tensors to hold input and outputs.
# By default, requires_grad=False, which indicates that we do not need to
# compute gradients with respect to these Tensors during the backward pass.
x = torch.linspace(-math.pi, math.pi, 2000, device=device, dtype=dtype)
x = torch.linspace(-math.pi, math.pi, 2000, dtype=dtype)
y = torch.sin(x)

# Create random Tensors for weights. For a third order polynomial, we need
# 4 weights: y = a + b x + c x^2 + d x^3
# Setting requires_grad=True indicates that we want to compute gradients with
# respect to these Tensors during the backward pass.
a = torch.randn((), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((), device=device, dtype=dtype, requires_grad=True)
d = torch.randn((), device=device, dtype=dtype, requires_grad=True)
a = torch.randn((), dtype=dtype, requires_grad=True)
b = torch.randn((), dtype=dtype, requires_grad=True)
c = torch.randn((), dtype=dtype, requires_grad=True)
d = torch.randn((), dtype=dtype, requires_grad=True)

learning_rate = 1e-6
for t in range(2000):
15 changes: 9 additions & 6 deletions recipes_source/recipes/amp_recipe.py
Original file line number Diff line number Diff line change
@@ -76,11 +76,14 @@ def make_model(in_size, out_size, num_layers):
num_batches = 50
epochs = 3

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.set_default_device(device)

# Creates data in default precision.
# The same data is used for both default and mixed precision trials below.
# You don't need to manually change inputs' ``dtype`` when enabling mixed precision.
data = [torch.randn(batch_size, in_size, device="cuda") for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size, device="cuda") for _ in range(num_batches)]
data = [torch.randn(batch_size, in_size) for _ in range(num_batches)]
targets = [torch.randn(batch_size, out_size) for _ in range(num_batches)]

loss_fn = torch.nn.MSELoss().cuda()

@@ -116,7 +119,7 @@ def make_model(in_size, out_size, num_layers):
for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
# Runs the forward pass under ``autocast``.
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
# output is float16 because linear layers ``autocast`` to float16.
assert output.dtype is torch.float16
@@ -151,7 +154,7 @@ def make_model(in_size, out_size, num_layers):

for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
loss = loss_fn(output, target)

@@ -184,7 +187,7 @@ def make_model(in_size, out_size, num_layers):
start_timer()
for epoch in range(epochs):
for input, target in zip(data, targets):
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=use_amp):
with torch.autocast(device_type=device, dtype=torch.float16, enabled=use_amp):
output = net(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
@@ -202,7 +205,7 @@ def make_model(in_size, out_size, num_layers):

for epoch in range(0): # 0 epochs, this section is for illustration only
for input, target in zip(data, targets):
with torch.autocast(device_type='cuda', dtype=torch.float16):
with torch.autocast(device_type=device, dtype=torch.float16):
output = net(input)
loss = loss_fn(output, target)
scaler.scale(loss).backward()
2 changes: 1 addition & 1 deletion recipes_source/recipes/tuning_guide.py
Original file line number Diff line number Diff line change
@@ -357,7 +357,7 @@ def fused_gelu(x):
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Instead of calling ``torch.rand(size).cuda()`` to generate a random tensor,
# produce the output directly on the target device:
# ``torch.rand(size, device=torch.device('cuda'))``.
# ``torch.rand(size, device='cuda')``.
#
# This is applicable to all functions which create new tensors and accept
# ``device`` argument: