'DeepLab V3 Plus Custom Model Implementation

I am trying to implement deeplabv3plus in pytorch. This is my code for creating deeplabv3plus head. But I am getting an Index error.

low_level_feature = self.project( feature['low_level'] ) IndexError: too many indices for tensor of dimension 4

def _segm_model(name, backbone_name, num_classes, pretrained_backbone=True):
aspp_dilate = [6, 12, 18]
backbone = resnet.__dict__[backbone_name](
    pretrained=pretrained_backbone,
    replace_stride_with_dilation=[False, False, True])

inplanes = 2048
low_level_planes = 256

if name == 'deeplabv3plus':
    return_layers = {'layer4': 'out', 'layer1': 'low_level'}
    classifier = DeepLabHeadV3Plus(inplanes, low_level_planes, num_classes, aspp_dilate)
backbone = IntermediateLayerGetter(backbone, return_layers=return_layers)

model = deeplabv3.DeepLabV3(backbone, classifier)
return model

class DeepLabHeadV3Plus(nn.Module):
def __init__(self, in_channels, low_level_channels, num_classes, aspp_dilate=[12, 24, 36]):
    super(DeepLabHeadV3Plus, self).__init__()
    self.project = nn.Sequential( 
        nn.Conv2d(low_level_channels, 48, 1, bias=False),
        nn.BatchNorm2d(48),
        nn.ReLU(inplace=True),
    )

    self.aspp = ASPP(in_channels, aspp_dilate)

    self.classifier = nn.Sequential(
        nn.Conv2d(304, 256, 3, padding=1, bias=False),
        nn.BatchNorm2d(256),
        nn.ReLU(inplace=True),
        nn.Conv2d(256, num_classes, 1)
    )
    self._init_weight()

def forward(self, feature):
    low_level_feature = self.project( feature['low_level'] )
    output_feature = self.aspp(feature['out'])
    output_feature = F.interpolate(output_feature, size=low_level_feature.shape[2:], mode='bilinear', align_corners=False)
    return self.classifier( torch.cat( [ low_level_feature, output_feature ], dim=1 ) )

def _init_weight(self):
    for m in self.modules():
        if isinstance(m, nn.Conv2d):
            nn.init.kaiming_normal_(m.weight)
        elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
            nn.init.constant_(m.weight, 1)
            nn.init.constant_(m.bias, 0)


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source