'Unexpected behavior in tf.data.Dataset map function
I am working on a problem where I need to apply some transformation to my dataset using the map function that tf.data.Dataset provides. The idea is to apply this transformation that rely on some random number and then chain this transformation with another function.
The idea is something like that:
dataset = tf.data.Dataset.from_tensor_slices([1, 1, 1, 1, 1, 1])
dataset = dataset.map(lambda x: x + tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int32)) #map function is done once
I thought that if I print dataset
twice I should expect the same values, however, the result is the following.
ds = dataset.zip((dataset,dataset))
print(list(ds.as_numpy_iterator()))
#output -> [(8, 2), (2, 1), (8, 9), (2, 2), (6, 7), (2, 2)]
Any clues on how can I get exactly the same values after a .map transformation which relies on random numbers? It seems that the map function is done twice instead of once as I declared in the code snippet.
P.D: Using a random seed does the trick but its just hidding the problem.
[EDITED]
What I need is to perform an operation like this:
ds = tf.data.Dataset.range(1, 10)
y = ds.map(lambda x: my_random_operation(x))
x = ds.map(lambda x: another_random_operation(x))
dataset = tf.data.Dataset.zip((x, y))
As you can see, x depends on y, and it doesn't seem that this behavior is happening. That's why I asked how to apply the random operation first and then apply the map to illustrate this.
Solution 1:[1]
I think it makes a lot more sense to create separate datasets, zip them, and then perform a common operation.
import tensorflow as tf
ds1 = tf.data.Dataset.range(1, 4)
ds2 = tf.data.Dataset.range(4, 8)
ds = tf.data.Dataset.zip((ds1, ds2))
# [(1, 4), (2, 5), (3, 6)]
def add_random_number(a, b):
random_number = tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int64)
return a + random_number, b + random_number
ds = ds.map(add_random_number)
print(list(ds.as_numpy_iterator()))
# [(7, 10), (3, 6), (10, 13)] +6, +1, +7
Solution 2:[2]
I think you have to set a random seed to get deterministic behavior, because when zipping the two datasets, they will be called internally (similar to the python zip
function) triggering the map function:
import tensorflow as tf
tf.random.set_seed(111)
dataset = tf.data.Dataset.from_tensor_slices([1, 1, 1, 1, 1, 1])
dataset = dataset.map(lambda x: x + tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int32))
ds = dataset.zip((dataset, dataset))
print(list(ds.as_numpy_iterator()))
# [(7, 7), (8, 8), (8, 8), (2, 2), (8, 8), (6, 6)]
Update 1
Maybe something like this:
import tensorflow as tf
ds = tf.data.Dataset.range(1, 10)
def my_random_operation(x):
# Or with stateless uniform
return x + tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int64, seed=1)
def another_random_operation(x):
# Or with stateless uniform
return x + tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int64, seed=2)
y = ds.map(lambda x: my_random_operation(x))
x = ds.map(lambda x: another_random_operation(x))
dataset = tf.data.Dataset.zip((x, y))
print(list(dataset.as_numpy_iterator()))
What could also work is something like np.random.randint
which would also be executed only once, since it won't be traced in a tf.Graph
:
import tensorflow as tf
import numpy as np
ds = tf.data.Dataset.range(1, 10)
def my_random_operation(x):
return x + np.random.randint(0, 9)
def another_random_operation(x):
return x + np.random.randint(0, 9)
y = ds.map(lambda x: my_random_operation(x))
x = ds.map(lambda x: another_random_operation(x))
Update 2
Regarding the reason behind this behavior, notice that the variable num_traces
is called 18 times altogether, which corresponds to your 9 tuples:
import tensorflow as tf
ds = tf.data.Dataset.range(1, 10)
num_traces = tf.Variable(0)
def my_random_operation(x):
global num_traces
num_traces.assign_add(1)
return x + tf.random.uniform([], minval=0, maxval=9, dtype=tf.dtypes.int64)
ds = ds.map(lambda x: my_random_operation(x))
dataset = tf.data.Dataset.zip((ds, ds))
print(list(dataset.as_numpy_iterator()))
print(num_traces)
# [(2, 3), (9, 8), (7, 7), (11, 11), (8, 9), (14, 11), (13, 11), (14, 16), (12, 16)]
# <tf.Variable 'Variable:0' shape=() dtype=int32, numpy=18>
This behavior is in accordance with the information in the docs:
When used as an argument to a tf.function, different generator objects will cause retracing of the tf.function.
So one assumption is that retracing is triggered by trying to zip the datasets after the first execution of dataset.map
. The datasets, as mentioned previously, are iterated internally and therefore trigger the map function again because the values have changed since the first map
call. Check this thread also.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
Solution | Source |
---|---|
Solution 1 | Nicolas Gervais |
Solution 2 |