'self-convolution layer in TensorFlow

I want to implement a layer that gets a tensor of shape (4,400) and returns a tensor of shape(8,400) where the 5th row is a self-convolution of the first row with itself, the 6th row is a self-convolution of the 2nd row with itself, and etc.

I tried to define a python function:

def convolve_tensors(x):
res = []
batch = x.shape[0]
deg = x.shape[-1]
if batch == None:
    batch = BATCH_SIZE
for i in range(batch):
    x_i = x[i,...]
    for k in range(4):
        x_k = tf.expand_dims(x[i, k, :], axis=-1)
        x_k = tf.reshape(x_k, (400, 1, 1))
        x_k_filter = tf.reshape(tf.expand_dims(x[i, k , :], axis=-1), (400, 1,1))
        y_k = tf.nn.conv1d(x_k, x_k_filter, stride=1, padding='SAME')
        y_k = tf.reshape(y_k, (1, 400))
        x_i = tf.concat([x_i, y_k], axis=0)
    res.append(x_i)
x = tf.reshape(tf.keras.backend.concatenate(res),(-1,8,400))
return x

and then calling the layer using:

x = Lambda(convolve_tensors)(x)

I am not sure this the right way for doing that, but it is very slow. any suggestions?



Solution 1:[1]

You could replace your first loop with tf.while_loop. It should be faster:

# Load
import tensorflow as tf
tf.random.set_seed(111)

def _body(i, ta, x, x_shape):
  x_i = x[i,...]
  x_k = tf.expand_dims(x[0, 0, :], axis=-1)
  x_k = tf.reshape(x_k, (400, 1, 1))

  x_k_filter1 = tf.reshape(tf.expand_dims(x[i, 0 , :], axis=-1), (x_shape[-1], 1,1))
  x_k_filter2 = tf.reshape(tf.expand_dims(x[i, 1 , :], axis=-1), (x_shape[-1], 1,1))
  x_k_filter3 = tf.reshape(tf.expand_dims(x[i, 2 , :], axis=-1), (x_shape[-1], 1,1))
  x_k_filter4 = tf.reshape(tf.expand_dims(x[i, 3 , :], axis=-1), (x_shape[-1], 1,1))
  
  ta = ta.write(ta.size(), tf.concat([x_i, 
                   tf.reshape(tf.nn.conv1d(x_k, x_k_filter1, stride=1, padding='SAME'), (1, 400)),
                   tf.reshape(tf.nn.conv1d(x_k, x_k_filter2, stride=1, padding='SAME'), (1, 400)),
                   tf.reshape(tf.nn.conv1d(x_k, x_k_filter3, stride=1, padding='SAME'), (1, 400)),
                   tf.reshape(tf.nn.conv1d(x_k, x_k_filter4, stride=1, padding='SAME'), (1, 400))], axis=0))

  return tf.add(i, 1), ta, x, x_shape

def convolve_tensors(x):
  i = tf.constant(0)
  ta = tf.TensorArray(dtype=tf.float32, size=0, dynamic_size=True)
  x_shape = tf.shape(x)

  while_condition = lambda i, ta, x, x_shape: tf.less(i, x_shape[0])
  _, ta, _, _ = tf.while_loop(while_condition, _body, loop_vars=(i, ta, x, x_shape))
  return tf.reshape(ta.stack(),(-1,8,400)) 

x = tf.random.normal((2, 4, 400))
layer = tf.keras.layers.Lambda(convolve_tensors)
print(layer(x))
tf.Tensor(
[[[ 7.55812705e-01  1.54472649e+00  1.63156021e+00 ...  1.56513906e+00
   -2.32601002e-01 -8.82739127e-01]
  [-1.44203651e+00  2.09071732e+00 -4.37746137e-01 ... -2.72640914e-01
    1.63588750e+00 -2.61778712e-01]
  [ 1.50723392e-02  8.18879828e-02  5.53403437e-01 ... -5.05646825e-01
    1.37793958e+00  1.60156024e+00]
  ...
  [-6.26580775e-01 -1.28060293e+00 -1.35258949e+00 ... -1.29752529e+00
    1.92829952e-01  7.31804848e-01]
  [ 4.74522680e-01  9.69827235e-01  1.02434409e+00 ...  9.82642889e-01
   -1.46034136e-01 -5.54211020e-01]
  [-1.13250780e+00 -2.31461406e+00 -2.44472551e+00 ... -2.34520030e+00
    3.48528713e-01  1.32269394e+00]]

 [[ 2.60277104e+00 -4.02214050e-01  7.82385767e-01 ...  1.00322634e-01
    7.35043705e-01 -1.37642205e+00]
  [-2.49467894e-01  1.63924351e-01 -5.34086585e-01 ... -5.54817200e-01
   -8.27848673e-01 -8.34841058e-02]
  [ 7.12225974e-01 -1.58221856e-01 -9.32399333e-01 ... -2.87807316e-01
    1.03814876e+00 -2.80721009e-01]
  ...
  [ 1.00409901e+00  2.05217290e+00  2.16753173e+00 ...  2.07929111e+00
   -3.09010983e-01 -1.17272103e+00]
  [-8.41171574e-03 -1.71918254e-02 -1.81582291e-02 ... -1.74190048e-02
    2.58870143e-03  9.82432626e-03]
  [-8.13263476e-01 -1.66214406e+00 -1.75557816e+00 ... -1.68410826e+00
    2.50281453e-01  9.49837804e-01]]], shape=(2, 8, 400), dtype=float32)

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 AloneTogether