'TensorFlow batching: elements with same key in same batch of variable size
I have a pandas dataframe with id keys. For simplicity we can say the id keys are 0-99.
In a second column we have encodings of fixed length K. Each encoding is related to an id key and two or more encodings may be related to the same id key.
Example:
[0, encoding_1] [0, encoding_2] [1, encoding_3] [2, encoding_4] [2, encoding_5]
I'm able to get batches that contain the rows from each unique key and only those:
ds = ds.group_by_window(key_func=lambda elem: tf.cast(elem['id_col'], tf.int64), reduce_func=lambda _, window: window.batch(batch_size), window_size=batch_size )
But this situation is not ideal because I want the batches to contain multiple unique keys, and not just one (contrastive learning is the goal).
How would I get batches that follow this rule: they must be of some minimum size and if an encoding of id key X is in the batch, then so are all other encodings of id key X.
Any idea on how to approach this?
Thanks!
Solution 1:[1]
I think what you are searching for are generators. Keras model.fit() takes generators as input, so you can pass batches with different batch sizes.
What I would do:
- Create a list of same id encodings from your dataframe (e.g. with a for-loop and pop). This should kind of look like this: [Array(encoding1, encoding2), Array(encoding3)]
- Create a generator that yields the next batch from this list. This next batch contains as many entries from the list as you specify in its input
- Optional: Create a dataset with Dataset.from_generators()
This is quite a lot of coding work, so unfortunately I don't have the time to do it, but let me know if you have any specific questions.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | PlzBePython |
