How to use grad-cAM to Interpret your Convolutional Neural Network

Introduction

Deep learning has become increasingly popular in the past few years. While deep learning models may show promising results, the lack of interpretability means that when a modern deep network fails, practitioners are unable to determine the reason why the model has predicted wrongly. Consequently, stakeholders may quickly lose trust in such systems.

To overcome this potential barrier to the mass adoption of modern Artificial Intelligence Systems, various modern techniques are developed to interpret the uninterpretable.

The purpose of this article is to introduce one such method, Gradient-weighted Class Activation Map (Grad-CAM), which is used to explain how modern Convolutional Neural Networks (CNNs) make their decisions.


Grad-CAM

Grad-CAM is a visualization technique that enables users to “see” why a CNN made a particular prediction for a certain class using gradient information encoded in the last convolutional layer with respect to the logits of that particular class c.

For example, in the figure below, we pass in a cat and dog image (a) into our CNN, and if we want to know where our CNN is “looking” for the class cat, we can employ Grad-CAM to feedback to us a heatmap (b) on where the cat class is maximally activated using gradients backflow from the last convolutional layer, and then overlay the heatmap onto the original image (c) to get a coarse localization which highlights regions that represent the cat.

Dog and cat image, with Grad-CAM, overlayed for the class cat.

Grad-CAM Example Walkthrough

We will use an out-of-the-box library PyTorch Grad-CAM to see how we can use Grad-CAM on a pretrained model. The main framework used in this article is PyTorch.

Before we start, we define a global constant named device to indicate whether our system is on GPU or CPU.

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

Initialize Model

We first get our pretrained model, a ResNet-50 with ImageNet weights.

model_name = "resnet50"
model_weights = ResNet50_Weights.IMAGENET1K_V2
def get_model(model_name: str, model_weights: str, device: str) -> nn.Module:
    model = getattr(torchvision.models, model_name)(weights=model_weights)
    model = model.eval()
    model = model.to(device)
    return model

model = get_model(model_name, model_weights, device)

The function get_model does the following:

  1. Initialize a ResNet50 network with pretrained ImageNet weights.
  2. Set the model to eval mode, this is to turn off layers such as dropout and batch norm. Failing to do this will yield inconsistent inference results.
  3. We will also have to equip the model to device.

People who use TensorFlow/Keras’ Model Summary will find the following model summary familiar.

