'Visualisation with GradCam from Monai expects larger input image

I want to visualize the crucial parts which were important for the classification in healthy and ill with GradCAM for 3D MRI images.

Therefore, I use

cam = GradCAM(nn_module=densenet, target_layers="class_layers.relu") 
result = cam(x=torch.rand((1, 1, 7, 7, 7)))

where my densenet is defined as:

self.densenet = densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=1)

this throws the error:

RuntimeError: input image (T: 1 H: 1 W: 1) smaller than kernel size (kT: 2 kH: 2 kW: 2)

Raising H,W,D to 30

result = cam(x=torch.rand((1, 1, 30, 30, 30)))

leads to

Traceback (most recent call last):

File "/var/folders/79/z7g43_0x08g2yj7w6lb_j5280000gn/T/ipykernel_3133/4028907207.py", line 1, in result = cam(x=torch.rand((64, 1, 30, 30, 30))) #result mri image nehmen

File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 380, in call acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx)

File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 360, in compute_map _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph)

File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 135, in call acti = tuple(self.activations[layer] for layer in self.target_layers)

File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 135, in acti = tuple(self.activations[layer] for layer in self.target_layers)

KeyError: 'class_layers.relu'

And using my own densenet class:

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Tue Mar  8 23:23:24 2022

