'tf.function is slow to initiate

I am writing a CNN training algorithm in tensorflow from scratch and having some performance issues. I have been searching for the cause of the bad performance and it seems to be a problem with initiating tf.function. The training function looks like this:

def get_train_func(model,loss_fn,optimizer):
    @tf.function
    def train(input,mask):
            import time
            time1 = time.time()       
            with tf.GradientTape() as tape:
                output = model(input,training=True)
                predtime=time.time()-time1      # 0.49s
                time1=time.time()
                output = tf.nn.softmax(output)
                loss = loss_fn(mask,output)
                losstime=time.time()-time1      #0.011s
            time1=time.time()    
            grads = tape.gradient(loss, model.trainable_variables)
            gradtime=time.time()-time1           #0.43s
            time1=time.time()
            optimizer.apply_gradients(zip(grads,model.trainable_variables))
            opttime=time.time()-time1            #0.339s
            return output
    return train

With the main code:

model = model_mySINet_tf.SINet(classes=2,p=2,q=8,chnn=1)
dataset_train,dataset_val = getData()
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_fn = tf.keras.losses.BinaryCrossentropy()
epoch = 5
train_step = get_train_func(model,loss_fn,optimizer)
    for i in range(0,epoch):
    #---------------------Train------------------------------
        for i,(input,mask) in enumerate(dataset_train):
            input = tf.cast(input,tf.float32)          #(24,224,224,3)
            mask  = tf.cast(mask,tf.float32)           #(24,224,224,1)  
            start = time.time()
            output = train_step(input,mask)
            print("TOTAL TIME",time.time()-start)         #5.13s

As you can see the total time is 5.13s but inside the train function the total time is 1.27s. I tried to fix it by creating the get_train_func() and calling it outside the loop but it does not seem to work. Is there a obvious mistake I am making? Grateful for any help!



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source