'pyspark explode one-hot encoded vector to each column with proper name
Applying one-hot encoding to multiple categorical column
X_cat = X.select(cat_cols)
str_indexer = [StringIndexer(inputCol=col, outputCol=col+"_si", handleInvalid="skip") for col in cat_cols]
ohe = [OneHotEncoder(inputCol=f"{col}_si", outputCol=f"{col}_ohe", dropLast=True) for col in cat_cols]
# ohe.setDropLast(False) # older version
pl = Pipeline(stages=str_indexer + ohe).fit(X_cat)
X_cat = pl.transform(X_cat)
si_cols = [col_nm for col_nm in X_cat.columns if col_nm.endswith("_si")]
ohe_cols = [col_nm for col_nm in X_cat.columns if col_nm.endswith("ohe")]
X_cat_ohe = X_cat.select(ohe_cols)
gives me
root
|-- workclass_ohe: vector (nullable = true)
|-- education_ohe: vector (nullable = true)
|-- marital-status_ohe: vector (nullable = true)
|-- occupation_ohe: vector (nullable = true)
+-------------+---------------+
|workclass_ohe| education_ohe|
+-------------+---------------+
|(8,[4],[1.0])| (15,[2],[1.0])|
|(8,[1],[1.0])| (15,[2],[1.0])|
|(8,[0],[1.0])| (15,[0],[1.0])|
|(8,[0],[1.0])| (15,[5],[1.0])|
which is basically
workclass_ohe education_ohe
0 (0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0) (0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
1 (0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
2 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
3 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, ...
4 (1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0) (0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ...
I want to explode value in vector to new column with proper name.
desired output
workclass_state-gov workclass_selfemp workclass_private workclass_middle education_1 education_2 education_3
0 0 0 1 0 0 1
0 1 0 1 0 0 1
1 0 0 0 1 0 0
...
From pyspark - Convert sparse vector obtained after one hot encoding into columns
I could add new columns however from X_cat_ohe I cannot figure out which value(ex: state-gov) corresponds to 0th vector, 1st vector and so on...
Solution 1:[1]
Thanks to Dummy Encoding using Pyspark I can extend it to multiple columns.
for col_nm in cat_cols:
category = X.select(col_nm).distinct().rdd.flatMap(lambda x:x).collect()
category = [col_nm + "_" + ct for ct in category]
exprs = [f.when(f.col(col_nm) == ct, 1).otherwise(0)\
.alias(str(ct)) for ct in category]
X = X.select(exprs+X.columns)
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | haneulkim |
