PyTorch: A Practical Use Case

Let's dive into how a neural network in PyTorch can be employed to recognize the contents of an image.

PyTorch is a powerful open-source machine learning library in Python, renowned for its applications in deep learning and natural language processing (NLP), especially with neural networks.

The first step involves importing an image into Python:

  1. import urllib.request
  2. url = 'https://upload.wikimedia.org/wikipedia/commons/9/9e/Ginger_european_cat.jpg'
  3. fpath = 'immagine.jpg'
  4. urllib.request.urlretrieve(url, fpath)

For our example, we'll use a photograph of a cat from Wikipedia, but any image will suffice.

Wikipedia

Next up, we import the PIL (Python Imaging Library), also known as Pillow. This versatile library is excellent for opening, manipulating, and saving a wide range of image formats.

One key function here is Image.open(), which allows us to load an image from the file system into a manipulable image object within Python.

  1. from PIL import Image
  2. img = Image.open('immagine.jpg')

The image is now conveniently stored in the variable 'img'.

Our next task is to prep the image for neural network processing, a vital step in our workflow.

Preprocessing, though sometimes monotonous, is essential for successful image analysis. In this stage, we transform the image into a specific tensor format using torchvision.

  1. from torchvision import transforms
  2. transform = transforms.Compose([
  3. transforms.Resize(256),
  4. transforms.CenterCrop(224),
  5. transforms.ToTensor(),
  6. transforms.Normalize(
  7. mean=[0.485, 0.456, 0.406],
  8. std=[0.229, 0.224, 0.225]
  9. )
  10. ])
  11. img_tensor=transform(img)

With the Compose method, we tailor the image to meet the neural network's needs, converting it into a tensor using toTensor() - a multi-dimensional array.

The Normalize() function scales each pixel's value to a range between 0 and 1, ensuring the image is compatible with our classification model and enhancing the accuracy of the subsequent classification task.

Now, the image resides as a tensor in the 'img_tensor' variable.

This tensor, an object of the class torch.Tensor, has 3 channels and dimensions of 224x224 pixels.

print(type(img_tensor), img_tensor.shape)

<class 'torch.Tensor'> torch.Size([3, 224, 224])

In image-based machine learning, it's a standard practice to process data in 'batches'. A model typically expects a batch of data rather than individual items.

If you're working with just one image, you can still fit it into this model framework by creating a single-item batch:

  1. batch=img_tensor.unsqueeze(0)

Here, we're simply adding a dimension to our tensor 'img_tensor' using PyTorch's unsqueeze() function, transforming it into a one-item batch.

print(batch.shape)

The tensor's shape now is 1 × 3 × 224 × 224, indicating:

torch.Size([1, 3, 224, 224])

  • 1 for the batch size (one image)
  • 3 for the number of channels (RGB)
  • 224 × 224 for the image's pixel dimensions.

Now, let's load the model that will process our image:

  1. from torchvision import models
  2. model = models.alexnet(pretrained=True)
  3. device = "cuda" if torch.cuda.is_available() else "cpu"

We set up the model to process our batch of data. The model.eval() function preps the AlexNet model for inference, while model.to(device) sends it to the GPU (if available) or the CPU.

Finally, with model(), we execute the processing:

  1. model.eval()
  2. model.to(device)
  3. y=model(batch.to(device))

The results are stored in the variable 'y'.

Let's examine the output dimensions with y.shape:

print(y.shape)

torch.Size([1, 1000])

This output tells us our input was one batch (the image), and the model generated a thousand potential classifications (1000 outputs).

To pinpoint the most probable classification:

y_max, index = torch.max(y,1)

And now, let's reveal the classified result:

print(index, y_max)

tensor([285]) tensor([18.0390], grad_fn=<MaxBackward0>)

According to our neural network, the image likely belongs to class 285.

Curious about what 'class 285' represents?

print(classes[285])

285: 'Egyptian cat'

The network suggests the image is likely an 'Egyptian cat'.

To explore all the probable classifications, we can enhance our program with additional code:

  1. url='https://raw.githubusercontent.com/joe-papa/pytorch-book/main/files/imagenet_class_labels.txt'
  2. fpath ='imagenet_class_labels.txt'
  3. urllib.request.urlretrieve(url, fpath)
  4.  
  5. with open('imagenet_class_labels.txt') as f:
  6. classes = [line.strip() for line in f.readlines()]
  7.  
  8. prob = torch.nn.functional.softmax(y, dim=1)[0]*100
  9.  
  10. _, indices = torch.sort(y, descending=True)
  11.  
  12. for idx in indices[0][:5]:
  13. print(classes[idx], prob[idx].item())

This additional code reads the classifier's class labels, ranks the model's output classes by probability, and displays them.

Let's look at the final results:

285: 'Egyptian cat', 59.271602630615234
281: 'tabby, tabby cat', 30.12925148010254
904: 'window screen', 2.8015429973602295
478: 'carton', 2.376310110092163
282: 'tiger cat', 1.066372036933899

With a 59.2% probability, the image is identified as an 'Egyptian cat', and with 30.1% as a 'tabby, tabby cat'. This means the model is 89.3% confident the photo is of a cat.




Report a mistake or post a question




FacebookTwitterLinkedinLinkedin