'How to gather all client weights at server in TFF?
I am trying to implement a custom aggregation using TFF by changing the code from this tutorial . I would like to rewrite next_fn so that all the client weights are placed at the server for further computations. As federated_collect was removed from tff-nightly, I am trying to do that using federated_aggregate.
This is what I have so far:
def accumulate(x, y):
x.append(y)
return x
def merge(x, y):
x.extend(y)
return y
@tff.federated_computation(federated_server_type, federated_dataset_type)
def next_fn(server_state, federated_dataset):
server_weights_at_client = tff.federated_broadcast(
server_state.trainable_weights)
client_deltas = tff.federated_map(
client_update_fn, (federated_dataset, server_weights_at_client))
z = []
agg_result = tff.federated_aggregate(client_deltas, z,
accumulate=tff.tf_computation(accumulate),
merge=tff.tf_computation(merge),
report=tff.tf_computation(lambda x: x))
new_weights = do_smth_with_result(agg_result)
server_state = tff.federated_map(
server_update_fn, (server_state, new_weights))
return server_state
However this results in the following Exception:
File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 351, in <module>
def next_fn(server_state, federated_dataset):
File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 494, in __call__
wrapped_func = self._strategy(
File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/wrappers/computation_wrapper.py", line 222, in __call__
result = fn_to_wrap(*args, **kwargs)
File "/home/yana/Documents/Uni/Thesis/grufedatt_try.py", line 358, in next_fn
agg_result = tff.federated_aggregate(client_deltas, z,
File "/home/yana/anaconda3/envs/fedenv/lib/python3.9/site-packages/tensorflow_federated/python/core/impl/federated_context/intrinsics.py", line 140, in federated_aggregate
raise TypeError(
TypeError: Expected parameter `accumulate` to be of type (<<<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>), but received (<<>,<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>> -> <<float32[9999,96],float32[96,1024],float32[256,1024],float32[1024],float32[256,96],float32[96]>>) instead.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
