'Extract weights of a tflite model directly on an android studio variable

I'm having some difficulties writing an extract_weights and an initialize function for a tf.Module model that i later convert to tflite.

The idea is that, i want to use this model for on device training. The project architecture is as it follows: -first i create a transfer learning model that will later be used for training -then i upload this model in my android application where i train it using the tflite.Interpreter -the model will be trained federated using a flower server The problem that i have at the moment is that flower needs to colect from each device the weights as ByteBuffers after each training loop, but i don't seem to understand how i could save them in my android application.

These are the methods that i wrote

@tf.function
def extract_weights(self):
    """
    Extracts the traininable weights of the head model as a list of numpy arrays.

    Paramaters:

    Returns:
        Map of extracted weights and biases.
    """
    tmp_dict = {}
    tensor_names = [weight.name for weight in self.head_model.weights]
    tensors_to_save = [weight.read_value() for weight in self.head_model.weights]
    for index, layer in enumerate(tensors_to_save):
        tmp_dict[tensor_names[index]] = layer

    return tmp_dict


@tf.function(input_signature=[SIGNATURE_DICT])
def initialize_weights(self, weights):
    """
    Initializes weights of the head model.

    Paramaters:
        weights : Tensors used for initialization.
    Returns:
        NONE
    """
    tensor_names = [weight.name for weight in self.head_model.weights]
    for i, tensor in enumerate(self.head_model.weights):
        tensor.assign(weights[tensor_names[i]])

To notice that when i instantiate a TransferLearningModel(my model class that implements tf.Module) object and call these to functions i got no problems but when i try to convert them to tflite i get this error:

ValueError: Got a non-Tensor value<tf.Operation 'StatefulPartitionedCall' type=StatefulPartitionedCall>for key 'output_0' in the output of the function __inference_initialize_weights_8582 used to generate the SavedModel signature 'initialize'. Outputs for functions used as signatures must be a ValueError: Got a non-Tensor value<tf.Operation 'StatefulPartitionedCall' type=StatefulPartitionedCall> for key 'output_0' in the output of the function __inference_initialize_weights_8582 used to generate the SavedModel signature 'initialize'. Outputs for functions used as signatures must be a single Tensor, a sequence of Tensors, or a dictionary from string to Tensor.

I understand the error but i don t get why i have to return something when simply initializing the weights of my model.



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source