'TensorFlow Lambda layer fft2d

I am building a CNN where the input is a grayscale image (256x256x1) and I want to add a Fourier transform layer which should output a shape (256x256x2), with the 2 channels for real and imaginary. I found tf.signal.fft2d on https://www.tensorflow.org/api_docs/python/tf/signal/fft2d . Unfortunately it is hard to find any example or explanation of how to use it concretely... I have tried:

X_input = Input(input_shape,)

X_input_fft=Lambda(lambda v: tf.cast(tf.compat.v1.spectral.rfft2d(v),dtype=tf.float32))(X_input)

l1Conv1 = Conv2D(filters = 16, kernel_size = (5,5), strides = 1, padding ='same',
                 data_format='channels_last',
                 kernel_initializer= initializers.he_normal(seed=None), 
                 bias_initializer='zeros')(X_input_fft)

but honestly I don't know what I am doing ...

Also, for the last layer, I would like to do an inverse fft, something like:

myLastLayer= Lambda(lambda v: tf.cast(tf.compat.v1.spectral.irfft2d(tf.cast(v, dtype=tf.complex64)),dtype=tf.float32))(myBeforeLastLayer)


Solution 1:[1]

I'm sorry that the answer comes 2 years later but I think this will help a lot of people dealing with Tensorflow fft2d

The first thing you should know is that the documentation says that TensorFlow performs the fft2d in "the inner-most 2 dimensions of input", which only means that they perform the fft2 in the last two dimensions. Then you have to permute the input tensor to work with that.

A function that will do the thing you need would be this one.

def fft2d_function(x, dtype = "complex64"):

    x = tf.transpose(x, perm = [2, 0, 1])
    x = tf.cast(x, dtype)
    x_f = tf.signal.fft2d(x)
    x_f = tf.transpose(x_f, perm = [1, 2, 0])
    real_x_f, imag_x_f = tf.math.real(x_f), tf.math.imag(x_f)
    return real_x_f, imag_x_f

or, if you are sure that the input is a real signal you can use rfft2d instead

def rfft2d_function(x):

    x = tf.transpose(x, perm = [2, 0, 1])
    x_f = tf.signal.rfft2d(x)
    x_f = tf.transpose(x_f, perm = [1, 2, 0])
    real_x_f, imag_x_f = tf.math.real(x_f), tf.math.imag(x_f)
    return real_x_f, imag_x_f

Besides, if you want to perform the inverse of these functions would be like this.

def ifft2d_function(x_r_i_tuple):
    real_x_f, imag_x_f = x_r_i_tuple
    x_f = tf.complex(real_x_f, imag_x_f)
    x_f = tf.transpose(x_f, perm = [2, 0, 1])
    x_hat = tf.signal.ifft2d(x_f)
    x_hat = tf.transpose(x_hat, perm = [1, 2, 0])
    return x_hat

def irfft2d_function(x_r_i_tuple):
    real_x_f, imag_x_f = x_r_i_tuple
    x_f = tf.complex(real_x_f, imag_x_f)
    x_f = tf.transpose(x_f, perm = [2, 0, 1])
    x_hat = tf.signal.irfft2d(x_f)
    x_hat = tf.transpose(x_hat, perm = [1, 2, 0])
    return x_hat

To end. an important thing in Fourier is the fftshift. TensorFlow also has a

fourier_x = tf.signal.fftshift(fourier_x)

I hope this answer helps someone dealing with Fourier transform in Tensorflow

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 David Santiago Morales Norato