'Replicate Keras tutorial without caching
I am trying to replicate this tutorial on the official Keras website. The tutorial is about transfer-learning and it is a guided example on how to use a pre-trained model on the famous cats vs. dogs dataset.
My question is related to the part where they do caching and adjust the buffer size which is carried out as follows:
batch_size = 32
train_ds = train_ds.cache().batch(batch_size).prefetch(buffer_size=10)
validation_ds = validation_ds.cache().batch(batch_size).prefetch(buffer_size=10)
test_ds = test_ds.cache().batch(batch_size).prefetch(buffer_size=10)
If I skip this part, I am not able to replicate the tutorial anymore because I get an error where I fit the model. The error reads this:
ValueError: Input 0 is incompatible with layer model_7: expected shape=(None, 150, 150, 3), found shape=(150, 150, 3)
QUESTION
What modification do I need to implement to run the training without worrying about caching and related stuff?
Solution 1:[1]
To remove the caching and prefetching (which uses a caching mechanism) from this code, simply remove those methods from the input pipeline, as follows:
batch_size = 32
train_ds = train_ds.batch(batch_size)
validation_ds = validation_ds.batch(batch_size)
test_ds = test_ds.batch(batch_size)
The remaining code from the tutorial will then work as normal, except the dataset won't use caching.
As others have pointed out, the .batch() method is independent of caching, but is required to arrange the dataset into batches which are fed to the model sequentially. When you skipped this entire code block, the error you received was to do with the .batch() method being skipped.
.batch() adds an outer dimension to the dataset, which the standard keras models and layers (like the ones used in that tutorial) expect as input. This is why you got the error "Expected shape=(None, 150, 150, 3), found shape=(150, 150, 3)".
You can read more information about the chaining of methods here, about batching and the .batch() method here, the .cache() method here and the .prefetch() method here.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | Angus Maiden |
