diff --git a/beginner_source/fgsm_tutorial.py b/beginner_source/fgsm_tutorial.py index fa23680496c..b58137ad738 100644 --- a/beginner_source/fgsm_tutorial.py +++ b/beginner_source/fgsm_tutorial.py @@ -160,25 +160,33 @@ class Net(nn.Module): def __init__(self): super(Net, self).__init__() - self.conv1 = nn.Conv2d(1, 10, kernel_size=5) - self.conv2 = nn.Conv2d(10, 20, kernel_size=5) - self.conv2_drop = nn.Dropout2d() - self.fc1 = nn.Linear(320, 50) - self.fc2 = nn.Linear(50, 10) + self.conv1 = nn.Conv2d(1, 32, 3, 1) + self.conv2 = nn.Conv2d(32, 64, 3, 1) + self.dropout1 = nn.Dropout(0.25) + self.dropout2 = nn.Dropout(0.5) + self.fc1 = nn.Linear(9216, 128) + self.fc2 = nn.Linear(128, 10) def forward(self, x): - x = F.relu(F.max_pool2d(self.conv1(x), 2)) - x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2)) - x = x.view(-1, 320) - x = F.relu(self.fc1(x)) - x = F.dropout(x, training=self.training) + x = self.conv1(x) + x = F.relu(x) + x = self.conv2(x) + x = F.relu(x) + x = F.max_pool2d(x, 2) + x = self.dropout1(x) + x = torch.flatten(x, 1) + x = self.fc1(x) + x = F.relu(x) + x = self.dropout2(x) x = self.fc2(x) - return F.log_softmax(x, dim=1) + output = F.log_softmax(x, dim=1) + return output # MNIST Test dataset and dataloader declaration test_loader = torch.utils.data.DataLoader( datasets.MNIST('../data', train=False, download=True, transform=transforms.Compose([ transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) ])), batch_size=1, shuffle=True) @@ -190,7 +198,7 @@ def forward(self, x): model = Net().to(device) # Load the pretrained model -model.load_state_dict(torch.load(pretrained_model, map_location='cpu')) +model.load_state_dict(torch.load(pretrained_model, map_location=device)) # Set the model in evaluation mode. In this case this is for the Dropout layers model.eval() @@ -225,6 +233,26 @@ def fgsm_attack(image, epsilon, data_grad): # Return the perturbed image return perturbed_image +# restores the tensors to their original scale +def denorm(batch, mean=[0.1307], std=[0.3081]): + """ + Convert a batch of tensors to their original scale. + + Args: + batch (torch.Tensor): Batch of normalized tensors. + mean (torch.Tensor or list): Mean used for normalization. + std (torch.Tensor or list): Standard deviation used for normalization. + + Returns: + torch.Tensor: batch of tensors without normalization applied to them. + """ + if isinstance(mean, list): + mean = torch.tensor(mean).to(device) + if isinstance(std, list): + std = torch.tensor(std).to(device) + + return batch * std.view(1, -1, 1, 1) + mean.view(1, -1, 1, 1) + ###################################################################### # Testing Function @@ -279,11 +307,17 @@ def test( model, device, test_loader, epsilon ): # Collect ``datagrad`` data_grad = data.grad.data + # Restore the data to its original scale + data_denorm = denorm(data) + # Call FGSM Attack - perturbed_data = fgsm_attack(data, epsilon, data_grad) + perturbed_data = fgsm_attack(data_denorm, epsilon, data_grad) + + # Reapply normalization + perturbed_data_normalized = transforms.Normalize((0.1307,), (0.3081,))(perturbed_data) # Re-classify the perturbed image - output = model(perturbed_data) + output = model(perturbed_data_normalized) # Check for success final_pred = output.max(1, keepdim=True)[1] # get the index of the max log-probability