'TypeError: ('Keyword argument not understood:', 'query_shape') with keras and tensorflow version 2.6.0, implementing a VIT

I am having this issue with the transformation of my model from ".pb" to ".tflite". The code is the following:

# From PB to h5 -- then load from h5

checkpoint_filepath="drive/MyDrive/Tirocinio/ViT/Model/vit_"+ num_dataset

import os
import tensorflow as tf
from tensorflow.keras.preprocessing import image
from keras.models import load_model


# LOAD THE PB MODEL
New_Model = tf.keras.models.load_model(checkpoint_filepath) # Loading the Tensorflow Saved Model (PB)
#print(New_Model.summary())

# Saving the Model in H5 Format and Loading it (to check if it is same as PB Format)
tf.keras.models.save_model(New_Model, checkpoint_filepath+'/vit_1_model.h5') # Saving the Model in H5 Format

# LOAD h5
loaded_model_from_h5 = load_model(checkpoint_filepath+'/vit_1_model.h5', custom_objects ={'ShiftedPatchTokenization': ShiftedPatchTokenization, 'PatchEncoder':PatchEncoder, 'MultiHeadAttentionLSA':MultiHeadAttentionLSA}) # Loading the H5 Saved Model
print(loaded_model_from_h5.summary())

Giving this error:

/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py:497: CustomMaskWarning: Custom mask layers require a config and must override get_config. When loading, the custom mask layer must be passed to the custom_objects argument.
  category=CustomMaskWarning)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-24-965e827ed6ce> in <module>()
     17 
     18 # LOAD h5
---> 19 loaded_model_from_h5 = load_model(checkpoint_filepath+'/vit_1_model.h5', custom_objects ={'ShiftedPatchTokenization': ShiftedPatchTokenization, 'PatchEncoder':PatchEncoder, 'MultiHeadAttentionLSA':MultiHeadAttentionLSA}) # Loading the H5 Saved Model
     20 print(loaded_model_from_h5.summary())

15 frames
/usr/local/lib/python3.7/dist-packages/keras/utils/generic_utils.py in validate_kwargs(kwargs, allowed_kwargs, error_message)
   1141   for kwarg in kwargs:
   1142     if kwarg not in allowed_kwargs:
-> 1143       raise TypeError(error_message, kwarg)
   1144 
   1145 

TypeError: ('Keyword argument not understood:', 'query_shape')

**This is the implementation of the Visual Transformer **

class ShiftedPatchTokenization(layers.Layer):
    def __init__(
      self,
      image_size=IMAGE_SIZE,
      patch_size=PATCH_SIZE,
      half_patch=PATCH_SIZE//2,
      num_patches=NUM_PATCHES,
      projection_dim=PROJECTION_DIM,
      flatten_patches=None,
      projection=None,
      layer_norm=None,
      vanilla=False,
      **kwargs,
    ):
      super(ShiftedPatchTokenization,self).__init__(**kwargs)
      self.vanilla = vanilla  # Flag to switch to vanilla patch extractor      
      self.image_size = image_size
      self.patch_size = patch_size
      self.half_patch = patch_size // 2 # la divisione con // dà il numero in int()
      self.flatten_patches = layers.Reshape((num_patches, -1))
      self.projection = layers.Dense(units=projection_dim)
      self.layer_norm = layers.LayerNormalization(epsilon=LAYER_NORM_EPS)

    # Override function to avoid error while saving model
    def get_config(self):
      config = super().get_config().copy()
      config.update(
          {
          "image_size": self.image_size,
          "patch_size": self.patch_size,
          "half_patch": self.half_patch,
          "flatten_patches": self.flatten_patches,
          "vanilla": self.vanilla,
          "projection": self.projection,
          "layer_norm": self.layer_norm,
          }
      )
      return config
     


    @classmethod
    def from_config(cls, config):
        return cls(**config)

    def crop_shift_pad(self, images, mode):
        # Build the diagonally shifted images
        if mode == "left-up":
            crop_height = self.half_patch
            crop_width = self.half_patch
            shift_height = 0
            shift_width = 0
        elif mode == "left-down":
            crop_height = 0
            crop_width = self.half_patch
            shift_height = self.half_patch
            shift_width = 0
        elif mode == "right-up":
            crop_height = self.half_patch
            crop_width = 0
            shift_height = 0
            shift_width = self.half_patch
        else:
            crop_height = 0
            crop_width = 0
            shift_height = self.half_patch
            shift_width = self.half_patch

        # Crop the shifted images and pad them
        crop = tf.image.crop_to_bounding_box(
            images,
            offset_height=crop_height,
            offset_width=crop_width,
            target_height=self.image_size - self.half_patch,
            target_width=self.image_size - self.half_patch,
        )
        shift_pad = tf.image.pad_to_bounding_box(
            crop,
            offset_height=shift_height,
            offset_width=shift_width,
            target_height=self.image_size,
            target_width=self.image_size,
        )
        return shift_pad

    def call(self, images):
        if not self.vanilla:
            # Concat the shifted images with the original image
            images = tf.concat(
                [
                    images,
                    self.crop_shift_pad(images, mode="left-up"),
                    self.crop_shift_pad(images, mode="left-down"),
                    self.crop_shift_pad(images, mode="right-up"),
                    self.crop_shift_pad(images, mode="right-down"),
                ],
                axis=-1,
            )
        # Patchify the images and flatten it
        patches = tf.image.extract_patches(
            images=images,
            sizes=[1, self.patch_size, self.patch_size, 1],
            strides=[1, self.patch_size, self.patch_size, 1],
            rates=[1, 1, 1, 1],
            padding="VALID",
        )
        flat_patches = self.flatten_patches(patches)
        if not self.vanilla:
            # Layer normalize the flat patches and linearly project it
            tokens = self.layer_norm(flat_patches)
            tokens = self.projection(tokens)
        else:
            # Linearly project the flat patches
            tokens = self.projection(flat_patches)
        return (tokens, patches)


