'Resnet based (Tensorflow Keras) Siamese Model providing `nan` validation loss in training when using TripletHardLoss (Semi too)

I have a model which I built on top of ResNet. I am using 25k Similar type of Images. My images have text as well as some diagram. When I used the Euclidean Distance + Binary loss, I got an accuracy of 95% with Inception but same with Triplet Hard/ Semi Hard Loss gave me nan loss and almost 0 accuracy. Please tell me if there is something wrong with the code structure.

import tensorflow_addons as tfa
from tensorflow.keras.applications.resnet50 import preprocess_input as res50_pre, ResNet50

shape = (224,224,3)
lr = 0.001
loss = tfa.losses.TripletSemiHardLoss()
epochs = 50
batch_size = 128 #254 gives 'log' referenced before assignment error


datagen = ImageDataGenerator(preprocessing_function=res50_pre,validation_split=0.2)

train_data = datagen.flow_from_dataframe(df,x_col='path',y_col='label',class_mode='sparse',target_size=(224,224),
                                         batch_size=batch_size,subset='training',seed=SEED)

val_data = datagen.flow_from_dataframe(df,x_col='path',y_col='label',class_mode='sparse',target_size=(224,224),
                                       batch_size=batch_size,subset='validation',seed=SEED)


base_model = ResNet50(weights='imagenet',input_shape=shape,include_top=False,pooling='avg')
base_model.trainable = True

inputs = keras.Input(shape=shape)
x = base_model(inputs,training=True)
outputs = keras.layers.Lambda(lambda x: tf.math.l2_normalize(x, axis=1))(x) # L2 normalize embeddings
model = keras.Model(inputs, outputs)

for layer in model.layers: # set all the parameters trainable
    layer.trainable = True
    
model.compile(optimizer=tf.keras.optimizers.Adam(lr),loss=loss,metrics=['accuracy'])

history = model.fit(train_data,epochs=epochs,steps_per_epoch=len(train_data)//batch_size,validation_data=val_data,verbose=2)

My group has values like 1,2,3 [Not in order and some missing] which represent the same type of data. I used Sparse after converting the value to str(1), str(3) etc.

My DataFrame looks like this:

enter image description here



Solution 1:[1]

increase batch size to reduce probability of a mini batch not including any triplets.

Edit: I published a package for generating TF/Keras balanced batches to solve this problem https://github.com/ma7555/kerasgen

Sources

This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.

Source: Stack Overflow

Solution Source
Solution 1