Skip to content

Commit 96878ba

Browse files
authored
FIX Ensure that index is correct when global transform_output is pandas (scikit-learn#26454)
1 parent e5df5fe commit 96878ba

File tree

5 files changed

+28
-9
lines changed

5 files changed

+28
-9
lines changed

doc/whats_new/v1.3.rst

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,9 @@ Changelog
401401
- |Enhancement| Added the parameter `fill_value` to :class:`impute.IterativeImputer`.
402402
:pr:`25232` by :user:`Thijs van Weezel <ValueInvestorThijs>`.
403403

404+
- |Fix| :class:`impute.IterativeImputer` now correctly preserves the Pandas
405+
Index when the `set_config(transform_output="pandas")`. :pr:`26454` by `Thomas Fan`_.
406+
404407
:mod:`sklearn.inspection`
405408
.........................
406409

@@ -444,6 +447,12 @@ Changelog
444447
on linearly separable problems.
445448
:pr:`25214` by `Tom Dupre la Tour`_.
446449

450+
:mod:`sklearn.manifold`
451+
.......................
452+
453+
- |Fix| :class:`manifold.Isomap` now correctly preserves the Pandas
454+
Index when the `set_config(transform_output="pandas")`. :pr:`26454` by `Thomas Fan`_.
455+
447456
:mod:`sklearn.metrics`
448457
......................
449458

@@ -636,6 +645,9 @@ Changelog
636645
The `sample_interval_` attribute is deprecated and will be removed in 1.5.
637646
:pr:`25190` by :user:`Vincent Maladière <Vincent-Maladiere>`.
638647

648+
- |Fix| :class:`preprocessing.PowerTransformer` now correctly preserves the Pandas
649+
Index when the `set_config(transform_output="pandas")`. :pr:`26454` by `Thomas Fan`_.
650+
639651
- |Fix| :class:`preprocessing.PowerTransformer` now correcly raises error when
640652
using `method="box-cox"` on data with a constant `np.nan` column.
641653
:pr:`26400` by :user:`Yao Xiao <Charlie-XIAO>`.

sklearn/impute/_iterative.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -627,7 +627,7 @@ def _initial_imputation(self, X, in_fit=False):
627627
strategy=self.initial_strategy,
628628
fill_value=self.fill_value,
629629
keep_empty_features=self.keep_empty_features,
630-
)
630+
).set_output(transform="default")
631631
X_filled = self.initial_imputer_.fit_transform(X)
632632
else:
633633
X_filled = self.initial_imputer_.transform(X)

sklearn/manifold/_isomap.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def _fit_transform(self, X):
235235
tol=self.tol,
236236
max_iter=self.max_iter,
237237
n_jobs=self.n_jobs,
238-
)
238+
).set_output(transform="default")
239239

240240
if self.n_neighbors is not None:
241241
nbg = kneighbors_graph(

sklearn/preprocessing/_data.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3180,7 +3180,7 @@ def _fit(self, X, y=None, force_transform=False):
31803180
X[:, i] = transform_function(X[:, i], self.lambdas_[i])
31813181

31823182
if self.standardize:
3183-
self._scaler = StandardScaler(copy=False)
3183+
self._scaler = StandardScaler(copy=False).set_output(transform="default")
31843184
if force_transform:
31853185
X = self._scaler.fit_transform(X)
31863186
else:

sklearn/utils/estimator_checks.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4424,7 +4424,7 @@ def _output_from_fit_transform(transformer, name, X, df, y):
44244424
return outputs
44254425

44264426

4427-
def _check_generated_dataframe(name, case, outputs_default, outputs_pandas):
4427+
def _check_generated_dataframe(name, case, index, outputs_default, outputs_pandas):
44284428
import pandas as pd
44294429

44304430
X_trans, feature_names_default = outputs_default
@@ -4434,7 +4434,12 @@ def _check_generated_dataframe(name, case, outputs_default, outputs_pandas):
44344434
# We always rely on the output of `get_feature_names_out` of the
44354435
# transformer used to generate the dataframe as a ground-truth of the
44364436
# columns.
4437-
expected_dataframe = pd.DataFrame(X_trans, columns=feature_names_pandas, copy=False)
4437+
# If a dataframe is passed into transform, then the output should have the same
4438+
# index
4439+
expected_index = index if case.endswith("df") else None
4440+
expected_dataframe = pd.DataFrame(
4441+
X_trans, columns=feature_names_pandas, copy=False, index=expected_index
4442+
)
44384443

44394444
try:
44404445
pd.testing.assert_frame_equal(df_trans, expected_dataframe)
@@ -4469,7 +4474,8 @@ def check_set_output_transform_pandas(name, transformer_orig):
44694474
set_random_state(transformer)
44704475

44714476
feature_names_in = [f"col{i}" for i in range(X.shape[1])]
4472-
df = pd.DataFrame(X, columns=feature_names_in, copy=False)
4477+
index = [f"index{i}" for i in range(X.shape[0])]
4478+
df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index)
44734479

44744480
transformer_default = clone(transformer).set_output(transform="default")
44754481
outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y)
@@ -4483,7 +4489,7 @@ def check_set_output_transform_pandas(name, transformer_orig):
44834489

44844490
for case in outputs_default:
44854491
_check_generated_dataframe(
4486-
name, case, outputs_default[case], outputs_pandas[case]
4492+
name, case, index, outputs_default[case], outputs_pandas[case]
44874493
)
44884494

44894495

@@ -4511,7 +4517,8 @@ def check_global_ouptut_transform_pandas(name, transformer_orig):
45114517
set_random_state(transformer)
45124518

45134519
feature_names_in = [f"col{i}" for i in range(X.shape[1])]
4514-
df = pd.DataFrame(X, columns=feature_names_in, copy=False)
4520+
index = [f"index{i}" for i in range(X.shape[0])]
4521+
df = pd.DataFrame(X, columns=feature_names_in, copy=False, index=index)
45154522

45164523
transformer_default = clone(transformer).set_output(transform="default")
45174524
outputs_default = _output_from_fit_transform(transformer_default, name, X, df, y)
@@ -4528,5 +4535,5 @@ def check_global_ouptut_transform_pandas(name, transformer_orig):
45284535

45294536
for case in outputs_default:
45304537
_check_generated_dataframe(
4531-
name, case, outputs_default[case], outputs_pandas[case]
4538+
name, case, index, outputs_default[case], outputs_pandas[case]
45324539
)

0 commit comments

Comments
 (0)