'TensorFlow 6x slower than PyTorch for simple 2-layer feedforward network

I've defined a very simple feedforward architecture

Dense(50, ReLU) -> Dense(50, ReLU) -> Dense(1, No Activation)

that I'm using to fit a linspace(-5.0, 5.0, batchSize) to a standard normal distribution. I would expect similar performance between PyTorch and TensorFlow for this simple case, but as it turns out, I get around 500 epochs/second with PyTorch but only 80 epochs/second with TensorFlow.

With a batch size of 2048, my GPU sits at around 30% CUDA usage in the PyTorch case and around 25% using PyTorch.

PyTorch version:

import torch
import torch.nn as nn
from torchsummary import summary

import math
from tqdm import trange

batchSize = 2048
epochs = 1000

model = nn.Sequential(
          nn.Linear(1, 50, device='cuda'),
          nn.ReLU(),
          nn.Linear(50, 50, device='cuda'),
          nn.ReLU(),
          nn.Linear(50, 1, device='cuda')
        )

summary(model)
"""
Output:
    
=================================================================
Layer (type:depth-idx)                   Param #
=================================================================
├─Linear: 1-1                            100
├─ReLU: 1-2                              --
├─Linear: 1-3                            2,550
├─ReLU: 1-4                              --
├─Linear: 1-5                            51
=================================================================
Total params: 2,701
Trainable params: 2,701
Non-trainable params: 0
=================================================================
"""

opt = torch.optim.Adam(params=model.parameters(), lr=1e-2)

X = torch.linspace(-5.0, 5.0, batchSize, device='cuda').reshape(-1,1)
Y = (1/(math.sqrt(2*math.pi)))*torch.exp(-X**2/2)
loss = torch.nn.MSELoss()
for _ in trange(epochs):
    opt.zero_grad()
    Ypred = model(X)
    lossVal = loss(Ypred, Y)
    lossVal.backward()
    opt.step()

plt.plot(X.detach().cpu(), Ypred.detach().cpu())

TensorFlow version:

import tensorflow as tf

import math
from tqdm import trange

batchSize = 2048
epochs = 1000

print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
# Output: Num GPUs Available:  1

model = tf.keras.Sequential(layers=(
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dense(50, activation=tf.nn.relu),
    tf.keras.layers.Dense(1, activation=None))
    )



model.build(input_shape=(None, 1))
print(model.summary())
"""
Output:
    
 Layer (type)                Output Shape              Param #   
=================================================================
 dense_65 (Dense)            (None, 50)                100       
                                                                 
 dense_66 (Dense)            (None, 50)                2550      
                                                                 
 dense_67 (Dense)            (None, 1)                 51        
                                                                 
=================================================================
Total params: 2,701
Trainable params: 2,701
Non-trainable params: 0
_________________________________________________________________
"""

X = tf.expand_dims(tf.linspace(-5.0, 5.0, batchSize, axis=0), axis=-1)
Y = (1/(math.sqrt(2*math.pi)))*tf.exp(-X**2/2)

opt = tf.keras.optimizers.Adam(learning_rate=1e-2)
loss = tf.keras.losses.MeanSquaredError()

for _ in trange(epochs):
    with tf.GradientTape() as tape:
        Ypred = model(X)
        lossVal = loss(Y, Ypred)
    gradients = tape.gradient(lossVal, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))

If I start stripping stuff away from the training loop:

100 epochs/sec

for _ in trange(epochs):
    with tf.GradientTape() as tape:
        Ypred = model(X)
        lossVal = loss(Y, Ypred)
    gradients = tape.gradient(lossVal, model.trainable_variables)
    # opt.apply_gradients(zip(gradients, model.trainable_variables))

250 epochs/sec

for _ in trange(epochs):
    with tf.GradientTape() as tape:
        Ypred = model(X)
        lossVal = loss(Y, Ypred)
    # gradients = tape.gradient(lossVal, model.trainable_variables)
    # opt.apply_gradients(zip(gradients, model.trainable_variables))

450 epochs/sec

for _ in trange(epochs):
    with tf.GradientTape() as tape:
        Ypred = model(X)
        # lossVal = loss(Y, Ypred)
    # gradients = tape.gradient(lossVal, model.trainable_variables)
    # opt.apply_gradients(zip(gradients, model.trainable_variables))

The gradients = tape.gradient(lossVal, model.trainable_variables) also causes 2% Copy usage, which of course is very detrimental to performance if I have to keep syncing between GPU and CPU (not sure why this would be needed in the first place here).



Solution 1:[1]

I figured it out, but in case someone else runs into the same thing:

@tf.function
def train_step():
    with tf.GradientTape() as tape:
        Ypred = model(X)
        lossVal = loss(Y, Ypred)
    gradients = tape.gradient(lossVal, model.trainable_variables)
    opt.apply_gradients(zip(gradients, model.trainable_variables))

for _ in trange(epochs):
    train_step()

By making the training step into a @tf.function, I get around 530 epochs/second, an improvement over PyTorch even!

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 user2978125