'How to evaluate model performance of keras tuner.search?
I'm currently trying to visualise the performance of my prediction model by showing the val_mse in every epoch. The code that used to work for model.fit() doesn't work for tuner.search(). Can anyone provide me with some guide on this. Thank you.
Previous code:
import matplotlib.pyplot as plt
def plot_model(history):
hist = pd.DataFrame (history.history)
hist['epoch'] = history.epoch
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Absolute Error')
plt.plot(hist['epoch'], hist['mae'],
label='Train Error')
plt.plot(hist['epoch'], hist['val_mae'],
label = 'Val Error')
plt.legend()
plt.ylim([0,20])
plt.figure()
plt.xlabel('Epoch')
plt.ylabel('Mean Square Error')
plt.plot (hist['epoch'], hist['mse'],
label='Train Error')
plt.plot (hist['epoch'], hist['val_mse'],
label = 'Val Error')
plt.legend()
plt.ylim([0,400])
plot_model(history)
keras.tuner code:
history = tuner.search(x = normed_train_data,
y = y_train,
epochs = 200,
batch_size=64,
validation_data=(normed_test_data, y_test),
callbacks = [early_stopping])
Solution 1:[1]
Before using tuner.search to search the best model, you need to install and import keras_tuner:
!pip install keras-tuner --upgrade
import keras_tuner as kt
from tensorflow import keras
Then, define the hyperparameter (hp) in the model definition, for instance as below:
def build_model(hp):
model = keras.Sequential()
model.add(keras.layers.Dense(
hp.Choice('units', [8, 16, 32]), # define the hyperparameter
activation='relu'))
model.add(keras.layers.Dense(1, activation='relu'))
model.compile(loss='mse')
return model
Initialize the tuner:
tuner = kt.RandomSearch(build_model,objective='val_loss',max_trials=5)
Now, Start the search and get the best model by using tuner.search:
tuner.search(x = normed_train_data,
y = y_train,
epochs = 200,
batch_size=64,
validation_data=(normed_test_data, y_test),
callbacks = [early_stopping])
best_model = tuner.get_best_models()[0]
Hence, Now you can use this best_model to train and evaluate the model with your dataset and will get a significant decrease in loss.
Please check this link as a reference for more detail.
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 |
