'Neural network: same prediction for different inputs
I am getting the same prediction for different inputs. I am trying to use a regressional neural network. I want to predict values instead of class using neural network. Since data is huge, I am training one example at a time. Here is a simplified version of my code.
list_of_files= Path().cwd().glob("**/**/*S1D_A.fits") # create the list of file
model = Sequential()
model.add(Dense(10000, input_dim=212207, kernel_initializer='normal', activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.compile(loss='mean_squared_error', optimizer='adam')
for file_name in list_of_files:
data=fits.getdata(file_name)
X=data.flux
Y=data.rv
#X is one input example with 212207 values/features
#Y is one output value (float)
if i<6000000: #out of 10000000
model.fit(X.transpose(), Y, epochs=30, batch_size=1, verbose=0)
else:
prediction=model.predict(X.transpose())
I made sure that I am training on different examples and trying predictions on different examples. I am still getting the same prediction value for all testing inputs. I tried a smaller input space instead of 212207 for debugging, but that did not help. The dataset is balanced and shuffled. Values of inputs range from 0 to 0.1 million. I haven't normalised them. values of output vary from -30 to 0. I think I made some mistake in defining the model for regression neural network. Can you please check if the code is correct?
Solution 1:[1]
I think you meant to pass each record from dataset instead of whole dataset. Right now you predict on exactly same data as you train.
This is what you want to execute:
model = Sequential()
model.add(Dense(10000, input_dim=212207, kernel_initializer='normal', activation='relu'))
model.add(Dense(100, activation='relu'))
model.add(Dense(1, kernel_initializer='normal'))
model.compile(loss='mean_squared_error', optimizer='adam')
X = X.transpose()
# train
model.fit(X[:6000000], Y, epochs=30, batch_size=1, verbose=0)
# test
prediction=model.predict(X[6000000:])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
