'Is loss function for high-dimensional datasets supposed to be non-convex?
MSE loss function is convex by definition but when I am generating a high-dimensional dataset then the problem is no more convex.
rng = np.random.RandomState(0)
X, y, coefficients = make_regression(n_samples = 1000, coef = True , n_features = 100, effective_rank = 3,noise = 0.0, n_informative = 100, random_state=rng)
Now splitting it into train, validation and test dataset.
from sklearn.model_selection import train_test_split
def ttv_split(dataX, dataY, train_ratio, validation_ratio, test_ratio):
x_train, x_test, y_train, y_test = train_test_split( dataX, dataY, test_size = round(1 - train_ratio, 2), random_state = 42 )
x_val, x_test, y_val, y_test = train_test_split(x_test, y_test, test_size = test_ratio/(test_ratio + validation_ratio), random_state = 42 )
return x_train, y_train, x_val, y_val, x_test, y_test
x_train, y_train, x_val, y_val, x_test, y_test = ttv_split(X, y, .05, .025, (1-0.075) )
print("TrainingX Shape = ", x_train.shape, "TrainY shape = ", y_train.shape)
print("ValidationX Shape = ", x_val.shape, "ValY shape = ", y_val.shape)
print("TestX Shape = ", x_test.shape, "TestY shape = ", y_test.shape)
#Output :
TrainingX Shape = (50, 100) TrainY shape = (50,)
ValidationX Shape = (25, 100) ValY shape = (25,)
TestX Shape = (925, 100) TestY shape = (925,)
Now, converting into QP and getting eigenvalues :
def quadratic_form(X,Y):
P = np.matmul(X.T,X)
Q = -2*np.array(np.matmul(Y.T,X))
YtY = np.matmul(Y.T,Y)
return P, Q, YtY
# Training, val and Test QP
P_tr, Q_tr, YtY_tr = quadratic_form( x_train , y_train )
P_val, Q_val, YtY_val = quadratic_form( x_val , y_val )
P_test, Q_test, YtY_test = quadratic_form( x_test , y_test )
np.linalg.eigvals(P_tr)
#Output :
array([ 4.95125590e-02+0.00000000e+00j, 3.59197391e-02+0.00000000e+00j,
3.37542388e-02+0.00000000e+00j, 2.42792525e-02+0.00000000e+00j,
1.59022501e-02+0.00000000e+00j, 1.45537102e-02+0.00000000e+00j,
1.07802480e-02+0.00000000e+00j, 9.44592832e-03+0.00000000e+00j,
8.84433990e-03+0.00000000e+00j, 9.23145161e-03+0.00000000e+00j,
7.92091792e-03+0.00000000e+00j, 7.58909877e-03+0.00000000e+00j,
5.83030469e-03+0.00000000e+00j, 5.03155542e-03+0.00000000e+00j,
4.47966869e-03+0.00000000e+00j, 3.74600673e-03+0.00000000e+00j,
3.43228263e-03+0.00000000e+00j, 3.27013758e-03+0.00000000e+00j,
2.82173367e-03+0.00000000e+00j, 2.72670221e-03+0.00000000e+00j,
2.38166848e-03+0.00000000e+00j, 2.31223051e-03+0.00000000e+00j,
1.93671892e-03+0.00000000e+00j, 1.78860097e-03+0.00000000e+00j,
1.68008668e-03+0.00000000e+00j, 1.54299547e-03+0.00000000e+00j,
1.40873741e-03+0.00000000e+00j, 1.09740264e-03+0.00000000e+00j,
1.19998089e-03+0.00000000e+00j, 9.80403331e-04+0.00000000e+00j,
9.03025762e-04+0.00000000e+00j, 8.31922845e-04+0.00000000e+00j,
7.53863100e-04+0.00000000e+00j, 7.11801927e-04+0.00000000e+00j,
6.05375236e-04+0.00000000e+00j, 5.28839466e-04+0.00000000e+00j,
4.93423676e-04+0.00000000e+00j, 4.60982147e-04+0.00000000e+00j,
4.25653579e-04+0.00000000e+00j, 2.04024243e-04+0.00000000e+00j,
2.52868938e-04+0.00000000e+00j, 2.85681188e-04+0.00000000e+00j,
3.54971061e-04+0.00000000e+00j, 3.48320578e-04+0.00000000e+00j,
1.67819656e-04+0.00000000e+00j, 1.36960992e-04+0.00000000e+00j,
7.17807793e-05+0.00000000e+00j, 8.51775301e-05+0.00000000e+00j,
1.02035484e-04+0.00000000e+00j, 1.05242480e-04+0.00000000e+00j,
3.13944338e-18+0.00000000e+00j, -2.55435320e-18+7.36995996e-19j,
-2.55435320e-18-7.36995996e-19j, -2.74984812e-18+0.00000000e+00j,
-2.18137804e-18+1.12544689e-19j, -2.18137804e-18-1.12544689e-19j,
2.32659249e-18+7.44570928e-20j, 2.32659249e-18-7.44570928e-20j,
1.94702401e-18+4.43765321e-19j, 1.94702401e-18-4.43765321e-19j,
-1.77691688e-18+0.00000000e+00j, 1.72290651e-18+0.00000000e+00j,
-1.58792504e-18+0.00000000e+00j, -1.36379048e-18+5.13259206e-19j,
-1.36379048e-18-5.13259206e-19j, -1.31631049e-18+0.00000000e+00j,
1.61107269e-18+1.89156345e-19j, 1.61107269e-18-1.89156345e-19j,
1.13646471e-18+8.04003199e-19j, 1.13646471e-18-8.04003199e-19j,
-1.16234738e-18+3.42680578e-19j, -1.16234738e-18-3.42680578e-19j,
-1.14431437e-18+1.74642682e-19j, -1.14431437e-18-1.74642682e-19j,
1.43019099e-18+2.69165895e-19j, 1.43019099e-18-2.69165895e-19j,
-5.25583651e-19+6.24253440e-19j, -5.25583651e-19-6.24253440e-19j,
-1.33992327e-19+7.06675184e-19j, -1.33992327e-19-7.06675184e-19j,
-8.22843670e-19+1.30454635e-20j, -8.22843670e-19-1.30454635e-20j,
1.04544115e-18+0.00000000e+00j, -6.73238048e-19+0.00000000e+00j,
9.03262709e-19+3.30466756e-19j, 9.03262709e-19-3.30466756e-19j,
5.04446680e-19+5.35740963e-19j, 5.04446680e-19-5.35740963e-19j,
-1.66403419e-19+2.50921451e-19j, -1.66403419e-19-2.50921451e-19j,
-3.72846642e-19+1.15075692e-19j, -3.72846642e-19-1.15075692e-19j,
7.63884776e-19+3.00700070e-19j, 7.63884776e-19-3.00700070e-19j,
6.09980347e-19+0.00000000e+00j, 3.61128263e-19+2.85977381e-19j,
3.61128263e-19-2.85977381e-19j, 1.58873504e-19+0.00000000e+00j,
5.01289831e-19+1.26078990e-19j, 5.01289831e-19-1.26078990e-19j])
I don't know what is it that I am missing here, Some help please :)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|
