'How to to train multi regression output with tf.data.Dataset?
Only the first output parameter is learned to be properly estimated during training of a multi regression output net. Second and subsequent parameters only seem to follow first parameter. It seems, that ground truth for second output parameter is not used during training. How do I shape tf.data.Dataset and input it into model.fit() function so second output parameter is trained?
import tensorflow as tf
import pandas as pd
from tensorflow import keras
from keras import layers
#create dataset from csv
file = pd.read_csv( 'minimalDataset.csv', skipinitialspace = True)
input = file["input"].values
output1 = file["output1"].values
output2 = file["output2"].values
dataset = tf.data.Dataset.from_tensor_slices((input, (output1, output2))).batch(4)
#create multi output regression net
input_layer = keras.Input(shape=(1,))
x = layers.Dense(20, activation="relu")(input_layer)
x = layers.Dense(60, activation="relu")(x)
output_layer = layers.Dense(2)(x)
model = keras.Model(input_layer, output_layer)
model.compile(optimizer="adam", loss="mean_squared_error")
#train model and make prediction (deliberately overfitting to illustrate problem)
model.fit(dataset, epochs=500)
prediction = model.predict(dataset)
minimalDataset.csv and predictions:
| input | output1 | output2 | prediction_output1 | prediction_output2 | |
|---|---|---|---|---|---|
| 0 | -1 | 1 | -0.989956 | -0.989964 | |
| 1 | 2 | 0 | 1.834444 | 1.845085 | |
| 2 | 0 | 2 | 0.640249 | 0.596099 | |
| 3 | 1 | -1 | 0.621426 | 0.646796 |
If I create two independent dense final layers the second parameter is learned accurately but I get two losses:
output_layer = (layers.Dense(1)(x), layers.Dense(1)(x))
Note: I want to use tf.data.Dataset because I build a 20k image/csv with it and do per-element transformations as preprocessing.
Solution 1:[1]
tf.data.Dataset.from_tensor_slices() slices along the first dimension. Because of this the input and output tensors need to be transposed:
dataset = tf.data.Dataset.from_tensor_slices((tf.transpose(input), (tf.transpose([output1, output2])))).batch(4)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | finndus |
