'How to change activation layer in Pytorch pretrained module?

How to change the activation layer of a Pytorch pretrained network? Here is my code :

print("All modules")
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
        child=nn.SELU()
        print(child)
print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)

Here is my output:

All modules
ReLU(inplace=True)
Before changing activation
ReLU(inplace=True)
SELU()
after changing activation
ReLU(inplace=True)


Solution 1:[1]

._modules solves the problem for me.

for name,child in net.named_children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        net._modules['relu'] = nn.SELU()

Solution 2:[2]

I'm assuming you use module interface nn.ReLU to create the acitvation layer instead of using functional interface F.relu. If so, setattr works for me.

import torch
import torch.nn as nn

# This function will recursively replace all relu module to selu module. 
def replace_relu_to_selu(model):
    for child_name, child in model.named_children():
        if isinstance(child, nn.ReLU):
            setattr(model, child_name, nn.SELU())
        else:
            replace_relu_to_selu(child)

########## A toy example ##########
net = nn.Sequential(
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True),
            nn.Conv2d(3, 32, kernel_size=3, stride=1),
            nn.ReLU(inplace=True)
          )

########## Test ##########
print('Before changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# Before changing activation
# ReLU(inplace=True)
# ReLU(inplace=True)


print('after changing activation')
for child in net.children():
    if isinstance(child,nn.ReLU) or isinstance(child,nn.SELU):
        print(child)
# after changing activation
# SELU()
# SELU(

Solution 3:[3]

here is a general function for replacing any layer

def replace_layers(model, old, new):
    for n, module in model.named_children():
        if len(list(module.children())) > 0:
            ## compound module, go inside it
            replace_layers(module, old, new)
            
        if isinstance(module, old):
            ## simple module
            setattr(model, n, new)

replace_layer(model, nn.ReLU, nn.ReLU6())

I struggled with it for a few days. So, I did some digging & wrote a kaggle notebook explaining how different types of layers / modules are accessed in pytorch.

Solution 4:[4]

I will provide a more general solution that works for any layer (and avoids other issues like modifying a dictionary as you loop through it or when there are recursive nn.modules inside each other).

def replace_bn(module, name):
    '''
    Recursively put desired batch norm in nn.module module.

    set module = net to start code.
    '''
    # go through all attributes of module nn.module (e.g. network or layer) and put batch norms if present
    for attr_str in dir(module):
        target_attr = getattr(m, attr_str)
        if type(target_attr) == torch.nn.BatchNorm2d:
            print('replaced: ', name, attr_str)
            new_bn = torch.nn.BatchNorm2d(target_attr.num_features, target_attr.eps, target_attr.momentum, target_attr.affine,
                                          track_running_stats=False)
            setattr(module, attr_str, new_bn)

    # iterate through immediate child modules. Note, the recursion is done by our code no need to use named_modules()
    for name, immediate_child_module in module.named_children():
        replace_bn(immediate_child_module, name)

replace_bn(model, 'model')

the crux is that you need to recursively keep changing the layers (mainly because sometimes you will encounter attributes that have modules itself). I think better code than the above would be to add another if statement (after the batch norm) detecting if you have to recurse and recursing if so. The above works to but first changes the batch norm over the outer layer (i.e. the first loop) and then with another loop making sure no other object that should be recursed is missed (and then recursing).

Original post: https://discuss.pytorch.org/t/how-to-modify-a-pretrained-model/60509/10

credits: https://discuss.pytorch.org/t/replacing-convs-modules-with-custom-convs-then-notimplementederror/17736/3?u=brando_miranda

Solution 5:[5]

Works fine for me with default pytorch API:

def replace_layer(module: nn.Module, old: nn.Module, new: nn.Module):
    for name, m in module.named_children():
        if isinstance(m, old):
            setattr(module, name, new)

model.apply(lambda m: replace_layer(m, nn.Hardswish, HardSwish(True)))

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Hamdard
Solution 2 zihaozhihao
Solution 3 Ankur Singh
Solution 4 Charlie Parker
Solution 5 RedEyed