'Mismatched number of elements between type spec and value in `to_representation_for_type`. Type spec has 2 elements, value has 5
I use tensorflow fedprox to implement federated learning.(tff.learning.algorithms.build_unweighted_fed_prox)
def model_fn():
keras_model = create_keras_model()
return tff.learning.from_keras_model(
keras_model,
input_spec=preprocessed_example_dataset.element_spec,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)
iterative_process = tff.learning.algorithms.build_unweighted_fed_prox(
model_fn, 0.001,
client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=0.001),
server_optimizer_fn=lambda: tf.keras.optimizers.SGD(learning_rate=1.0)
)
import nest_asyncio
nest_asyncio.apply()
state = iterative_process.initialize()
for round in range(3, 11):
state = iterative_process.next(state.state, federated_train_data)
print('round {:2d}, metrics={}'.format(round, state.metrics))
and the result of training is:
round 3, 'sparse_categorical_accuracy'= 0.6435834
round 4, 'sparse_categorical_accuracy'= 0.6955319
round 5, 'sparse_categorical_accuracy'= 0.74295634
round 6, 'sparse_categorical_accuracy'= 0.78176934
round 7, 'sparse_categorical_accuracy'= 0.80838746
round 8, 'sparse_categorical_accuracy'= 0.8300672
round 9, 'sparse_categorical_accuracy'= 0.8486338
round 10, 'sparse_categorical_accuracy', 0.86639416
but when I want to evaluate my model on test data I get error:
evaluation = tff.learning.build_federated_evaluation(model_fn)
test_metrics = evaluation(state.state, federated_test_data)
TypeError: Mismatched number of elements between type spec and value in `to_representation_for_type`. Type spec has 2 elements, value has 5.
How do I fix it?
Solution 1:[1]
Your evaluation method expects tff.learning.ModelWeights, but you are providing the entire state, which is a bigger structure, including the model weights under global_model_weights attribute. So, this could work:
test_metrics = evaluation(state.state.global_model_weights, federated_test_data)
Side note, assigning the return value of iterative_process.next to Python variable state can become very confusing, as it contains state of the program and metrics, which leads you to the use state.state
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Jakub Konecny |
