'Torch Stack mean state_dict, deep learning
I am following this tutorial for federated learning.
The author defines
def server_aggregate(global_model, client_models,client_lens):
"""
This function has aggregation method 'wmean'
wmean takes the weighted mean of the weights of models
"""
total = sum(client_lens)
n = len(client_models)
global_dict = global_model.state_dict()
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float()*(n*client_lens[i]/total) for i in range(len(client_models))], 0).mean(0)
global_model.load_state_dict(global_dict)
for model in client_models:
model.load_state_dict(global_model.state_dict())
As you can see, there is this piece of code where he is doing a weighted average of all the values in a list of models called "client_models".
for k in global_dict.keys():
global_dict[k] = torch.stack([client_models[i].state_dict()[k].float()*(n*client_lens[i]/total) for i in range(len(client_models))], 0).mean(0)
My question is, why is there a multiplication by "n" ? isn't it that calling the .mean() function is enough to make an average?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