==========================================================================================
Layer (type:depth-idx)                   Output Shape              Param #
==========================================================================================
ResNet                                   [1, 1000]                 --
├─Conv2d: 1-1                            [1, 64, 14, 14]           9,408
├─BatchNorm2d: 1-2                       [1, 64, 14, 14]           128
├─ReLU: 1-3                              [1, 64, 14, 14]           --
├─MaxPool2d: 1-4                         [1, 64, 7, 7]             --
├─Sequential: 1-5                        [1, 256, 7, 7]            --
│    └─Bottleneck: 2-1                   [1, 256, 7, 7]            --
│    │    └─Conv2d: 3-1                  [1, 64, 7, 7]             4,096
│    │    └─BatchNorm2d: 3-2             [1, 64, 7, 7]             128
│    │    └─ReLU: 3-3                    [1, 64, 7, 7]             --
│    │    └─Conv2d: 3-4                  [1, 64, 7, 7]             36,864
│    │    └─BatchNorm2d: 3-5             [1, 64, 7, 7]             128
│    │    └─ReLU: 3-6                    [1, 64, 7, 7]             --
│    │    └─Conv2d: 3-7                  [1, 256, 7, 7]            16,384
│    │    └─BatchNorm2d: 3-8             [1, 256, 7, 7]            512
│    │    └─Sequential: 3-9              [1, 256, 7, 7]            16,896
│    │    └─ReLU: 3-10                   [1, 256, 7, 7]            --
│    └─Bottleneck: 2-2                   [1, 256, 7, 7]            --
│    │    └─Conv2d: 3-11                 [1, 64, 7, 7]             16,384
│    │    └─BatchNorm2d: 3-12            [1, 64, 7, 7]             128
│    │    └─ReLU: 3-13                   [1, 64, 7, 7]             --
│    │    └─Conv2d: 3-14                 [1, 64, 7, 7]             36,864
│    │    └─BatchNorm2d: 3-15            [1, 64, 7, 7]             128
│    │    └─ReLU: 3-16                   [1, 64, 7, 7]             --
│    │    └─Conv2d: 3-17                 [1, 256, 7, 7]            16,384
│    │    └─BatchNorm2d: 3-18            [1, 256, 7, 7]            512
│    │    └─ReLU: 3-19                   [1, 256, 7, 7]            --
│    └─Bottleneck: 2-3                   [1, 256, 7, 7]            --
│    │    └─Conv2d: 3-20                 [1, 64, 7, 7]             16,384
│    │    └─BatchNorm2d: 3-21            [1, 64, 7, 7]             128
│    │    └─ReLU: 3-22                   [1, 64, 7, 7]             --
│    │    └─Conv2d: 3-23                 [1, 64, 7, 7]             36,864
│    │    └─BatchNorm2d: 3-24            [1, 64, 7, 7]             128
│    │    └─ReLU: 3-25                   [1, 64, 7, 7]             --
│    │    └─Conv2d: 3-26                 [1, 256, 7, 7]            16,384
│    │    └─BatchNorm2d: 3-27            [1, 256, 7, 7]            512
│    │    └─ReLU: 3-28                   [1, 256, 7, 7]            --
├─Sequential: 1-6                        [1, 512, 4, 4]            --
│    └─Bottleneck: 2-4                   [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-29                 [1, 128, 7, 7]            32,768
│    │    └─BatchNorm2d: 3-30            [1, 128, 7, 7]            256
│    │    └─ReLU: 3-31                   [1, 128, 7, 7]            --
│    │    └─Conv2d: 3-32                 [1, 128, 4, 4]            147,456
│    │    └─BatchNorm2d: 3-33            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-34                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-35                 [1, 512, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-36            [1, 512, 4, 4]            1,024
│    │    └─Sequential: 3-37             [1, 512, 4, 4]            132,096
│    │    └─ReLU: 3-38                   [1, 512, 4, 4]            --
│    └─Bottleneck: 2-5                   [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-39                 [1, 128, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-40            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-41                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-42                 [1, 128, 4, 4]            147,456
│    │    └─BatchNorm2d: 3-43            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-44                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-45                 [1, 512, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-46            [1, 512, 4, 4]            1,024
│    │    └─ReLU: 3-47                   [1, 512, 4, 4]            --
│    └─Bottleneck: 2-6                   [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-48                 [1, 128, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-49            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-50                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-51                 [1, 128, 4, 4]            147,456
│    │    └─BatchNorm2d: 3-52            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-53                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-54                 [1, 512, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-55            [1, 512, 4, 4]            1,024
│    │    └─ReLU: 3-56                   [1, 512, 4, 4]            --
│    └─Bottleneck: 2-7                   [1, 512, 4, 4]            --
│    │    └─Conv2d: 3-57                 [1, 128, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-58            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-59                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-60                 [1, 128, 4, 4]            147,456
│    │    └─BatchNorm2d: 3-61            [1, 128, 4, 4]            256
│    │    └─ReLU: 3-62                   [1, 128, 4, 4]            --
│    │    └─Conv2d: 3-63                 [1, 512, 4, 4]            65,536
│    │    └─BatchNorm2d: 3-64            [1, 512, 4, 4]            1,024
│    │    └─ReLU: 3-65                   [1, 512, 4, 4]            --
├─Sequential: 1-7                        [1, 1024, 2, 2]           --
│    └─Bottleneck: 2-8                   [1, 1024, 2, 2]           --
│    │    └─Conv2d: 3-66                 [1, 256, 4, 4]            131,072
│    │    └─BatchNorm2d: 3-67            [1, 256, 4, 4]            512
│    │    └─ReLU: 3-68                   [1, 256, 4, 4]            --
│    │    └─Conv2d: 3-69                 [1, 256, 2, 2]            589,824
│    │    └─BatchNorm2d: 3-70            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-71                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-72                 [1, 1024, 2, 2]           262,144
│    │    └─BatchNorm2d: 3-73            [1, 1024, 2, 2]           2,048
│    │    └─Sequential: 3-74             [1, 1024, 2, 2]           526,336
│    │    └─ReLU: 3-75                   [1, 1024, 2, 2]           --
│    └─Bottleneck: 2-9                   [1, 1024, 2, 2]           --
│    │    └─Conv2d: 3-76                 [1, 256, 2, 2]            262,144
│    │    └─BatchNorm2d: 3-77            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-78                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-79                 [1, 256, 2, 2]            589,824
│    │    └─BatchNorm2d: 3-80            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-81                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-82                 [1, 1024, 2, 2]           262,144
│    │    └─BatchNorm2d: 3-83            [1, 1024, 2, 2]           2,048
│    │    └─ReLU: 3-84                   [1, 1024, 2, 2]           --
│    └─Bottleneck: 2-10                  [1, 1024, 2, 2]           --
│    │    └─Conv2d: 3-85                 [1, 256, 2, 2]            262,144
│    │    └─BatchNorm2d: 3-86            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-87                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-88                 [1, 256, 2, 2]            589,824
│    │    └─BatchNorm2d: 3-89            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-90                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-91                 [1, 1024, 2, 2]           262,144
│    │    └─BatchNorm2d: 3-92            [1, 1024, 2, 2]           2,048
│    │    └─ReLU: 3-93                   [1, 1024, 2, 2]           --
│    └─Bottleneck: 2-11                  [1, 1024, 2, 2]           --
│    │    └─Conv2d: 3-94                 [1, 256, 2, 2]            262,144
│    │    └─BatchNorm2d: 3-95            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-96                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-97                 [1, 256, 2, 2]            589,824
│    │    └─BatchNorm2d: 3-98            [1, 256, 2, 2]            512
│    │    └─ReLU: 3-99                   [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-100                [1, 1024, 2, 2]           262,144
│    │    └─BatchNorm2d: 3-101           [1, 1024, 2, 2]           2,048
│    │    └─ReLU: 3-102                  [1, 1024, 2, 2]           --
│    └─Bottleneck: 2-12                  [1, 1024, 2, 2]           --
│    │    └─Conv2d: 3-103                [1, 256, 2, 2]            262,144
│    │    └─BatchNorm2d: 3-104           [1, 256, 2, 2]            512
│    │    └─ReLU: 3-105                  [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-106                [1, 256, 2, 2]            589,824
│    │    └─BatchNorm2d: 3-107           [1, 256, 2, 2]            512
│    │    └─ReLU: 3-108                  [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-109                [1, 1024, 2, 2]           262,144
│    │    └─BatchNorm2d: 3-110           [1, 1024, 2, 2]           2,048
│    │    └─ReLU: 3-111                  [1, 1024, 2, 2]           --
│    └─Bottleneck: 2-13                  [1, 1024, 2, 2]           --
│    │    └─Conv2d: 3-112                [1, 256, 2, 2]            262,144
│    │    └─BatchNorm2d: 3-113           [1, 256, 2, 2]            512
│    │    └─ReLU: 3-114                  [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-115                [1, 256, 2, 2]            589,824
│    │    └─BatchNorm2d: 3-116           [1, 256, 2, 2]            512
│    │    └─ReLU: 3-117                  [1, 256, 2, 2]            --
│    │    └─Conv2d: 3-118                [1, 1024, 2, 2]           262,144
│    │    └─BatchNorm2d: 3-119           [1, 1024, 2, 2]           2,048
│    │    └─ReLU: 3-120                  [1, 1024, 2, 2]           --
├─Sequential: 1-8                        [1, 2048, 1, 1]           --
│    └─Bottleneck: 2-14                  [1, 2048, 1, 1]           --
│    │    └─Conv2d: 3-121                [1, 512, 2, 2]            524,288
│    │    └─BatchNorm2d: 3-122           [1, 512, 2, 2]            1,024
│    │    └─ReLU: 3-123                  [1, 512, 2, 2]            --
│    │    └─Conv2d: 3-124                [1, 512, 1, 1]            2,359,296
│    │    └─BatchNorm2d: 3-125           [1, 512, 1, 1]            1,024
│    │    └─ReLU: 3-126                  [1, 512, 1, 1]            --
│    │    └─Conv2d: 3-127                [1, 2048, 1, 1]           1,048,576
│    │    └─BatchNorm2d: 3-128           [1, 2048, 1, 1]           4,096
│    │    └─Sequential: 3-129            [1, 2048, 1, 1]           2,101,248
│    │    └─ReLU: 3-130                  [1, 2048, 1, 1]           --
│    └─Bottleneck: 2-15                  [1, 2048, 1, 1]           --
│    │    └─Conv2d: 3-131                [1, 512, 1, 1]            1,048,576
│    │    └─BatchNorm2d: 3-132           [1, 512, 1, 1]            1,024
│    │    └─ReLU: 3-133                  [1, 512, 1, 1]            --
│    │    └─Conv2d: 3-134                [1, 512, 1, 1]            2,359,296
│    │    └─BatchNorm2d: 3-135           [1, 512, 1, 1]            1,024
│    │    └─ReLU: 3-136                  [1, 512, 1, 1]            --
│    │    └─Conv2d: 3-137                [1, 2048, 1, 1]           1,048,576
│    │    └─BatchNorm2d: 3-138           [1, 2048, 1, 1]           4,096
│    │    └─ReLU: 3-139                  [1, 2048, 1, 1]           --
│    └─Bottleneck: 2-16                  [1, 2048, 1, 1]           --
│    │    └─Conv2d: 3-140                [1, 512, 1, 1]            1,048,576
│    │    └─BatchNorm2d: 3-141           [1, 512, 1, 1]            1,024
│    │    └─ReLU: 3-142                  [1, 512, 1, 1]            --
│    │    └─Conv2d: 3-143                [1, 512, 1, 1]            2,359,296
│    │    └─BatchNorm2d: 3-144           [1, 512, 1, 1]            1,024
│    │    └─ReLU: 3-145                  [1, 512, 1, 1]            --
│    │    └─Conv2d: 3-146                [1, 2048, 1, 1]           1,048,576
│    │    └─BatchNorm2d: 3-147           [1, 2048, 1, 1]           4,096
│    │    └─ReLU: 3-148                  [1, 2048, 1, 1]           --
├─AdaptiveAvgPool2d: 1-9                 [1, 2048, 1, 1]           --
├─Linear: 1-10                           [1, 1000]                 2,049,000
==========================================================================================
Total params: 25,557,032
Trainable params: 25,557,032
Non-trainable params: 0
Total mult-adds (M): 81.26
==========================================================================================
Input size (MB): 0.01
Forward/backward pass size (MB): 3.21
Params size (MB): 102.23
Estimated Total Size (MB): 105.44
==========================================================================================

This summary can be generated using the torchinfo package.

def torchsummary_wrapper(
    model: nn.Module, input_size: Tuple[int, int, int, int], device: str, **kwargs: Any
) -> torchinfo.model_statistics.ModelStatistics:
    return torchinfo.summary(model, input_size=input_size, device=device, **kwargs)

batch_size = 1
input_size = (batch_size, 3, 28, 28)

model_summary = torchsummary_wrapper(model, input_size=input_size, device=device)
print(model_summary)

Preprocessing

Next, we will need an image to pass into the model. For our purpose, we will use the cat and dog photo as our sample image.

The image URL is defined as cat_dog_url.

cat_dog_url = "https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/both.png"

We will also define a helper function url_to_numpy_and_pil that takes in an image URL and converts the image to its corresponding numpy array and PIL image.

def url_to_numpy_and_pil(url: str) -> Tuple[np.ndarray, PIL.Image.Image]:
    image_pil = PIL.Image.open(urlopen(url)).convert("RGB")
    image_numpy = np.array(image_pil)
    return image_numpy, image_pil

We will call this function and unpack it into two variables, image_numpy and image_pil.

image_numpy, image_pil = url_to_numpy_and_pil(cat_dog_url)

Since we are using PyTorch as our framework, we need to further perform the following preprocessing steps:

  1. We need to transform the image into a torch.Tensor type.
  2. We need to resize the image into our desired shape.
  3. We will also normalize the image to ImageNet mean and standard deviation, a common practice since the ImageNet pretrained model was trained on millions of images, in which our image class(es) (cat and dog) is part of their categories.
  4. We then have to expand 1 dimension of the image to make it a batch of size 1, as the model expects you to feed in batches of images. In our case, the resized image of shape (3, 224, 224) will become (1, 3, 224, 224).
  5. Lastly, we equip our tensor to device1.

We can compress steps 1-3 into a function get_transforms2, and this function will be passed into preprocess, which will first call get_transforms, and subsequently, perform steps 4 and 5.

def get_transforms(image_size: int, mean: List[float], std: List[float]) -> T.Compose:
    return T.Compose(
        [
            T.Resize((image_size, image_size)),
            T.ToTensor(),
            T.Normalize(mean=mean, std=std),
        ]
    )
def preprocess(image: PIL.Image.Image, transforms: T.Compose) -> torch.Tensor:
    image_tensor = transforms(image)
    image_tensor = image_tensor.unsqueeze(0)
    image_tensor = image_tensor.to(device)
    return image_tensor

The following code shows the constants for the image size, mean and standard deviation that are necessary for our transforms. Subsequently, we pass in image_pil and transforms into preprocess to obtain our image_tensor.

image_size = 224
imagenet_mean = [0.485, 0.456, 0.406]
imagenet_std = [0.229, 0.224, 0.225]
transforms = get_transforms(image_size, imagenet_mean, imagenet_std)

image_tensor = preprocess(image_pil, transforms)

Predictions

Now, we can run this pretrained model on the image_tensor.

logits = model(image_tensor)

The model outputs logits of shape (1, 1000), as expected, since the last layer has 1000 output neurons (classes corresponding to ImageNet).

Let’s print out the top 3 prediction logits made by the model.

k = 3
top_k_logits = torch.topk(logits, k, dim=1, largest=True, sorted=True)
>>> print(top_k_logits)
values=tensor([[16.5665, 16.0921, 15.6797]]
indices=tensor([[242, 243, 282]]

Notice that the values churned out are not probabilities! This is by design as PyTorch’s last layer is a Linear layer and not a Softmax layer. We can perform our familiar softmax function on these 1000 logits to squash all logits to be between 0 and 1 so that we can somewhat interpret the softmax outputs as “probabilities” since they sum to 1, but do take note that not all deep neural networks are well calibrated.

Both logits and probabilities will output the same class index as softmax is a monotonic function that preserves the order.

softmax = nn.Softmax(dim=1)
probabilities = softmax(logits)
top_k_probabilities = torch.topk(probabilities, k, dim=1, largest=True, sorted=True)
>>> print(top_k_probabilities)
values=tensor([[0.3500, 0.0242, 0.0171]]
indices=tensor([[243, 242, 281]]

The values of [0.3500, 0.0242, 0.0171] corresponds to the class indexes in ImageNet [243, 242, 281], which means the top three classes for this input image are:

  • The class name for index 243 is a bull mastiff and its probability is 35%.
  • The class name for index 242 is a boxer and its probability is 2.42%.
  • The class name for index 281 is a tabby and its probability is 1.71%.

This means this model is most confident in predicting a bull mastiff (dog) and is less confident in predicting a tabby (cat). One should not be too alarmed by the low “probability” as ImageNet is trained as a multi-class classification problem with over 1000 classes and is not entirely accustomed to simultaneously predicting more than 1 object accurately3.

Display Grad-CAM

We can now use Grad-CAM to visualize where our model is “looking” for the class bull mastiff (dog).

Before that, we need to define some variables to pass into the PyTorch Grad-CAM API.

model: nn.Module
target_layers: List[nn.Module]
device: str
targets: List[int]
image_tensor: torch.Tensor
image_numpy: np.ndarray

We have already defined model, device, image_tensor and image_numpy4, we are left with defining our last convolutional layer5 (target_layers) and the class of choice (targets), which was mentioned earlier in the section Grad-CAM.

target_layers = [model.layer4[-1]]
targets = [243]
image_normalized = image_numpy / 255.

We are now ready to showcase Grad-CAM on the cat and dog image.

@dataclass(frozen=False, init=True)
class GradCamWrapper:
    model: nn.Module
    target_layers: List[nn.Module]
    device: str
    targets: List[int]
    image_tensor: torch.Tensor
    image_numpy: np.ndarray
    reshape_transform: Optional[Callable] = None
    use_cuda: bool = field(init=False)
    target_categories: List[ClassifierOutputTarget] = field(init=False)

    def __post_init__(self) -> None:
        self.use_cuda = self.device == "cuda"
        self.target_categories = [
            ClassifierOutputTarget(target) for target in self.targets
        ]
        self.gradcam = self._init_gradcam_object()

    def _init_gradcam_object(self) -> GradCAM:
        return GradCAM(
            model=self.model,
            target_layers=self.target_layers,
            use_cuda=self.use_cuda,
            reshape_transform=self.reshape_transform,
        )

    def _generate_heatmap(self) -> np.ndarray:
        heatmap = self.gradcam(
            input_tensor=self.image_tensor,
            targets=self.target_categories,
        )
        return heatmap

    def display(self) -> None:
        heatmap = self._generate_heatmap()
        heatmap = heatmap[0, :]
        visualization = show_cam_on_image(self.image_numpy, heatmap, use_rgb=True)

        fig, axes = plt.subplots(figsize=(20, 10), ncols=3)

        axes[0].imshow(self.image_numpy)
        axes[0].axis("off")

        axes[1].imshow(heatmap)
        axes[1].axis("off")

        axes[2].imshow(visualization)
        axes[2].axis("off")

        plt.show()

We just need to pass in the necessary variables defined earlier to the GradCamWrapper object, and then call its display method to visualize the original image, its Grad-CAM heatmap, and the overlayed image.

gradcam = GradCamWrapper(model, target_layers, device, targets, image_tensor, image_normalized)
gradcam.display()
Grad-CAM on the class dog.

To show where the model is looking for the class tabby (cat), we merely need to change the targets to be [281].


End Notes

In this article, we described how Grad-CAM can be useful, its high-level idea as well as walked through snippets of code on how to use it on your own pretrained models with an open-source python package.

Throughout the code walkthrough, we managed to use Grad-CAM to visualize our CNN model and show that the model is focusing at a reasonable place on the image.

This, however, is only one part of the interpretability of a model.

What remains to be told is how we can further use Grad-CAM to perform error analysis to debug models that may have an inherent bias (focusing on the wrong area yet producing good results), and also for models that are just outright wrong (data quality issues). I would suggest having a read on the example “Is it wolf or is it snow”, where the model can achieve great performance for the wrong reasons. The author used another technique, LIME, to showcase how such tools can be used for performing error analysis on your models.

Lastly, for readers who are intrigued by the idea and decide to take their learnings further, one will find the reference section helpful. In addition, users who want to explore the TensorFlow counterpart can look at tf-explain.

The code contained in this section can be found in this Google Colab link, feel free to reach out to me at hongnan@aisingapore.org for clarifications.


References

1. CHOLLET, FRANCOIS. Deep Learning with Python. O’REILLY MEDIA, 2021. Chapter 5.4.3

2. Molnar, C. (2022). Interpretable Machine Learning, Chapter 10.2.3

3. Selvaraju, Ramprasaath R., Michael Cogswell, Abhishek Das, Ramakrishna Vedantam, Devi Parikh, and Dhruv Batra. “Grad-Cam: Visual Explanations from Deep Networks via Gradient-Based Localization.” 2017 IEEE International Conference on Computer Vision (ICCV), 2017. https://doi.org/10.1109/iccv.2017.74.


Footnotes

  1. Note that the image tensors must be on the same device as your models!
  2. We leverage PyTorch’s torchvision.transforms library to help us do the necessary transformations.
  3. One could explore Object Detection or Multi-Label Classification if you want your model to be able to predict multiple objects at once.
  4. This should be normalized by 255 to be between 0 and 1.
  5. Theoretically, one can pass any layer as your target layer.

Author