'Keras: Mimic PyTorch's conv2d and linear/dense weight initialization?

I am porting a model from PyTorch to Keras/Tensorflow, and I want to make sure I'm using the same algorithm for weight initialization. How do I mimic PyTorch's weight initialization in Keras?



Solution 1:[1]

If you refactor the PyTorch initialization code, you'll find that the weight initialization algorithm is surprisingly simple. The comment in that code is correct; just read that comment and mimic it.

Here's working Keras / Tensorflow code that mimics it:

import tensorflow as tf
from tensorflow.keras import layers

class PytorchInitialization(tf.keras.initializers.VarianceScaling):
    def __init__(self, seed=None):
        super().__init__(
            scale=1 / 3, mode='fan_in', distribution='uniform', seed=seed)

# Conv layer
conv = layers.Conv2D(32, 3, activation="relu", padding="SAME",
                     input_shape=(28, 28, 1),
                     kernel_initializer=PytorchInitialization(),
                     bias_initializer=PytorchInitialization())

# Dense / linear layer
classifier = layers.Dense(10,
                          kernel_initializer=PytorchInitialization(),
                          bias_initializer=PytorchInitialization(),

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1 Marcus