class PatchEncoder(layers.Layer):
    def __init__(
        self,
        num_patches=NUM_PATCHES,
        projection_dim=PROJECTION_DIM,
        position_embedding=None,
        positions=None,
        **kwargs
    ):
        super(PatchEncoder,self).__init__(**kwargs)
        self.num_patches = num_patches
        self.position_embedding = layers.Embedding(
            input_dim=num_patches, output_dim=projection_dim
        )
        self.positions = tf.range(start=0, limit=self.num_patches, delta=1)

    def get_config(self):
      config = super().get_config().copy()
      config.update({
          "num_patches": self.num_patches,
          "position_embedding": self.position_embedding,
          "positions": self.positions.numpy(),
      })
      return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    def call(self, encoded_patches):
        encoded_positions = self.position_embedding(self.positions)
        encoded_patches = encoded_patches + encoded_positions
        return encoded_patches

class MultiHeadAttentionLSA(tf.keras.layers.MultiHeadAttention):
    def __init__(
        self,
        tau=None, #modificato, prima non c'era
        **kwargs
    ):
        super(MultiHeadAttentionLSA,self).__init__(**kwargs)
        self.tau = tf.Variable(math.sqrt(float(self._key_dim)), trainable=True) # The trainable temperature term. The initial value is the square root of the key dimension.

    def get_config(self):
      config = super().get_config().copy()
      config.update({
          "tau": self.tau.numpy(), #modificato, prima era solo self.tau
      })
      return config

    @classmethod
    def from_config(cls, config):
        return cls(**config)

    def _compute_attention(self, query, key, value, attention_mask=None, training=None):
        query = tf.multiply(query, 1.0 / self.tau)
        attention_scores = tf.einsum(self._dot_product_equation, key, query)
        attention_scores = self._masked_softmax(attention_scores, attention_mask) 
        attention_scores_dropout = self._dropout_layer(
            attention_scores, training=training
        )
        attention_output = tf.einsum(
            self._combine_equation, attention_scores_dropout, value
        )
        return attention_output, attention_scores

def mlp(x, hidden_units, dropout_rate):
    for units in hidden_units:
        x = layers.Dense(units, activation=tf.nn.gelu)(x) 
        x = layers.Dropout(dropout_rate)(x)
    return x

diag_attn_mask = 1 - tf.eye(NUM_PATCHES)
diag_attn_mask = tf.cast([diag_attn_mask], dtype=tf.int8)

I was trying to find a solution changing version of both tensorflow and keras but the VIT net start to be implemented since che 2.6.0 version of tf!



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source