'Fail to quantize custom layer - Quantization Aware Training

I'm following Quantization aware training comprehensive guide and struggling with QAT for custom layers, working with tf=2.6.0, py=3.9.7. Below is a toy example of my problem:

I wrote a simple custom layer that implements Conv2D

class MyConv(tf.keras.layers.Layer):
'''costume conv2d'''

def __init__(self, filt=1, name=None, **kwargs):
    super(MyConv, self).__init__(name=name)
    self.filt = filt
    super(MyConv, self).__init__(**kwargs)

def get_config(self):
    config = super().get_config().copy()
    config.update({"filt": self.filt})
    return config

def build(self, shape):
    self.conv = tf.keras.layers.Conv2D(self.filt, 1, padding="same")

def call(self, input):
    return self.conv(input)

I've created a small model with that layer, then recursively pass over its layers and annotates them using tfmot.guantization.keras.quantize_annotate_layer (each custom layer could have more custom sub-layers that needs to be quantized). Then I apply tfmot.quantization.keras.quantize_apply to the annotated model. The result model consists of all the quantized layers, except of my custom layer, that had not been quantized.

model summary attached

I'll note that when I'm replacing the custom layer MyConv with the code below, as in the comprehensive guide, the quantization works.

def MyConv(tf.keras.layers.Conv2D):
    pass

Please help me solve this issue. Might be some issue with my QuantizeConfig?

Below is my full code:

import tensorflow as tf
import tensorflow_model_optimization as tfmot

class MyConv(tf.keras.layers.Layer):
'''costume conv2d'''

def __init__(self, filt=1, name=None, **kwargs):
    super(MyConv, self).__init__(name=name)
    self.filt = filt
    super(MyConv, self).__init__(**kwargs)

def get_config(self):
    config = super().get_config().copy()
    config.update({"filt": self.filt})
    return config

def build(self, shape):
    self.conv = tfmot.quantization.keras.quantize_annotate_layer(tf.keras.layers.Conv2D(self.filt, 1, padding="same"))

def call(self, input):
    return self.conv(input)


def get_toy_model():
  input = tf.keras.Input((10, 10, 1), name='input')
  x = tf.keras.layers.Conv2D(1, 3, padding="same")(input)
  x = tf.keras.layers.ReLU()(x)
  x = MyConv()(x)
  for _ in range(2):
      y = tf.keras.layers.Conv2D(1, 3, padding="same")(x)
      y = tf.keras.layers.ReLU()(y)
  out = tf.keras.layers.Conv2D(1, 3, padding="same")(y)
  return tf.keras.Model(input, out, name='toy_Conv2D')

LastValueQuantizer = tfmot.quantization.keras.quantizers.LastValueQuantizer
MovingAverageQuantizer = tfmot.quantization.keras.quantizers.MovingAverageQuantizer


class DefaultCostumeQuantizeConfig(tfmot.quantization.keras.QuantizeConfig):
    # Configure how to quantize weights.
    def get_weights_and_quantizers(self, layer):
        return []

    # Configure how to quantize activations.
    def get_activations_and_quantizers(self, layer):
        return []

    def set_quantize_weights(self, layer, quantize_weights):
        pass

    def set_quantize_activations(self, layer, quantize_activations):
        pass

    # Configure how to quantize outputs (may be equivalent to activations).
    def get_output_quantizers(self, layer):
        return [tfmot.quantization.keras.quantizers.MovingAverageQuantizer(num_bits=8, per_axis=False, symmetric=False, narrow_range=False)]

    def get_config(self):
        return {}


def recursive_depth_layers(layer):
  for l in list(layer.__dict__.values()):
    if isinstance(l, tf.keras.layers.Layer):
        recursive_depth_layers(l)
        if isinstance(l, (
        tf.keras.layers.Dense, tf.keras.layers.Conv2D, tf.keras.layers.ReLU, tf.keras.layers.LeakyReLU, tf.keras.layers.Activation)):
            ql = tfmot.quantization.keras.quantize_annotate_layer(l, DefaultCostumeQuantizeConfig())
            ql._name += "_" + l.name
            return ql


def apply_quantization(layer):
    # regular layer
    if isinstance(layer, (tf.keras.layers.Dense, tf.keras.layers.Conv2D, tf.keras.layers.ReLU, tf.keras.layers.LeakyReLU,tf.keras.layers.Activation)):
      l = tfmot.quantization.keras.quantize_annotate_layer(layer, DefaultCostumeQuantizeConfig())
      l._name += '_' + layer.name
      return l
    if layer.__module__ == "__main__":
    # custom layer
      recursive_depth_layers(layer)
      l = tfmot.quantization.keras.quantize_annotate_layer(layer, DefaultCostumeQuantizeConfig())
      l._name += '_' + layer.name
      return l
  return layer

model = get_toy_model()
model.summary()

annotated_model = tf.keras.models.clone_model(model, clone_function=apply_quantization)
annotated_model.summary()

quantize_scope = tfmot.quantization.keras.quantize_scope
with quantize_scope({'DefaultCostumeQuantizeConfig': DefaultCostumeQuantizeConfig, 'MyConv': MyConv}):  
    quant_aware_model = tfmot.quantization.keras.quantize_apply(annotated_model)
    quant_aware_model._name += "_quant"
quant_aware_model.summary()
quant_aware_model.compile()


Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source