Skip to content

Commit dec3453

Browse files
author
ArturoAmorQ
committed
Prefer make_column_transformer as per INRIA#831
1 parent 52f244a commit dec3453

File tree

1 file changed

+5
-12
lines changed

1 file changed

+5
-12
lines changed

python_scripts/clustering_transformer.py

Lines changed: 5 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,19 +35,14 @@
3535
from sklearn.pipeline import make_pipeline
3636
from sklearn.preprocessing import StandardScaler
3737
from sklearn.linear_model import Ridge
38-
from sklearn.compose import ColumnTransformer
38+
from sklearn.compose import make_column_transformer
3939

4040
data_train, data_test, target_train, target_test = train_test_split(
4141
data, target, test_size=0.2, random_state=0
4242
)
4343
geo_columns = ["Latitude", "Longitude"]
4444
model_drop_geo = make_pipeline(
45-
ColumnTransformer(
46-
[
47-
("geo", "drop", geo_columns),
48-
],
49-
remainder="passthrough",
50-
),
45+
make_column_transformer(("drop", geo_columns), remainder="passthrough"),
5146
StandardScaler(),
5247
Ridge(alpha=1e-12),
5348
)
@@ -129,10 +124,8 @@
129124
from sklearn.pipeline import make_pipeline
130125

131126
model_cluster_geo = make_pipeline(
132-
ColumnTransformer(
133-
[
134-
("geo", KMeans(n_clusters=100), geo_columns),
135-
],
127+
make_column_transformer(
128+
(KMeans(n_clusters=100), geo_columns),
136129
remainder="passthrough",
137130
),
138131
StandardScaler(),
@@ -158,7 +151,7 @@
158151
# %%
159152
from sklearn.model_selection import GridSearchCV
160153

161-
param_name = "columntransformer__geo__n_clusters"
154+
param_name = "columntransformer__kmeans__n_clusters"
162155
param_grid = {param_name: [10, 30, 100, 300, 1_000, 3_000]}
163156
grid_search = GridSearchCV(
164157
model_cluster_geo, param_grid=param_grid, scoring="neg_mean_absolute_error"

0 commit comments

Comments
 (0)