'Tff: define the usage of Tensorflow.take() function
I am trying to mimic the federated learning implementation provided here: Working with tff's clientData in order to understand the code clearly. I reached to this point where I need clarification in.
def preprocess_dataset(dataset):
"""Create batches of 5 examples, and limit to 3 batches."""
def map_fn(input):
return collections.OrderedDict(
x=tf.reshape(input['pixels'], shape=(-1, 784)),
y=tf.cast(tf.reshape(input['label'], shape=(-1, 1)), tf.int64),
)
return dataset.batch(5).map(
map_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE).take(5)
- what does
dataset.batch(5)refer to? are these batches are taking from the data to training and the 3 are for testing? - what does
.take(5)mean?
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
