Skip to content

Commit a58279c

Browse files
authored
Image prediction using trained model (#2392)
* Image prediction using trained model * Inference on custom images * Updated the PR following the PEP8 guidelines and made the requested changes --------- Co-authored-by: Svetlana Karslioglu <svekars@fb.com>
1 parent a5376f7 commit a58279c

File tree

1 file changed

+42
-0
lines changed

1 file changed

+42
-0
lines changed

Diff for: beginner_source/transfer_learning_tutorial.py

+42
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@
4444
import matplotlib.pyplot as plt
4545
import time
4646
import os
47+
from PIL import Image
4748
from tempfile import TemporaryDirectory
4849

4950
cudnn.benchmark = True
@@ -337,6 +338,47 @@ def visualize_model(model, num_images=6):
337338
plt.ioff()
338339
plt.show()
339340

341+
342+
######################################################################
343+
# Inference on custom images
344+
# --------------------------
345+
#
346+
# Use the trained model to make predictions on custom images and visualize
347+
# the predicted class labels along with the images.
348+
#
349+
350+
def visualize_model_predictions(model,img_path):
351+
was_training = model.training
352+
model.eval()
353+
354+
img = Image.open(img_path)
355+
img = data_transforms['val'](img)
356+
img = img.unsqueeze(0)
357+
img = img.to(device)
358+
359+
with torch.no_grad():
360+
outputs = model(img)
361+
_, preds = torch.max(outputs, 1)
362+
363+
ax = plt.subplot(2,2,1)
364+
ax.axis('off')
365+
ax.set_title(f'Predicted: {class_names[preds[0]]}')
366+
imshow(img.cpu().data[0])
367+
368+
model.train(mode=was_training)
369+
370+
######################################################################
371+
#
372+
373+
visualize_model_predictions(
374+
model_conv,
375+
img_path='data/hymenoptera_data/val/bees/72100438_73de9f17af.jpg'
376+
)
377+
378+
plt.ioff()
379+
plt.show()
380+
381+
340382
######################################################################
341383
# Further Learning
342384
# -----------------

0 commit comments

Comments
 (0)