'Keras fit not training model weights

I am trying to fit a simple histogram model with custom weights and no input. It should fit a histogram for the data generated by:

train_data = [max(0,int(np.round(np.random.randn()*2+5))) for i in range(1000)]

The model is defined by

d = 15
class hist_model(tf.keras.Model):
    def __init__(self):
        super(hist_model,self).__init__()
        self._theta = self.add_weight(shape=[1,d],initializer='zero',trainable=True)
        
    
    def call(self,x):
        return self._theta

The problem I have is that training using model.fit doesn't work: The model weights don't change at all during training. I tried:

model = hist_model()
model.compile(optimizer = tf.keras.optimizers.SGD(learning_rate=1e-2),
                loss="sparse_categorical_crossentropy")
history = model.fit(train_data,train_data,verbose=2,batch_size=1,epochs=10)
model.summary()

Which returns:

Epoch 1/3
1000/1000 - 1s - loss: 2.7080
Epoch 2/3
1000/1000 - 1s - loss: 2.7080
Epoch 3/3
1000/1000 - 1s - loss: 2.7080
Model: "hist_model_17"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
Total params: 15
Trainable params: 15
Non-trainable params: 0
________________________

I tried writing a custom training loop for the same model, it worked well. This is the code for the custom training:

optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
for epoch in range(3):
    running_loss = 0
    for data in train_data:
        with tf.GradientTape() as tape:
            loss_value = loss_fn(data,model(data))
            running_loss += loss_value.numpy()
            grad = tape.gradient(loss_value,model.trainable_weights)
            optimizer.apply_gradients(zip(grad, model.trainable_weights))
    print(f'Epoch {epoch} loss: {loss_value}')

I still don't understand why the fit method doesn't work. What am I missing? Thanks!



Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source