'How to select only few columns in scikit learn column selector pipeline?
I was reading the scikitlearn tutorial about column transformer. The given example (https://scikit-learn.org/stable/modules/generated/sklearn.compose.make_column_selector.html#sklearn.compose.make_column_selector) works, but when I tried to select only few columns, It gives me error.
MWE
import numpy as np
import pandas as pd
import seaborn as sns
from sklearn.compose import make_column_transformer
from sklearn.compose import make_column_selector
df = sns.load_dataset('tips')
mycols = ['tip','sex']
ct = make_column_transformer(make_column_selector(pattern=mycols)
ct.fit_transform(df)
Required
I want only the select columns in the output.
NOTE
Of course, I know I can do df[mycols], I am looking for scikit learn pipeline example.
Solution 1:[1]
If you don't mind mlxtend, it has built-in transformer for that.
Using mlxtend
from mlxtend.feature_selection import ColumnSelector
pipe = ColumnSelector(mycols)
pipe.fit_transform(df)
For sklearn >= 0.20
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
import seaborn as sns
df = sns.load_dataset('tips')
mycols = ['tip','sex']
pipeline = Pipeline([
("selector", ColumnTransformer([
("selector", "passthrough", mycols)
], remainder="drop"))
])
pipeline.fit_transform(df)
For sklearn < 0.20
from sklearn.base import BaseEstimator, TransformerMixin
from sklearn.pipeline import Pipeline
class FeatureSelector(BaseEstimator, TransformerMixin):
def __init__(self, columns):
self.columns = columns
def fit(self, X, y=None):
return self
def transform(self, X, y=None):
return X[self.columns]
pipeline = Pipeline([('selector', FeatureSelector(columns=mycols))
])
pipeline.fit_transform(df)[:5]
Solution 2:[2]
I'm maybe a bit late, but you can also select columns using sklearn's ColumnTranformer() by setting the transformer to "passthrough" and remainder='drop':
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
pipe = Pipeline([
("selector", ColumnTransformer([
("selector", "passthrough", mycols)
], remainder="drop"))
])
Sources
This article follows the attribution requirements of Stack Overflow and is licensed under CC BY-SA 3.0.
Source: Stack Overflow
| Solution | Source |
|---|---|
| Solution 1 | |
| Solution 2 | Jens |
