'Mobilenet as feature backbone to use Resnet18 pretrained model using Pytorch

  I have a resnet18 pretrained model, now I want to change as feature backbone into MobileNet using pytorch , please suggest any optimal way is available or not to implement this.

In the below code I want to use backbone mobilenet instead of resnet as feature extraction


import torch from model.backbone import resnet import numpy as np

class conv_bn_relu(torch.nn.Module): def init(self,in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1,bias=False): super(conv_bn_relu,self).init() self.conv = torch.nn.Conv2d(in_channels,out_channels, kernel_size, stride = stride, padding = padding, dilation = dilation,bias = bias) self.bn = torch.nn.BatchNorm2d(out_channels) self.relu = torch.nn.ReLU()

def forward(self,x):
    x = self.conv(x)
    x = self.bn(x)
    x = self.relu(x)
    return x

class parsingNet(torch.nn.Module): def init(self, size=(288, 800), pretrained=True, backbone='50', cls_dim=(37, 10, 4), use_aux=False): super(parsingNet, self).init()

    self.size = size
    self.w = size[0]
    self.h = size[1]
    self.cls_dim = cls_dim # (num_gridding, num_cls_per_lane, num_of_lanes)
    # num_cls_per_lane is the number of row anchors
    self.use_aux = use_aux
    self.total_dim = np.prod(cls_dim)

    # input : nchw,
    # output: (w+1) * sample_rows * 4 
    self.model = resnet(backbone, pretrained=pretrained)

    if self.use_aux:
        self.aux_header2 = torch.nn.Sequential(
            conv_bn_relu(128, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1),
        self.aux_header3 = torch.nn.Sequential(
            conv_bn_relu(256, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(1024, 128, kernel_size=3, stride=1, padding=1),
        self.aux_header4 = torch.nn.Sequential(
            conv_bn_relu(512, 128, kernel_size=3, stride=1, padding=1) if backbone in ['34','18'] else conv_bn_relu(2048, 128, kernel_size=3, stride=1, padding=1),
        self.aux_combine = torch.nn.Sequential(
            conv_bn_relu(384, 256, 3,padding=2,dilation=2),
            conv_bn_relu(256, 128, 3,padding=2,dilation=2),
            conv_bn_relu(128, 128, 3,padding=2,dilation=2),
            conv_bn_relu(128, 128, 3,padding=4,dilation=4),
            torch.nn.Conv2d(128, cls_dim[-1] + 1,1)
            # output : n, num_of_lanes+1, h, w

    self.cls = torch.nn.Sequential(
        torch.nn.Linear(1800, 2048),
        torch.nn.Linear(2048, self.total_dim),

    self.pool = torch.nn.Conv2d(512,8,1) if backbone in ['34','18'] else torch.nn.Conv2d(2048,8,1)
    # 1/32,2048 channel
    # 288,800 -> 9,40,2048
    # (w+1) * sample_rows * 4
    # 37 * 10 * 4

def forward(self, x):
    # n c h w - > n 2048 sh sw
    # -> n 2048
    x2,x3,fea = self.model(x)
    if self.use_aux:
        x2 = self.aux_header2(x2)
        x3 = self.aux_header3(x3)
        x3 = torch.nn.functional.interpolate(x3,scale_factor = 2,mode='bilinear')
        x4 = self.aux_header4(fea)
        x4 = torch.nn.functional.interpolate(x4,scale_factor = 4,mode='bilinear')
        aux_seg = torch.cat([x2,x3,x4],dim=1)
        aux_seg = self.aux_combine(aux_seg)
        aux_seg = None

    fea = self.pool(fea).view(-1, 1800)

    group_cls = self.cls(fea).view(-1, *self.cls_dim)

    if self.use_aux:
        return group_cls, aux_seg

    return group_cls

def initialize_weights(*models): for model in models: real_init_weights(model) def real_init_weights(m):

if isinstance(m, list):
    for mini_m in m:
    if isinstance(m, torch.nn.Conv2d):    
        torch.nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
        if m.bias is not None:
            torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m, torch.nn.Linear):
        m.weight.data.normal_(0.0, std=0.01)
    elif isinstance(m, torch.nn.BatchNorm2d):
        torch.nn.init.constant_(m.weight, 1)
        torch.nn.init.constant_(m.bias, 0)
    elif isinstance(m,torch.nn.Module):
        for mini_m in m.children():
        print('unkonwn module', m)



This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source