@author: Wu
"""

###methods from pytorch lightening


import torch
import pytorch_lightning as pl
from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score
import torch.nn.functional as F
from torch import nn
import monai.networks.nets.densenet as densenet
import pycm
import numpy as np


class DenseNet(pl.LightningModule):
    def __init__(self, learning_rate=1e-4):
        super().__init__()
        self.densenet = densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=1)
        self.learning_rate = learning_rate
        self.class_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') #binarycorssentropy loss

bceloss print('1')

    def forward(self, x):

        logits = self.densenet(x)

        print('2')

        #float(logits.float())
        #logits = logits.float()
        print(logits.dtype)

        return logits

    def configure_optimizers(self):
        print('3')

        optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate)
        return optimizer

    def on_train_epoch_start(self):
        print('4')

        self.train_predictions = []
        self.train_labels = []
        
        
    def training_step(self, batch, batch_idx):
        print('5')

        mri, y  = batch['mri'], batch['label']
        #y = y.type_as(logits)
        #float(y.float())
        y = y.float()
        print (y.dtype)
        logits = self(mri)
        #float(logits.float())
        logits = logits.float()
        y_hat = (logits >= 0).float()
        class_loss = self.class_loss(logits, y.unsqueeze(0))        
        self.log('loss/train', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.train_predictions.extend(y_hat.tolist())
        self.train_labels.extend(y.tolist())
        return class_loss

    def on_train_epoch_end(self, outputs):
        print('6')

        cm = pycm.ConfusionMatrix(actual_vector=self.train_labels, predict_vector=self.train_predictions)
        self.log('accuracy/train', cm.Overall_ACC)
        if cm.F1_Macro is not 'None':
            self.log('f1_macro/train', cm.F1_Macro)
        if cm.F1_Micro is not 'None':
            self.log('f1_micro/train', cm.F1_Micro)
        #prit(cm)

    def on_validation_epoch_start(self):
        print('7')

        self.predictions = []
        self.labels = []

    def validation_step(self, batch, batch_idx):
        print('8')

        mri, y = batch['mri'], batch['label']
        logits = self(mri)
        float(y.float())
        y = y.float()
        print (y.dtype)
        float(logits.float())
        logits = logits.float()
        
        y_hat = (logits >= 0).float()
        class_loss = self.class_loss(logits, y.unsqueeze(0))
        self.log(f'loss/valid', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True)
        self.predictions.extend(y_hat.tolist())
        self.labels.extend(y.tolist())

    def on_validation_epoch_end(self):
        print('9')

        cm = pycm.ConfusionMatrix(actual_vector=self.labels, predict_vector=self.predictions)
        self.log('accuracy/valid', cm.Overall_ACC)
        if cm.F1_Macro is not 'None':
            self.log('f1_macro/valid', cm.F1_Macro)
        if cm.F1_Micro is not 'None':
            self.log('f1_micro/valid', cm.F1_Micro)

        #print(cm)

The output of densenet.state_dict().keys() is So there is no class_layers.relu. What would be the equivalent layer?

densenet.state_dict().keys() . . . .

'densenet.features.denseblock4.denselayer1.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer1.layers.conv1.weight', 'densenet.features.denseblock4.denselayer1.layers.norm2.weight', 'densenet.features.denseblock4.denselayer1.layers.norm2.bias', 'densenet.features.denseblock4.denselayer1.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer1.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer1.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer1.layers.conv2.weight', 'densenet.features.denseblock4.denselayer2.layers.norm1.weight', 'densenet.features.denseblock4.denselayer2.layers.norm1.bias', 'densenet.features.denseblock4.denselayer2.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer2.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer2.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer2.layers.conv1.weight', 'densenet.features.denseblock4.denselayer2.layers.norm2.weight', 'densenet.features.denseblock4.denselayer2.layers.norm2.bias', 'densenet.features.denseblock4.denselayer2.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer2.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer2.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer2.layers.conv2.weight', 'densenet.features.denseblock4.denselayer3.layers.norm1.weight', 'densenet.features.denseblock4.denselayer3.layers.norm1.bias', 'densenet.features.denseblock4.denselayer3.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer3.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer3.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer3.layers.conv1.weight', 'densenet.features.denseblock4.denselayer3.layers.norm2.weight', 'densenet.features.denseblock4.denselayer3.layers.norm2.bias', 'densenet.features.denseblock4.denselayer3.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer3.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer3.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer3.layers.conv2.weight', 'densenet.features.denseblock4.denselayer4.layers.norm1.weight', 'densenet.features.denseblock4.denselayer4.layers.norm1.bias', 'densenet.features.denseblock4.denselayer4.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer4.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer4.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer4.layers.conv1.weight', 'densenet.features.denseblock4.denselayer4.layers.norm2.weight', 'densenet.features.denseblock4.denselayer4.layers.norm2.bias', 'densenet.features.denseblock4.denselayer4.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer4.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer4.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer4.layers.conv2.weight', 'densenet.features.denseblock4.denselayer5.layers.norm1.weight', 'densenet.features.denseblock4.denselayer5.layers.norm1.bias', 'densenet.features.denseblock4.denselayer5.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer5.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer5.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer5.layers.conv1.weight', 'densenet.features.denseblock4.denselayer5.layers.norm2.weight', 'densenet.features.denseblock4.denselayer5.layers.norm2.bias', 'densenet.features.denseblock4.denselayer5.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer5.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer5.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer5.layers.conv2.weight', 'densenet.features.denseblock4.denselayer6.layers.norm1.weight', 'densenet.features.denseblock4.denselayer6.layers.norm1.bias', 'densenet.features.denseblock4.denselayer6.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer6.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer6.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer6.layers.conv1.weight', 'densenet.features.denseblock4.denselayer6.layers.norm2.weight', 'densenet.features.denseblock4.denselayer6.layers.norm2.bias', 'densenet.features.denseblock4.denselayer6.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer6.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer6.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer6.layers.conv2.weight', 'densenet.features.denseblock4.denselayer7.layers.norm1.weight', 'densenet.features.denseblock4.denselayer7.layers.norm1.bias', 'densenet.features.denseblock4.denselayer7.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer7.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer7.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer7.layers.conv1.weight', 'densenet.features.denseblock4.denselayer7.layers.norm2.weight', 'densenet.features.denseblock4.denselayer7.layers.norm2.bias', 'densenet.features.denseblock4.denselayer7.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer7.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer7.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer7.layers.conv2.weight', 'densenet.features.denseblock4.denselayer8.layers.norm1.weight', 'densenet.features.denseblock4.denselayer8.layers.norm1.bias', 'densenet.features.denseblock4.denselayer8.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer8.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer8.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer8.layers.conv1.weight', 'densenet.features.denseblock4.denselayer8.layers.norm2.weight', 'densenet.features.denseblock4.denselayer8.layers.norm2.bias', 'densenet.features.denseblock4.denselayer8.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer8.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer8.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer8.layers.conv2.weight', 'densenet.features.denseblock4.denselayer9.layers.norm1.weight', 'densenet.features.denseblock4.denselayer9.layers.norm1.bias', 'densenet.features.denseblock4.denselayer9.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer9.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer9.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer9.layers.conv1.weight', 'densenet.features.denseblock4.denselayer9.layers.norm2.weight', 'densenet.features.denseblock4.denselayer9.layers.norm2.bias', 'densenet.features.denseblock4.denselayer9.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer9.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer9.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer9.layers.conv2.weight', 'densenet.features.denseblock4.denselayer10.layers.norm1.weight', 'densenet.features.denseblock4.denselayer10.layers.norm1.bias', 'densenet.features.denseblock4.denselayer10.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer10.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer10.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer10.layers.conv1.weight', 'densenet.features.denseblock4.denselayer10.layers.norm2.weight', 'densenet.features.denseblock4.denselayer10.layers.norm2.bias', 'densenet.features.denseblock4.denselayer10.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer10.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer10.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer10.layers.conv2.weight', 'densenet.features.denseblock4.denselayer11.layers.norm1.weight', 'densenet.features.denseblock4.denselayer11.layers.norm1.bias', 'densenet.features.denseblock4.denselayer11.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer11.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer11.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer11.layers.conv1.weight', 'densenet.features.denseblock4.denselayer11.layers.norm2.weight', 'densenet.features.denseblock4.denselayer11.layers.norm2.bias', 'densenet.features.denseblock4.denselayer11.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer11.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer11.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer11.layers.conv2.weight', 'densenet.features.denseblock4.denselayer12.layers.norm1.weight', 'densenet.features.denseblock4.denselayer12.layers.norm1.bias', 'densenet.features.denseblock4.denselayer12.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer12.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer12.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer12.layers.conv1.weight', 'densenet.features.denseblock4.denselayer12.layers.norm2.weight', 'densenet.features.denseblock4.denselayer12.layers.norm2.bias', 'densenet.features.denseblock4.denselayer12.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer12.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer12.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer12.layers.conv2.weight', 'densenet.features.denseblock4.denselayer13.layers.norm1.weight', 'densenet.features.denseblock4.denselayer13.layers.norm1.bias', 'densenet.features.denseblock4.denselayer13.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer13.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer13.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer13.layers.conv1.weight', 'densenet.features.denseblock4.denselayer13.layers.norm2.weight', 'densenet.features.denseblock4.denselayer13.layers.norm2.bias', 'densenet.features.denseblock4.denselayer13.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer13.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer13.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer13.layers.conv2.weight', 'densenet.features.denseblock4.denselayer14.layers.norm1.weight', 'densenet.features.denseblock4.denselayer14.layers.norm1.bias', 'densenet.features.denseblock4.denselayer14.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer14.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer14.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer14.layers.conv1.weight', 'densenet.features.denseblock4.denselayer14.layers.norm2.weight', 'densenet.features.denseblock4.denselayer14.layers.norm2.bias', 'densenet.features.denseblock4.denselayer14.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer14.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer14.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer14.layers.conv2.weight', 'densenet.features.denseblock4.denselayer15.layers.norm1.weight', 'densenet.features.denseblock4.denselayer15.layers.norm1.bias', 'densenet.features.denseblock4.denselayer15.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer15.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer15.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer15.layers.conv1.weight', 'densenet.features.denseblock4.denselayer15.layers.norm2.weight', 'densenet.features.denseblock4.denselayer15.layers.norm2.bias', 'densenet.features.denseblock4.denselayer15.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer15.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer15.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer15.layers.conv2.weight', 'densenet.features.denseblock4.denselayer16.layers.norm1.weight', 'densenet.features.denseblock4.denselayer16.layers.norm1.bias', 'densenet.features.denseblock4.denselayer16.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer16.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer16.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer16.layers.conv1.weight', 'densenet.features.denseblock4.denselayer16.layers.norm2.weight', 'densenet.features.denseblock4.denselayer16.layers.norm2.bias', 'densenet.features.denseblock4.denselayer16.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer16.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer16.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer16.layers.conv2.weight', 'densenet.features.norm5.weight', 'densenet.features.norm5.bias', 'densenet.features.norm5.running_mean', 'densenet.features.norm5.running_var', 'densenet.features.norm5.num_batches_tracked', 'densenet.class_layers.out.weight', 'densenet.class_layers.out.bias'])



Solution 1:[1]

You may need to pay attention to the usage of import, the following code should help you.

import torch
from monai.networks.nets import DenseNet121
from monai.visualize import GradCAM

model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1)
cam = GradCAM(nn_module=model, target_layers="class_layers.relu") 
result = cam(x=torch.rand(1, 1, 64, 64, 64))

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Linminxiang