'Pretrained Tensorflow model RGB -> RGBY channel extension

I am working on the protein analysis project. We receive the images* of proteins with 4 filters (Red, Green, Blue and Yellow). Every of those RGBY channels contains unique data as different cellular structures are visible with different filters.

The idea is to use a pre-trained network e.g. VGG19 and extend the number of channels from default 3 to 4. Something like this:

(My appologies, I am not allowed to add images directly before 10 reputation, please press the "Run code snippet" button to visualize):

<img src="https://i.stack.imgur.com/TZKka.png" alt="Italian Trulli">

Picture: VGG model with RGB extended to RGBY

The Y channel should be the copy of the existing pretrained channel. Then it is possible to make use of the pretrained weights.

Does anyone have an idea of how such extension of a pretrained network can be achieved?

* Author of the collage - Allunia from Kaggle, "Protein Atlas - Exploration and Baseline" kernel.



Solution 1:[1]

Beyond the RGBY case, the following snippet works generally by copying or removing the layer's weights and/or biases vectors dimensions as needed. Please refer to numpy documentation on what numpy.resize does: in the case of the original question it copies the B-channel weights onto the Y-channel (or more generally onto any higher dimensionality).

import numpy as np
import tensorflow as tf
...

model = ...  # your RGBY model is here
pretrained_model = tf.keras.models.load_model(...)  # pretrained RGB model

# the following assumes that the layers match with the two models and
# only the shapes of weights and/or biases are different
for pretrained_layer, layer in zip(pretrained_model.layers, model.layers):
    pretrained = pretrained_layer.get_weights()
    target = layer.get_weights()
    if len(pretrained) == 0:  # skip input, pooling and other no weights layers
        continue
    try:  
        # set the pretrained weights as is whenever possible
        layer.set_weights(pretrained)
    except:
        # numpy.resize to the rescue whenever there is a shape mismatch
        for idx, (l1, l2) in enumerate(zip(pretrained, target)):
            target[idx] = np.resize(l1, l2.shape)

        layer.set_weights(target)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1