'Getting `ValueError` when plotting features with colours on python
I have the following data which needs to be linearly classified using least squares. I wanted to visualise my data and then plot the features with colours but I got the following error when assigning the colour colour_cond.
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
Note that data_t is made of 1s and 0s.
import numpy as np
import matplotlib.pyplot as plt
import glob
from scipy.io import loadmat
%matplotlib inline
data = glob.glob('Mydata_A.mat')
data_c1 = np.array([loadmat(entry, variable_names= ("X"), squeeze_me=True)["X"][:,0] for entry in data])
data_c2 = np.array([loadmat(entry, variable_names= ("X"), squeeze_me=True)["X"][:,1] for entry in data])
data_t = np.array([loadmat(entry, variable_names= ("T"), squeeze_me=True)["T"][:] for entry in data])
colour_cond=['red' if t==1 else 'blue' for t in data_t]
plt.scatter(data_c1,data_c2,colour=colour_cond)
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('Training Data (X1,X2)')
plt.show()
Solution 1:[1]
Your problem is that the arrays data_c1, data_c2 and data_t seem to have more that one dimension. In your following line:
colour_cond=['red' if t==1 else 'blue' for t in data_t]
the variable t is not a scalar but a NumPy array, and t == 1 is ambiguous for non-scalar NumPy objects. I would suggest you to ravel (i.e. flatten) all your arrays:
import glob
import numpy as np
import matplotlib.pyplot as plt
from scipy.io import loadmat
%matplotlib inline
data = loadmat('Mydata_A.mat')
data_c1 = np.array([
loadmat(entry, variable_names=("X"), squeeze_me=True)["X"][:, 0]
for entry in entries]).ravel()
data_c2 = np.array([
loadmat(entry, variable_names=("X"), squeeze_me=True)["X"][:, 1]
for entry in entries]).ravel()
data_t = np.array([
loadmat(entry, variable_names=("T"), squeeze_me=True)["T"][:]
for entry in entries]).ravel()
colour_cond = ['red' if t==1 else 'blue' for t in data_t]
plt.scatter(data_c1, data_c2, color=colour_cond)
plt.xlabel('X1')
plt.ylabel('X2')
plt.title('Training Data (X1,X2)')
plt.show()
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | molinav |
