'pass fit_params to the cross_val_score in sklearn, it always return nan
when I pass fit_params to the cross_val_score in sklearn, it always return nan
parameters = {'n_estimators':100}
model = XGBRegressor()
cv_results = cross_val_score(model,
X_train_,
y_train_,
cv=split,
scoring="neg_mean_squared_error",
fit_params=parameters,
n_jobs=-1)
result is: [nan nan nan nan nan]
when I not use the fit_params or pass an empty dict to the fit_params, the result is okay.
But return nan, when pass any parameter to fit_params.
Anyone knows the answer, why?
Solution 1:[1]
You get an error because the parameter is not used in .fit() method, and if you don't provide parameters you don't run into an error.
n_estimators is a parameter used in initializing the model XGBRegressor(), that is you can do:
model = XGBRegressor(n_estimators = 100)
cv_results = cross_val_score(model,
X_train_,
y_train_,
cv=5,
scoring="neg_mean_squared_error")
The parameters are passed to the .fit() method (see help page) are eval_metric, early_stopping_rounds etc, so for example:
model = XGBRegressor(n_estimators = 100)
parameters = {'eval_metric': 'mae'}
cv_results = cross_val_score(model,
X_train_,
y_train_,
cv=5,
scoring="neg_mean_squared_error",
fit_params = parameters)
Solution 2:[2]
@StupidWolf nailed it ...
I was having a similar issue when I was passing a pipeline to cross_val_score, and solved it by passing the model directly, like in the example here by @StupidWolf.
xgb_model = XGBClassifier(**hyperparams,use_label_encoder =False, n_jobs=4)
fit_params={'early_stopping_rounds': 5,
'eval_metric': ['logloss'],
'verbose': False,
'eval_set': [(X_valid, y_valid)]
}
# Multiply by -1 since sklearn calculates *negative* RMSE
scores = -1 * cross_val_score(xgb_model, X_train, y_train,
cv=5,
scoring='neg_root_mean_squared_error',
fit_params = fit_params,
verbose=1)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | J R |
