'Visualisation with GradCam from Monai expects larger input image
I want to visualize the crucial parts which were important for the classification in healthy and ill with GradCAM for 3D MRI images.
Therefore, I use
cam = GradCAM(nn_module=densenet, target_layers="class_layers.relu") result = cam(x=torch.rand((1, 1, 7, 7, 7)))
where my densenet is defined as:
self.densenet = densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=1)
this throws the error:
RuntimeError: input image (T: 1 H: 1 W: 1) smaller than kernel size (kT: 2 kH: 2 kW: 2)
Raising H,W,D to 30
result = cam(x=torch.rand((1, 1, 30, 30, 30)))
leads to
Traceback (most recent call last):
File "/var/folders/79/z7g43_0x08g2yj7w6lb_j5280000gn/T/ipykernel_3133/4028907207.py", line 1, in result = cam(x=torch.rand((64, 1, 30, 30, 30))) #result mri image nehmen
File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 380, in call acti_map = self.compute_map(x, class_idx=class_idx, retain_graph=retain_graph, layer_idx=layer_idx)
File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 360, in compute_map _, acti, grad = self.nn_module(x, class_idx=class_idx, retain_graph=retain_graph)
File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 135, in call acti = tuple(self.activations[layer] for layer in self.target_layers)
File "/Users/Wu/opt/anaconda3/lib/python3.9/site-packages/monai/visualize/class_activation_maps.py", line 135, in acti = tuple(self.activations[layer] for layer in self.target_layers)
KeyError: 'class_layers.relu'
And using my own densenet class:
#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ Created on Tue Mar 8 23:23:24 2022 @author: Wu """ ###methods from pytorch lightening import torch import pytorch_lightning as pl from sklearn.metrics import accuracy_score, recall_score, precision_score, f1_score import torch.nn.functional as F from torch import nn import monai.networks.nets.densenet as densenet import pycm import numpy as np class DenseNet(pl.LightningModule): def __init__(self, learning_rate=1e-4): super().__init__() self.densenet = densenet.densenet121(spatial_dims=3, in_channels=1, out_channels=1) self.learning_rate = learning_rate self.class_loss = torch.nn.BCEWithLogitsLoss(reduction='mean') #binarycorssentropy lossbceloss print('1')
def forward(self, x): logits = self.densenet(x) print('2') #float(logits.float()) #logits = logits.float() print(logits.dtype) return logits def configure_optimizers(self): print('3') optimizer = torch.optim.Adam(self.parameters(), lr=self.learning_rate) return optimizer def on_train_epoch_start(self): print('4') self.train_predictions = [] self.train_labels = [] def training_step(self, batch, batch_idx): print('5') mri, y = batch['mri'], batch['label'] #y = y.type_as(logits) #float(y.float()) y = y.float() print (y.dtype) logits = self(mri) #float(logits.float()) logits = logits.float() y_hat = (logits >= 0).float() class_loss = self.class_loss(logits, y.unsqueeze(0)) self.log('loss/train', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.train_predictions.extend(y_hat.tolist()) self.train_labels.extend(y.tolist()) return class_loss def on_train_epoch_end(self, outputs): print('6') cm = pycm.ConfusionMatrix(actual_vector=self.train_labels, predict_vector=self.train_predictions) self.log('accuracy/train', cm.Overall_ACC) if cm.F1_Macro is not 'None': self.log('f1_macro/train', cm.F1_Macro) if cm.F1_Micro is not 'None': self.log('f1_micro/train', cm.F1_Micro) #prit(cm) def on_validation_epoch_start(self): print('7') self.predictions = [] self.labels = [] def validation_step(self, batch, batch_idx): print('8') mri, y = batch['mri'], batch['label'] logits = self(mri) float(y.float()) y = y.float() print (y.dtype) float(logits.float()) logits = logits.float() y_hat = (logits >= 0).float() class_loss = self.class_loss(logits, y.unsqueeze(0)) self.log(f'loss/valid', class_loss, on_step=True, on_epoch=True, prog_bar=True, logger=True) self.predictions.extend(y_hat.tolist()) self.labels.extend(y.tolist()) def on_validation_epoch_end(self): print('9') cm = pycm.ConfusionMatrix(actual_vector=self.labels, predict_vector=self.predictions) self.log('accuracy/valid', cm.Overall_ACC) if cm.F1_Macro is not 'None': self.log('f1_macro/valid', cm.F1_Macro) if cm.F1_Micro is not 'None': self.log('f1_micro/valid', cm.F1_Micro) #print(cm)
The output of densenet.state_dict().keys() is So there is no class_layers.relu. What would be the equivalent layer?
densenet.state_dict().keys() . . . .
'densenet.features.denseblock4.denselayer1.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer1.layers.conv1.weight', 'densenet.features.denseblock4.denselayer1.layers.norm2.weight', 'densenet.features.denseblock4.denselayer1.layers.norm2.bias', 'densenet.features.denseblock4.denselayer1.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer1.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer1.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer1.layers.conv2.weight', 'densenet.features.denseblock4.denselayer2.layers.norm1.weight', 'densenet.features.denseblock4.denselayer2.layers.norm1.bias', 'densenet.features.denseblock4.denselayer2.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer2.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer2.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer2.layers.conv1.weight', 'densenet.features.denseblock4.denselayer2.layers.norm2.weight', 'densenet.features.denseblock4.denselayer2.layers.norm2.bias', 'densenet.features.denseblock4.denselayer2.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer2.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer2.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer2.layers.conv2.weight', 'densenet.features.denseblock4.denselayer3.layers.norm1.weight', 'densenet.features.denseblock4.denselayer3.layers.norm1.bias', 'densenet.features.denseblock4.denselayer3.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer3.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer3.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer3.layers.conv1.weight', 'densenet.features.denseblock4.denselayer3.layers.norm2.weight', 'densenet.features.denseblock4.denselayer3.layers.norm2.bias', 'densenet.features.denseblock4.denselayer3.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer3.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer3.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer3.layers.conv2.weight', 'densenet.features.denseblock4.denselayer4.layers.norm1.weight', 'densenet.features.denseblock4.denselayer4.layers.norm1.bias', 'densenet.features.denseblock4.denselayer4.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer4.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer4.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer4.layers.conv1.weight', 'densenet.features.denseblock4.denselayer4.layers.norm2.weight', 'densenet.features.denseblock4.denselayer4.layers.norm2.bias', 'densenet.features.denseblock4.denselayer4.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer4.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer4.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer4.layers.conv2.weight', 'densenet.features.denseblock4.denselayer5.layers.norm1.weight', 'densenet.features.denseblock4.denselayer5.layers.norm1.bias', 'densenet.features.denseblock4.denselayer5.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer5.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer5.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer5.layers.conv1.weight', 'densenet.features.denseblock4.denselayer5.layers.norm2.weight', 'densenet.features.denseblock4.denselayer5.layers.norm2.bias', 'densenet.features.denseblock4.denselayer5.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer5.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer5.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer5.layers.conv2.weight', 'densenet.features.denseblock4.denselayer6.layers.norm1.weight', 'densenet.features.denseblock4.denselayer6.layers.norm1.bias', 'densenet.features.denseblock4.denselayer6.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer6.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer6.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer6.layers.conv1.weight', 'densenet.features.denseblock4.denselayer6.layers.norm2.weight', 'densenet.features.denseblock4.denselayer6.layers.norm2.bias', 'densenet.features.denseblock4.denselayer6.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer6.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer6.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer6.layers.conv2.weight', 'densenet.features.denseblock4.denselayer7.layers.norm1.weight', 'densenet.features.denseblock4.denselayer7.layers.norm1.bias', 'densenet.features.denseblock4.denselayer7.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer7.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer7.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer7.layers.conv1.weight', 'densenet.features.denseblock4.denselayer7.layers.norm2.weight', 'densenet.features.denseblock4.denselayer7.layers.norm2.bias', 'densenet.features.denseblock4.denselayer7.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer7.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer7.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer7.layers.conv2.weight', 'densenet.features.denseblock4.denselayer8.layers.norm1.weight', 'densenet.features.denseblock4.denselayer8.layers.norm1.bias', 'densenet.features.denseblock4.denselayer8.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer8.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer8.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer8.layers.conv1.weight', 'densenet.features.denseblock4.denselayer8.layers.norm2.weight', 'densenet.features.denseblock4.denselayer8.layers.norm2.bias', 'densenet.features.denseblock4.denselayer8.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer8.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer8.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer8.layers.conv2.weight', 'densenet.features.denseblock4.denselayer9.layers.norm1.weight', 'densenet.features.denseblock4.denselayer9.layers.norm1.bias', 'densenet.features.denseblock4.denselayer9.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer9.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer9.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer9.layers.conv1.weight', 'densenet.features.denseblock4.denselayer9.layers.norm2.weight', 'densenet.features.denseblock4.denselayer9.layers.norm2.bias', 'densenet.features.denseblock4.denselayer9.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer9.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer9.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer9.layers.conv2.weight', 'densenet.features.denseblock4.denselayer10.layers.norm1.weight', 'densenet.features.denseblock4.denselayer10.layers.norm1.bias', 'densenet.features.denseblock4.denselayer10.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer10.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer10.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer10.layers.conv1.weight', 'densenet.features.denseblock4.denselayer10.layers.norm2.weight', 'densenet.features.denseblock4.denselayer10.layers.norm2.bias', 'densenet.features.denseblock4.denselayer10.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer10.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer10.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer10.layers.conv2.weight', 'densenet.features.denseblock4.denselayer11.layers.norm1.weight', 'densenet.features.denseblock4.denselayer11.layers.norm1.bias', 'densenet.features.denseblock4.denselayer11.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer11.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer11.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer11.layers.conv1.weight', 'densenet.features.denseblock4.denselayer11.layers.norm2.weight', 'densenet.features.denseblock4.denselayer11.layers.norm2.bias', 'densenet.features.denseblock4.denselayer11.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer11.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer11.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer11.layers.conv2.weight', 'densenet.features.denseblock4.denselayer12.layers.norm1.weight', 'densenet.features.denseblock4.denselayer12.layers.norm1.bias', 'densenet.features.denseblock4.denselayer12.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer12.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer12.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer12.layers.conv1.weight', 'densenet.features.denseblock4.denselayer12.layers.norm2.weight', 'densenet.features.denseblock4.denselayer12.layers.norm2.bias', 'densenet.features.denseblock4.denselayer12.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer12.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer12.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer12.layers.conv2.weight', 'densenet.features.denseblock4.denselayer13.layers.norm1.weight', 'densenet.features.denseblock4.denselayer13.layers.norm1.bias', 'densenet.features.denseblock4.denselayer13.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer13.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer13.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer13.layers.conv1.weight', 'densenet.features.denseblock4.denselayer13.layers.norm2.weight', 'densenet.features.denseblock4.denselayer13.layers.norm2.bias', 'densenet.features.denseblock4.denselayer13.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer13.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer13.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer13.layers.conv2.weight', 'densenet.features.denseblock4.denselayer14.layers.norm1.weight', 'densenet.features.denseblock4.denselayer14.layers.norm1.bias', 'densenet.features.denseblock4.denselayer14.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer14.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer14.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer14.layers.conv1.weight', 'densenet.features.denseblock4.denselayer14.layers.norm2.weight', 'densenet.features.denseblock4.denselayer14.layers.norm2.bias', 'densenet.features.denseblock4.denselayer14.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer14.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer14.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer14.layers.conv2.weight', 'densenet.features.denseblock4.denselayer15.layers.norm1.weight', 'densenet.features.denseblock4.denselayer15.layers.norm1.bias', 'densenet.features.denseblock4.denselayer15.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer15.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer15.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer15.layers.conv1.weight', 'densenet.features.denseblock4.denselayer15.layers.norm2.weight', 'densenet.features.denseblock4.denselayer15.layers.norm2.bias', 'densenet.features.denseblock4.denselayer15.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer15.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer15.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer15.layers.conv2.weight', 'densenet.features.denseblock4.denselayer16.layers.norm1.weight', 'densenet.features.denseblock4.denselayer16.layers.norm1.bias', 'densenet.features.denseblock4.denselayer16.layers.norm1.running_mean', 'densenet.features.denseblock4.denselayer16.layers.norm1.running_var', 'densenet.features.denseblock4.denselayer16.layers.norm1.num_batches_tracked', 'densenet.features.denseblock4.denselayer16.layers.conv1.weight', 'densenet.features.denseblock4.denselayer16.layers.norm2.weight', 'densenet.features.denseblock4.denselayer16.layers.norm2.bias', 'densenet.features.denseblock4.denselayer16.layers.norm2.running_mean', 'densenet.features.denseblock4.denselayer16.layers.norm2.running_var', 'densenet.features.denseblock4.denselayer16.layers.norm2.num_batches_tracked', 'densenet.features.denseblock4.denselayer16.layers.conv2.weight', 'densenet.features.norm5.weight', 'densenet.features.norm5.bias', 'densenet.features.norm5.running_mean', 'densenet.features.norm5.running_var', 'densenet.features.norm5.num_batches_tracked', 'densenet.class_layers.out.weight', 'densenet.class_layers.out.bias'])
Solution 1:[1]
You may need to pay attention to the usage of import, the following code should help you.
import torch
from monai.networks.nets import DenseNet121
from monai.visualize import GradCAM
model = DenseNet121(spatial_dims=3, in_channels=1, out_channels=1)
cam = GradCAM(nn_module=model, target_layers="class_layers.relu")
result = cam(x=torch.rand(1, 1, 64, 64, 64))
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Linminxiang |
