From 819cf0c091bcf0bdafd3f74a984ff1abd38fb9eb Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 24 Apr 2025 13:35:26 +0200 Subject: [PATCH 01/31] Remove requires fitting and finalize routine --- .../abstract_pipeline_elements/core.py | 207 ++++++------------ molpipeline/error_handling.py | 134 ++++++------ molpipeline/mol2any/mol2bool.py | 12 +- .../mol2any/mol2concatinated_vector.py | 108 ++++----- molpipeline/pipeline/_molpipeline.py | 84 ++++--- tests/test_elements/test_error_handling.py | 10 +- 6 files changed, 225 insertions(+), 330 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 9e26bce0..c5907a50 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -5,8 +5,7 @@ import abc import copy import inspect -from collections.abc import Iterable -from typing import Any, NamedTuple, Self +from typing import TYPE_CHECKING, Any, NamedTuple, Self from uuid import uuid4 from joblib import Parallel, delayed @@ -16,6 +15,9 @@ from molpipeline.utils.multi_proc import check_available_cores +if TYPE_CHECKING: + from collections.abc import Iterable + class InvalidInstance(NamedTuple): """Object which is returned when an instance cannot be processed. @@ -29,6 +31,7 @@ class InvalidInstance(NamedTuple): element_name: str | None Optional name of the element which could not be processed. The name of the pipeline element is often more descriptive than the id. + """ element_id: str @@ -42,6 +45,7 @@ def __repr__(self) -> str: ------- str String representation of InvalidInstance. + """ return ( f"InvalidInstance({self.element_name or self.element_id}, {self.message})" @@ -75,6 +79,7 @@ def __repr__(self) -> str: ------- str String representation of RemovedInstance. + """ return f"RemovedInstance({self.filter_element_id}, {self.message})" @@ -83,7 +88,6 @@ class ABCPipelineElement(abc.ABC): """Ancestor of all PipelineElements.""" name: str - _requires_fitting: bool = False uuid: str def __init__( @@ -102,6 +106,7 @@ def __init__( Number of cores used for processing. uuid: str | None, optional Unique identifier of the PipelineElement. + """ if name is None: name = self.__class__.__name__ @@ -119,13 +124,13 @@ def __repr__(self) -> str: ------- str String representation of object. + """ parm_list = [] for key, value in self._get_non_default_parameters().items(): parm_list.append(f"{key}={value}") parm_str = ", ".join(parm_list) - repr_str = f"{self.__class__.__name__}({parm_str})" - return repr_str + return f"{self.__class__.__name__}({parm_str})" def _get_non_default_parameters(self) -> dict[str, Any]: """Return all parameters which are not default. @@ -134,6 +139,7 @@ def _get_non_default_parameters(self) -> dict[str, Any]: ------- dict[str, Any] Dictionary of non default parameters. + """ signature = inspect.signature(self.__class__.__init__) self_params = self.get_params() @@ -144,9 +150,8 @@ def _get_non_default_parameters(self) -> dict[str, Any]: if value_default is inspect.Parameter.empty: non_default_params[parm] = self_params[parm] continue - if parm in self_params: - if value_default.default != self_params[parm]: - non_default_params[parm] = self_params[parm] + if parm in self_params and value_default.default != self_params[parm]: + non_default_params[parm] = self_params[parm] return non_default_params def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -161,6 +166,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameters of the object. + """ if deep: return { @@ -176,7 +182,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } def set_params(self, **parameters: Any) -> Self: - """As the setter function cannot be assessed with super(), this method is implemented for inheritance. + """Set the parameters of the object. Parameters ---------- @@ -197,7 +203,7 @@ def set_params(self, **parameters: Any) -> Self: for att_name, att_value in parameters.items(): if not hasattr(self, att_name): raise ValueError( - f"Cannot set attribute {att_name} on {self.__class__.__name__}" + f"Cannot set attribute {att_name} on {self.__class__.__name__}", ) setattr(self, att_name, att_value) return self @@ -219,45 +225,25 @@ def n_jobs(self, n_jobs: int) -> None: """ self._n_jobs = check_available_cores(n_jobs) - @property - def requires_fitting(self) -> bool: - """Return whether the object requires fitting or not.""" - return self._requires_fitting - - def fit(self, values: Any, labels: Any = None) -> Self: + def fit( + self, + values: Any, # noqa: ARG002 + labels: Any = None, # noqa: ARG002 + ) -> Self: """Fit object to input_values. - Most objects might not need fitting, but it is implemented for consitency for all PipelineElements. - Parameters ---------- values: Any List of molecule representations. - labels: Any + labels: Any, optional Optional label for fitting. Returns ------- Self Fitted object. - """ - _ = self.fit_transform(values, labels) - return self - - def fit_to_result(self, values: Any) -> Self: # pylint: disable=unused-argument - """Fit object to result of transformed values. - Fit object to the result of the transform function. This is useful catching nones and removed molecules. - - Parameters - ---------- - values: Any - List of molecule representations. - - Returns - ------- - Self - Fitted object. """ return self @@ -272,7 +258,7 @@ def fit_transform( Parameters ---------- values: Any - Apply transformation specified in transform_single to all molecules in the value_list. + Apply transform_single to all molecules in the value_list. labels: Any Optional label for fitting. @@ -280,6 +266,7 @@ def fit_transform( ------- Any List of instances in new representation. + """ @abc.abstractmethod @@ -289,28 +276,15 @@ def transform(self, values: Any) -> Any: Parameters ---------- values: Any - Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, PhysChem vectors etc.). + Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, + PhysChem vectors etc.). Input depends on the concrete PipelineElement. Returns ------- Any Transformed input_values. - """ - - @abc.abstractmethod - def transform_single(self, value: Any) -> Any: - """Transform a single value. - - Parameters - ---------- - value: Any - Value to be transformed. - Returns - ------- - Any - Transformed value. """ @@ -337,6 +311,7 @@ def __init__( Number of cores used for processing. uuid: str | None, optional Unique identifier of the PipelineElement. + """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self._is_fitted = False @@ -346,11 +321,6 @@ def input_type(self) -> str: """Return the input type.""" return self._input_type - @property - def is_fitted(self) -> bool: - """Return whether the object is fitted or not.""" - return self._is_fitted - @property def output_type(self) -> str: """Return the output type.""" @@ -373,24 +343,6 @@ def parameters(self, **parameters: Any) -> None: """ self.set_params(**parameters) - def fit_to_result(self, values: Any) -> Self: - """Fit object to result of transformed values. - - Fit object to the result of the transform function. This is useful catching nones and removed molecules. - - Parameters - ---------- - values: Any - List of molecule representations. - - Returns - ------- - Self - Fitted object. - """ - self._is_fitted = True - return super().fit_to_result(values) - def fit_transform( self, values: Any, @@ -401,7 +353,7 @@ def fit_transform( Parameters ---------- values: Any - Apply transformation specified in transform_single to all molecules in the value_list. + Apply transform_single to all molecules in the value_list. labels: Any Optional label for fitting. @@ -409,20 +361,15 @@ def fit_transform( ------- Any List of molecules in new representation. + """ - self._is_fitted = True - if self.requires_fitting: - pre_value_list = self.pretransform(values) - self.fit_to_result(pre_value_list) - output_list = self.finalize_list(pre_value_list) - if hasattr(self, "assemble_output"): - return self.assemble_output(output_list) + self.fit(values, labels) return self.transform(values) def transform_single(self, value: Any) -> Any: """Transform a single molecule to the new representation. - RemovedMolecule objects are passed without change, as no transformations are applicable. + RemovedMolecule objects are passed without change. Parameters ---------- @@ -432,20 +379,19 @@ def transform_single(self, value: Any) -> Any: Returns ------- Any - New representation of the molecule. (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) + New representation of the molecule. + (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) + """ if isinstance(value, InvalidInstance): return value - pre_value = self.pretransform_single(value) - if isinstance(pre_value, InvalidInstance): - return pre_value - return self.finalize_single(pre_value) + return self.pretransform_single(value) def assemble_output(self, value_list: Iterable[Any]) -> Any: """Aggregate rows, which in most cases is just return the list. - Some representations might be better representd as a single object. For example a list of vectors can - be transformed to a matrix. + Some representations might be better representd as a single object. + For example a list of vectors can be transformed to a matrix. Parameters ---------- @@ -456,16 +402,18 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: ------- Any Aggregated output. This can also be the original input. + """ return list(value_list) @abc.abstractmethod def pretransform_single(self, value: Any) -> Any: - """Transform the instance, but skips parameters learned during fitting. + """Transform the instance. This is the first step for the full transformation. - It is followed by the finalize_single method and assemble output which collects all single transformations. - These functions are split as they need to be accessed separately from outside the object. + It is followed by the finalize_single method and assemble output which collects + all single transformations. These functions are split as they need to be + accessed separately from outside the object. Parameters ---------- @@ -476,25 +424,11 @@ def pretransform_single(self, value: Any) -> Any: ------- Any Pretransformed value. (Skips applying parameters learned during fitting) - """ - def finalize_single(self, value: Any) -> Any: - """Apply parameters learned during fitting to a single instance. - - Parameters - ---------- - value: Any - Value obtained from pretransform_single. - - Returns - ------- - Any - Finalized value. """ - return value def pretransform(self, value_list: Iterable[Any]) -> list[Any]: - """Transform input_values according to object rules without fitting specifics. + """Transform input_values according to object rules. Parameters ---------- @@ -505,31 +439,12 @@ def pretransform(self, value_list: Iterable[Any]) -> list[Any]: ------- list[Any] Transformed input_values. - """ - parallel = Parallel(n_jobs=self.n_jobs) - output_values = parallel( - delayed(self.pretransform_single)(value) for value in value_list - ) - return output_values - - def finalize_list(self, value_list: Iterable[Any]) -> list[Any]: - """Transform list of values according to parameters learned during fitting. - Parameters - ---------- - value_list: Iterable[Any] - List of values to be transformed. - - Returns - ------- - list[Any] - List of transformed values. """ parallel = Parallel(n_jobs=self.n_jobs) - output_values = parallel( - delayed(self.finalize_single)(value) for value in value_list + return parallel( + delayed(self.pretransform_single)(value) for value in value_list ) - return output_values def transform(self, values: Any) -> Any: """Transform input_values according to object rules. @@ -537,18 +452,18 @@ def transform(self, values: Any) -> Any: Parameters ---------- values: Any - Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, PhysChem vectors etc.). + Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, + PhysChem vectors etc.). Input depends on the concrete PipelineElement. Returns ------- Any Transformed input_values. + """ - output_rows = self.pretransform(values) - output_rows = self.finalize_list(output_rows) - output = self.assemble_output(output_rows) - return output + output_values = self.pretransform(values) + return self.assemble_output(output_values) class MolToMolPipelineElement(TransformingPipelineElement, abc.ABC): @@ -568,7 +483,8 @@ def transform(self, values: list[OptionalMol]) -> list[OptionalMol]: Returns ------- list[OptionalMol] - List of molecules or InvalidInstances, if corresponding transformation was not successful. + List of molecules or InvalidInstances. + """ mol_list: list[OptionalMol] = super().transform(values) # Stupid mypy... return mol_list @@ -585,17 +501,22 @@ def transform_single(self, value: OptionalMol) -> OptionalMol: ------- OptionalMol Transformed molecule if transformation was successful, else InvalidInstance. + """ + if isinstance(value, InvalidInstance): + return value try: return super().transform_single(value) except Exception as exc: if isinstance(value, Chem.Mol): logger.error( - f"Failed to process: {Chem.MolToSmiles(value)} | ({self.name}) ({self.uuid})" + f"Failed to process: {Chem.MolToSmiles(value)} | ({self.name}) " + f"({self.uuid})", ) else: logger.error( - f"Failed to process: {value} ({type(value)}) | ({self.name}) ({self.uuid})" + f"Failed to process: {value} ({type(value)}) | ({self.name}) " + f"({self.uuid})", ) raise exc @@ -614,6 +535,7 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: ------- OptionalMol Transformed molecule if transformation was successful, else InvalidInstance. + """ @@ -633,14 +555,15 @@ def transform(self, values: Any) -> list[OptionalMol]: Returns ------- list[OptionalMol] - List of molecules or InvalidInstances, if corresponding representation was invalid. + List of molecules or InvalidInstances. + """ mol_list: list[OptionalMol] = super().transform(values) # Stupid mypy... return mol_list @abc.abstractmethod def pretransform_single(self, value: Any) -> OptionalMol: - """Transform the instance to a molecule, but skip parameters learned during fitting. + """Transform the instance to a molecule. Parameters ---------- @@ -651,6 +574,7 @@ def pretransform_single(self, value: Any) -> OptionalMol: ------- OptionalMol Obtained molecule if valid representation, else InvalidInstance. + """ @@ -661,7 +585,7 @@ class MolToAnyPipelineElement(TransformingPipelineElement, abc.ABC): @abc.abstractmethod def pretransform_single(self, value: RDKitMol) -> Any: - """Transform the molecule, but skip parameters learned during fitting. + """Transform the molecule. Parameters ---------- @@ -672,4 +596,5 @@ def pretransform_single(self, value: RDKitMol) -> Any: ------- Any Transformed molecule. + """ diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index 87cad027..b36c4f15 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -2,8 +2,7 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence -from typing import Any, Generic, Self, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar import numpy as np import numpy.typing as npt @@ -14,7 +13,11 @@ RemovedInstance, TransformingPipelineElement, ) -from molpipeline.utils.molpipeline_types import AnyVarSeq, TypeFixedVarSeq + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from molpipeline.utils.molpipeline_types import AnyVarSeq, TypeFixedVarSeq __all__ = ["ErrorFilter", "FilterReinserter", "_MultipleErrorFilter"] @@ -23,7 +26,7 @@ class ErrorFilter(ABCPipelineElement): - """Collects None values and can fill Dummy values to matrices where None values were removed.""" + """Filter to remove InvalidInstances from a list of values.""" element_ids: set[str] error_indices: list[int] @@ -64,7 +67,7 @@ def __init__( if element_ids is None: if not filter_everything: raise ValueError( - "If element_ids is None, filter_everything must be True" + "If element_ids is None, filter_everything must be True", ) element_ids = set() if not isinstance(element_ids, set): @@ -72,7 +75,6 @@ def __init__( self.element_ids = element_ids self.filter_everything = filter_everything self.n_total = 0 - self._requires_fitting = True @classmethod def from_element_list( @@ -99,10 +101,15 @@ def from_element_list( ------- Self Constructed ErrorFilter object. + """ element_ids = {element.uuid for element in element_list} return cls( - element_ids, filter_everything=False, name=name, n_jobs=n_jobs, uuid=uuid + element_ids, + filter_everything=False, + name=name, + n_jobs=n_jobs, + uuid=uuid, ) def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -117,6 +124,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter names mapped to their values. + """ params = super().get_params(deep=deep) params["filter_everything"] = self.filter_everything @@ -168,14 +176,17 @@ def check_removal(self, value: Any) -> bool: ------- bool True if value should be removed. + """ if not isinstance(value, InvalidInstance): return False - if self.filter_everything or value.element_id in self.element_ids: - return True - return False + return self.filter_everything or value.element_id in self.element_ids - def fit(self, values: AnyVarSeq, labels: Any = None) -> Self: + def fit( + self, + values: AnyVarSeq, # noqa: ARG002 + labels: Any = None, # noqa: ARG002 + ) -> Self: """Fit to input values. Only for compatibility with sklearn Pipelines. @@ -191,11 +202,14 @@ def fit(self, values: AnyVarSeq, labels: Any = None) -> Self: ------- Self Fitted ErrorFilter. + """ return self def fit_transform( - self, values: TypeFixedVarSeq, labels: Any = None + self, + values: TypeFixedVarSeq, + labels: Any = None, ) -> TypeFixedVarSeq: """Transform values and return a list without the None values. @@ -205,22 +219,24 @@ def fit_transform( ---------- values: TypeFixedVarSeq Iterable to which element is fitted and which is subsequently transformed. - labels: Any - Label used for fitting. (Not used, but required for compatibility with sklearn) + labels: Any, optional + Label used for fitting. + (Not used, but required for compatibility with sklearn) Returns ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ self.fit(values, labels) return self.transform(values) def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: - """Remove rows at positions which contained discarded values during transformation. + """Remove rows at positions which contained discarded values in `transform`. - This ensures that rows of this instance maintain a one to one correspondence with the rows of data seen during - transformation. + This ensures that rows of this instance maintain a one to one correspondence + with the rows of data seen during transformation. Parameters ---------- @@ -266,6 +282,7 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ self.n_total = len(values) self.error_indices = [] @@ -286,21 +303,7 @@ def transform_single(self, value: Any) -> Any: ------- Any Transformed value. - """ - return self.pretransform_single(value) - def pretransform_single(self, value: Any) -> Any: - """Transform a single value. - - Parameters - ---------- - value: Any - Value to be transformed. - - Returns - ------- - Any - Transformed value. """ if self.check_removal(value): return RemovedInstance( @@ -322,6 +325,7 @@ def __init__(self, error_filter_list: list[ErrorFilter]) -> None: ---------- error_filter_list: list[ErrorFilter] List of ErrorFilter objects. + """ self.error_filter_list = error_filter_list prior_remover_dict = {} @@ -341,13 +345,14 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ for error_filter in self.error_filter_list: values = error_filter.transform(values) return values def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: - """Remove rows at positions which contained discarded values during transformation. + """Remove rows at positions which contained discarded values in `transform`. Parameters ---------- @@ -357,7 +362,8 @@ def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: Returns ------- TypeFixedVarSeq - Iterable where rows are removed which were removed during the transformation. + Iterable where rows are removed which were removed in `transform`. + """ for error_filter in self.error_filter_list: values = error_filter.co_transform(values) @@ -375,6 +381,7 @@ def fit_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ for error_filter in self.error_filter_list: values = error_filter.fit_transform(values) @@ -410,7 +417,8 @@ def register_removed(self, index: int, value: RemovedInstance) -> None: break else: raise ValueError( - f"Invalid instance not captured by any ErrorFilter: {value.filter_element_id}" + f"Invalid instance not captured by any ErrorFilter: " + f"{value.filter_element_id}", ) def set_total(self, total: int) -> None: @@ -457,6 +465,7 @@ def __init__( Number of parallel jobs to use. uuid: str | None, optional UUID of the pipeline element. + """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.error_filter_id = error_filter_id @@ -491,6 +500,7 @@ def from_error_filter( ------- Self Constructed FilterReinserter object. + """ filler = cls( error_filter_id=error_filter.uuid, @@ -514,6 +524,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter names mapped to their values. + """ params = super().get_params(deep=deep) if deep: @@ -539,6 +550,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self The instance itself. + """ parameter_copy = dict(parameters) if "error_filter_id" in parameter_copy: @@ -570,6 +582,7 @@ def error_filter(self, error_filter: ErrorFilter) -> None: ---------- error_filter: ErrorFilter ErrorFilter to set. + """ self._error_filter = error_filter @@ -603,9 +616,9 @@ def select_error_filter(self, error_filter_list: list[ErrorFilter]) -> Self: # pylint: disable=unused-argument def fit( self, - values: TypeFixedVarSeq, - labels: Any = None, - **params: Any, + values: TypeFixedVarSeq, # noqa: ARG002 + labels: Any = None, # noqa: ARG002 + **params: Any, # noqa: ARG002 ) -> Self: """Fit to input values. @@ -615,8 +628,9 @@ def fit( ---------- values: TypeFixedVarSeq Values used for fitting. - labels: Any - Label used for fitting. (Not used, but required for compatibility with sklearn) + labels: Any, optional + Label used for fitting. + (Not used, but required for compatibility with sklearn) **params: Any Additional keyword arguments. (Not used) @@ -624,6 +638,7 @@ def fit( ------- Self Fitted FilterReinserter. + """ return self @@ -631,8 +646,8 @@ def fit( def fit_transform( self, values: TypeFixedVarSeq, - labels: Any = None, - **params: Any, + labels: Any = None, # noqa: ARG002 + **params: Any, # noqa: ARG002 ) -> TypeFixedVarSeq: """Transform values and return a list without the Invalid values. @@ -642,8 +657,9 @@ def fit_transform( ---------- values: TypeFixedVarSeq Iterable to which element is fitted and which is subsequently transformed. - labels: Any - Label used for fitting. (Not used, but required for compatibility with sklearn) + labels: Any, optional + Label used for fitting. + (Not used, but required for compatibility with sklearn) **params: Any Additional keyword arguments. (Not used) @@ -651,6 +667,7 @@ def fit_transform( ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ self.fit(values) return self.transform(values) @@ -667,21 +684,7 @@ def transform_single(self, value: Any) -> Any: ------- Any Transformed value. - """ - return self.pretransform_single(value) - - def pretransform_single(self, value: Any) -> Any: - """Transform a single value. - - Parameters - ---------- - value: Any - Value to be transformed. - Returns - ------- - Any - Transformed value. """ if ( isinstance(value, RemovedInstance) @@ -693,7 +696,7 @@ def pretransform_single(self, value: Any) -> Any: def transform( self, values: TypeFixedVarSeq, - **params: Any, + **params: Any, # noqa: ARG002 ) -> TypeFixedVarSeq: """Transform iterable of values by removing invalid instances. @@ -718,11 +721,13 @@ def transform( """ if len(values) != self.error_filter.n_total - len( - self.error_filter.error_indices + self.error_filter.error_indices, ): raise ValueError( f"Length of values does not match length of values in fit. " - f"Expected: {self.error_filter.n_total - len(self.error_filter.error_indices)} - Received :{len(values)}" + f"Expected: " + f"{self.error_filter.n_total - len(self.error_filter.error_indices)} " + f"- Received :{len(values)}", ) return self.fill_with_dummy(values) @@ -756,7 +761,7 @@ def _fill_list(self, list_to_fill: Sequence[_S]) -> Sequence[_S | _T]: next_value_pos += 1 if len(list_to_fill) != next_value_pos: raise AssertionError( - "Length of list does not match length of values in fit" + "Length of list does not match length of values in fit", ) return filled_list @@ -771,7 +776,9 @@ def _fill_numpy_arr(self, value_array: npt.NDArray[Any]) -> npt.NDArray[Any]: Returns ------- npt.NDArray[Any] - Numpy array where dummy values were inserted to replace instances which could not be processed. + Numpy array where dummy values were inserted to replace instances which + could not be processed. + """ fill_value = self.fill_value output_shape = list(value_array.shape) @@ -808,7 +815,8 @@ def fill_with_dummy( Returns ------- AnyVarSeq - Iterable where dummy values were inserted to replace molecules which could not be processed. + Iterable where dummy values were inserted to replace molecules which could + not be processed. """ if isinstance(value_container, list): diff --git a/molpipeline/mol2any/mol2bool.py b/molpipeline/mol2any/mol2bool.py index c605a6ba..1529e79b 100644 --- a/molpipeline/mol2any/mol2bool.py +++ b/molpipeline/mol2any/mol2bool.py @@ -26,16 +26,16 @@ def pretransform_single(self, value: Any) -> bool: ------- str Binary representation of molecule. + """ - if isinstance(value, InvalidInstance): - return False - return True + return not isinstance(value, InvalidInstance) def transform_single(self, value: Any) -> Any: """Transform a single molecule to a bool representation. Valid molecules are passed as True, InvalidInstances are passed as False. - RemovedMolecule objects are passed without change, as no transformations are applicable. + RemovedMolecule objects are passed without change, as no transformations are + applicable. Parameters ---------- @@ -46,6 +46,6 @@ def transform_single(self, value: Any) -> Any: ------- Any Bool representation of the molecule. + """ - pre_value = self.pretransform_single(value) - return self.finalize_single(pre_value) + return self.pretransform_single(value) diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index 5955b4d1..70412d84 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -1,9 +1,8 @@ -"""Classes for creating arrays from multiple concatenated descriptors or fingerprints.""" +"""Classes for descriptors from multiple descriptors or fingerprints.""" from __future__ import annotations -from collections.abc import Iterable -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import numpy as np import numpy.typing as npt @@ -17,11 +16,15 @@ from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( MolToFingerprintPipelineElement, ) -from molpipeline.utils.molpipeline_types import RDKitMol + +if TYPE_CHECKING: + from collections.abc import Iterable + + from molpipeline.utils.molpipeline_types import RDKitMol class MolToConcatenatedVector(MolToAnyPipelineElement): - """Creates a concatenated descriptor vectored from multiple MolToAny PipelineElements.""" + """A concatenated descriptor vector from multiple MolToAny PipelineElements.""" _element_list: list[tuple[str, MolToAnyPipelineElement]] @@ -52,7 +55,8 @@ def __init__( uuid: str | None, optional (default=None) UUID of the pipeline element. If None, a random UUID is generated. kwargs: Any - Additional keyword arguments. Can be used to set parameters of the pipeline elements. + Additional keyword arguments. + Can be used to set parameters of the pipeline elements. Raises ------ @@ -158,7 +162,7 @@ def _create_feature_names( def _set_element_execution_details( self, - element_list: list[tuple[str, MolToAnyPipelineElement]], + element_list: list[tuple[str, MolToAnyPipelineElement]], # noqa: ARG002 ) -> None: """Set output type and requires fitting for the concatenated vector. @@ -176,10 +180,6 @@ def _set_element_execution_details( self._output_type = output_types.pop() else: self._output_type = "mixed" - self._requires_fitting = any( - element[1]._requires_fitting # pylint: disable=protected-access - for element in element_list - ) def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. @@ -311,12 +311,14 @@ def assemble_output( Parameters ---------- value_list: Iterable[npt.NDArray[np.float64]] - List of molecular descriptors or fingerprints which are concatenated to a single matrix. + List of molecular descriptors or fingerprints which are concatenated to a + single matrix. Returns ------- npt.NDArray[np.float64] - Matrix of shape (n_molecules, n_features) with concatenated features specified during init. + Matrix of shape (n_molecules, n_features) with concatenated features + specified during init. """ return np.vstack(list(value_list)) @@ -332,7 +334,8 @@ def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: Returns ------- npt.NDArray[np.float64] - Matrix of shape (n_molecules, n_features) with concatenated features specified during init. + Matrix of shape (n_molecules, n_features) with concatenated features + specified during init. """ output: npt.NDArray[np.float64] = super().transform(values) @@ -341,14 +344,14 @@ def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: def fit( self, values: list[RDKitMol], - labels: Any = None, + labels: Any = None, # noqa: ARG002 ) -> Self: """Fit each pipeline element. Parameters ---------- values: list[RDKitMol] - List of molecules used to fit the pipeline elements creating the concatenated vector. + List of molecules used to fit the pipeline elements. labels: Any Labels for the molecules. Not used. @@ -365,7 +368,7 @@ def fit( def pretransform_single( self, value: RDKitMol, - ) -> list[npt.NDArray[np.float64] | dict[int, int]] | InvalidInstance: + ) -> npt.NDArray[np.float64] | InvalidInstance: """Get pretransform of each element and concatenate for output. Parameters @@ -380,64 +383,29 @@ def pretransform_single( If any element returns None, InvalidInstance is returned. """ - final_vector = [] + transfored_list = [] error_message = "" for name, pipeline_element in self._element_list: - vector = pipeline_element.pretransform_single(value) - if isinstance(vector, InvalidInstance): + transformed_value = pipeline_element.pretransform_single(value) + if isinstance(transformed_value, InvalidInstance): error_message += f"{self.name}__{name} returned an InvalidInstance." break - - final_vector.append(vector) - else: # no break - return final_vector - return InvalidInstance(self.uuid, error_message, self.name) - - def finalize_single(self, value: Any) -> Any: - """Finalize the output of transform_single. - - Parameters - ---------- - value: Any - Output of transform_single. - - Returns - ------- - Any - Finalized output. - - """ - final_vector_list = [] - for (_, element), sub_value in zip(self._element_list, value, strict=True): - final_value = element.finalize_single(sub_value) - if isinstance(element, MolToFingerprintPipelineElement) and isinstance( - final_value, + if isinstance( + pipeline_element, + MolToFingerprintPipelineElement, + ) and isinstance( + transformed_value, dict, ): - vector = np.zeros(element.n_bits) - vector[list(final_value.keys())] = np.array(list(final_value.values())) + vector = np.zeros(pipeline_element.n_bits) + vector[list(transformed_value.keys())] = np.array( + list(transformed_value.values()), + ) final_value = vector - if not isinstance(final_value, np.ndarray): - final_value = np.array(final_value) - final_vector_list.append(final_value) - return np.hstack(final_vector_list) - - def fit_to_result(self, values: Any) -> Self: - """Fit the pipeline element to the result of transform_single. - - Parameters - ---------- - values: Any - Output of transform_single. - - Returns - ------- - Self - Fitted pipeline element. + else: + final_value = np.array(transformed_value) - """ - for element, value in zip( - self._element_list, zip(*values, strict=True), strict=True - ): - element[1].fit_to_result(value) - return self + transfored_list.append(final_value) + else: # no break + return np.hstack(transfored_list) + return InvalidInstance(self.uuid, error_message, self.name) diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py index 3a9f2def..c7a3b3a0 100644 --- a/molpipeline/pipeline/_molpipeline.py +++ b/molpipeline/pipeline/_molpipeline.py @@ -2,8 +2,7 @@ from __future__ import annotations -from collections.abc import Iterable -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import numpy as np from joblib import Parallel, delayed @@ -21,16 +20,19 @@ FilterReinserter, _MultipleErrorFilter, ) -from molpipeline.utils.molpipeline_types import TypeFixedVarSeq from molpipeline.utils.multi_proc import check_available_cores +if TYPE_CHECKING: + from collections.abc import Iterable + + from molpipeline.utils.molpipeline_types import TypeFixedVarSeq + class _MolPipeline: """Contains the PipeElements which describe the functionality of the pipeline.""" _n_jobs: int _element_list: list[ABCPipelineElement] - _requires_fitting: bool def __init__( self, @@ -53,9 +55,6 @@ def __init__( self._element_list = element_list self.n_jobs = n_jobs self.name = name - self._requires_fitting = any( - element.requires_fitting for element in self._element_list - ) @property def _filter_elements(self) -> list[ErrorFilter]: @@ -95,7 +94,8 @@ def n_jobs(self, requested_jobs: int) -> None: ---------- requested_jobs: int Number of cores requested for transformation steps. - If fewer cores than requested are available, the number of cores is set to maximum available. + If fewer cores than requested are available, the number of cores is set to + the maximum available. """ self._n_jobs = check_available_cores(requested_jobs) @@ -112,16 +112,11 @@ def parameters(self, parameter_dict: dict[str, Any]) -> None: Parameters ---------- parameter_dict: dict[str, Any] - Dictionary containing the parameter names and corresponding values to be set. + Dictionary of parameter names and corresponding values to be set. """ self.set_params(**parameter_dict) - @property - def requires_fitting(self) -> bool: - """Return whether the pipeline requires fitting.""" - return self._requires_fitting - def get_params(self, deep: bool = True) -> dict[str, Any]: """Get all parameters defining the object. @@ -133,7 +128,8 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: Returns ------- dict[str, Any] - Dictionary containing the parameter names and corresponding values. + Dictionary of parameter names and corresponding values. + """ if deep: return { @@ -153,12 +149,13 @@ def set_params(self, **parameter_dict: Any) -> Self: Parameters ---------- parameter_dict: Any - Dictionary containing the parameter names and corresponding values to be set. + Dictionary of parameter names and corresponding values to be set. Returns ------- Self MolPipeline object with updated parameters. + """ if "element_list" in parameter_dict: self._element_list = parameter_dict["element_list"] @@ -176,27 +173,27 @@ def element_list(self) -> list[ABCPipelineElement]: def _get_meta_element_list( self, ) -> list[ABCPipelineElement | _MolPipeline]: - """Merge elements which do not require fitting to a meta element which improves parallelization. + """Merge elements which do not require fitting to a meta element. + + This improves the parallelization of the pipeline. Returns ------- list[ABCPipelineElement | _MolPipeline] List of pipeline elements and meta elements. + """ meta_element_list: list[ABCPipelineElement | _MolPipeline] = [] no_fit_element_list: list[ABCPipelineElement] = [] for element in self._element_list: - if ( - isinstance(element, TransformingPipelineElement) - and not element.requires_fitting - ): + if isinstance(element, TransformingPipelineElement): no_fit_element_list.append(element) else: if len(no_fit_element_list) == 1: meta_element_list.append(no_fit_element_list[0]) elif len(no_fit_element_list) > 1: meta_element_list.append( - _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs) + _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs), ) no_fit_element_list = [] meta_element_list.append(element) @@ -204,7 +201,7 @@ def _get_meta_element_list( meta_element_list.append(no_fit_element_list[0]) elif len(no_fit_element_list) > 1: meta_element_list.append( - _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs) + _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs), ) return meta_element_list @@ -229,20 +226,20 @@ def fit( ------- Self Fitted MolPipeline. + """ _ = y # Making pylint happy _ = fit_params # Making pylint happy - if self.requires_fitting: - self.fit_transform(x_input) + self.fit_transform(x_input) return self def fit_transform( # pylint: disable=unused-argument self, x_input: Any, - y: Any = None, - **fit_params: dict[str, Any], + y: Any = None, # noqa: ARG002 + **fit_params: dict[str, Any], # noqa: ARG002 ) -> Any: - """Fit the MolPipeline according to x_input and return the transformed molecules. + """Fit to x_input and return the transformed molecules. Parameters ---------- @@ -253,16 +250,11 @@ def fit_transform( # pylint: disable=unused-argument fit_params: Any Parameters. Only for SKlearn compatibility. - Raises - ------ - AssertionError - If a subpipeline requires fitting, which by definition should not be the - case. - Returns ------- Any Transformed molecules. + """ iter_input = x_input @@ -271,7 +263,7 @@ def fit_transform( # pylint: disable=unused-argument removed_rows[error_filter] = [] iter_idx_array = np.arange(len(iter_input)) - # The meta elements merge steps which do not require fitting to improve parallelization + # The meta elements merge steps which do not require fitting for i_element in self._get_meta_element_list(): if not isinstance(i_element, (TransformingPipelineElement, _MolPipeline)): continue @@ -283,12 +275,6 @@ def fit_transform( # pylint: disable=unused-argument idx = iter_idx_array[idx] removed_rows[error_filter].append(idx) iter_idx_array = error_filter.co_transform(iter_idx_array) - if i_element.requires_fitting: - if isinstance(i_element, _MolPipeline): - raise AssertionError("No subpipline should require fitting!") - i_element.fit_to_result(iter_input) - if isinstance(i_element, TransformingPipelineElement): - iter_input = i_element.finalize_list(iter_input) iter_input = i_element.assemble_output(iter_input) i_element.n_jobs = 1 @@ -311,7 +297,7 @@ def fit_transform( # pylint: disable=unused-argument return iter_input def transform_single(self, input_value: Any) -> Any: - """Transform a single input according to the sequence of provided PipelineElements. + """Transform a single input according to the sequence of PipelineElements. Parameters ---------- @@ -322,14 +308,16 @@ def transform_single(self, input_value: Any) -> Any: ------- Any Transformed molecular representation. + """ log_block = BlockLogs() iter_value = input_value for p_element in self._element_list: try: - if not isinstance(iter_value, RemovedInstance): - iter_value = p_element.transform_single(iter_value) - elif isinstance(p_element, FilterReinserter): + if not isinstance(iter_value, RemovedInstance) or isinstance( + p_element, + FilterReinserter, + ): iter_value = p_element.transform_single(iter_value) except MolSanitizeException as err: iter_value = InvalidInstance( @@ -341,7 +329,7 @@ def transform_single(self, input_value: Any) -> Any: return iter_value def pretransform(self, x_input: Any) -> Any: - """Transform the input according to the sequence BUT skip the assemble output step. + """Transform the input according to the sequence without assemble_output step. Parameters ---------- @@ -352,6 +340,7 @@ def pretransform(self, x_input: Any) -> Any: ------- Any Transformed molecular representations. + """ return list(self._transform_iterator(x_input)) @@ -367,6 +356,7 @@ def transform(self, x_input: Any) -> Any: ------- Any Transformed molecular representations. + """ output_generator = self._transform_iterator(x_input) return self.assemble_output(output_generator) @@ -383,6 +373,7 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: ------- Any Assembled output. + """ last_element = self._transforming_elements[-1] if hasattr(last_element, "assemble_output"): @@ -433,5 +424,6 @@ def co_transform(self, x_input: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- Any Filtered molecular representations. + """ return self._filter_elements_agg.co_transform(x_input) diff --git a/tests/test_elements/test_error_handling.py b/tests/test_elements/test_error_handling.py index d6a79245..ad1130a3 100644 --- a/tests/test_elements/test_error_handling.py +++ b/tests/test_elements/test_error_handling.py @@ -22,6 +22,8 @@ TEST_SMILES = ["NCCCO", "abc", "c1ccccc1"] EXPECTED_OUTPUT = ["NCCCO", None, "c1ccccc1"] +TOLERANCE = 0.000001 + class NoneTest(unittest.TestCase): """Unittest for None Handling.""" @@ -95,7 +97,7 @@ def test_dummy_remove_physchem_record_molpipeline(self) -> None: out = pipeline.transform(TEST_SMILES) out2 = pipeline2.fit_transform(TEST_SMILES) self.assertEqual(out.shape, out2.shape) - self.assertTrue(np.max(np.abs(out - out2)) < 0.000001) + self.assertTrue(np.max(np.abs(out - out2)) < TOLERANCE) def test_dummy_remove_physchem_record_autodetect_molpipeline(self) -> None: """Assert that invalid smiles are transformed to None.""" @@ -108,14 +110,14 @@ def test_dummy_remove_physchem_record_autodetect_molpipeline(self) -> None: ("mol2physchem", mol2physchem), ("remove_none", remove_none), ], + n_jobs=1, ) pipeline2 = clone(pipeline) pipeline.fit(TEST_SMILES) out = pipeline.transform(TEST_SMILES) - print(pipeline2["remove_none"].filter_everything) out2 = pipeline2.fit_transform(TEST_SMILES) self.assertEqual(out.shape, out2.shape) - self.assertTrue(np.max(np.abs(out - out2)) < 0.000001) + self.assertTrue(np.max(np.abs(out - out2)) < TOLERANCE) def test_dummy_fill_physchem_record_molpipeline(self) -> None: """Assert that invalid smiles are transformed to None.""" @@ -141,7 +143,7 @@ def test_dummy_fill_physchem_record_molpipeline(self) -> None: out2 = pipeline2.fit_transform(TEST_SMILES) self.assertEqual(out.shape, out2.shape) self.assertEqual(out.shape, (3, 215)) - self.assertTrue(np.nanmax(np.abs(out - out2)) < 0.000001) + self.assertTrue(np.nanmax(np.abs(out - out2)) < TOLERANCE) def test_replace_mixed_datatypes(self) -> None: """Assert that invalid values are replaced by fill value.""" From c1556c37e921883818ebb1c3e0bc39bcb1e68982 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 24 Apr 2025 13:41:00 +0200 Subject: [PATCH 02/31] linters --- molpipeline/abstract_pipeline_elements/core.py | 4 ++-- molpipeline/mol2any/mol2concatinated_vector.py | 18 ++++-------------- 2 files changed, 6 insertions(+), 16 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index c5907a50..ee3d39fa 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -227,8 +227,8 @@ def n_jobs(self, n_jobs: int) -> None: def fit( self, - values: Any, # noqa: ARG002 - labels: Any = None, # noqa: ARG002 + values: Any, # noqa: ARG002 # pylint: disable=unused-argument + labels: Any = None, # noqa: ARG002 # pylint: disable=unused-argument ) -> Self: """Fit object to input_values. diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index 70412d84..fc5a5dd0 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -70,7 +70,7 @@ def __init__( self._use_feature_names_prefix = use_feature_names_prefix super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) # set element execution details - self._set_element_execution_details(self._element_list) + self._set_element_execution_details() # set feature names self._feature_names = self._create_feature_names( self._element_list, @@ -160,18 +160,8 @@ def _create_feature_names( ) return feature_names - def _set_element_execution_details( - self, - element_list: list[tuple[str, MolToAnyPipelineElement]], # noqa: ARG002 - ) -> None: - """Set output type and requires fitting for the concatenated vector. - - Parameters - ---------- - element_list: list[tuple[str, MolToAnyPipelineElement]] - List of pipeline elements. - - """ + def _set_element_execution_details(self) -> None: + """Set output type and requires fitting for the concatenated vector.""" output_types = set() for _, element in self._element_list: element.n_jobs = self.n_jobs @@ -243,7 +233,7 @@ def _set_element_list( if len(element_list) == 0: raise ValueError("element_list must contain at least one element.") # reset element execution details - self._set_element_execution_details(self._element_list) + self._set_element_execution_details() step_params: dict[str, dict[str, Any]] = {} step_dict = dict(self._element_list) to_delete_list = [] From a6716fb02237038b30868300a0b494647f59c3eb Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 24 Apr 2025 13:44:53 +0200 Subject: [PATCH 03/31] move transform_single back in (for now) --- .../abstract_pipeline_elements/core.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index ee3d39fa..5f99724c 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -287,6 +287,25 @@ def transform(self, values: Any) -> Any: """ + @abc.abstractmethod + def transform_single(self, value: Any) -> Any: + """Transform a single molecule to the new representation. + + RemovedMolecule objects are passed without change. + + Parameters + ---------- + value: Any + Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) + + Returns + ------- + Any + New representation of the molecule. + (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) + + """ + class TransformingPipelineElement(ABCPipelineElement): """Ancestor of all PipelineElements.""" From 1051fec77fb237986ad2e29d77bbf6d928acc955 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 24 Apr 2025 14:57:11 +0200 Subject: [PATCH 04/31] remove parameter property --- molpipeline/pipeline/_molpipeline.py | 17 ----------------- 1 file changed, 17 deletions(-) diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py index c7a3b3a0..559f93c9 100644 --- a/molpipeline/pipeline/_molpipeline.py +++ b/molpipeline/pipeline/_molpipeline.py @@ -100,23 +100,6 @@ def n_jobs(self, requested_jobs: int) -> None: """ self._n_jobs = check_available_cores(requested_jobs) - @property - def parameters(self) -> dict[str, Any]: - """Get all parameters defining the object.""" - return self.get_params() - - @parameters.setter - def parameters(self, parameter_dict: dict[str, Any]) -> None: - """Set parameters of the pipeline and pipeline elements. - - Parameters - ---------- - parameter_dict: dict[str, Any] - Dictionary of parameter names and corresponding values to be set. - - """ - self.set_params(**parameter_dict) - def get_params(self, deep: bool = True) -> dict[str, Any]: """Get all parameters defining the object. From e32f72da6268e8b998bdaccc4a9db21c322ed954 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 24 Apr 2025 19:09:59 +0200 Subject: [PATCH 05/31] rewrite to use mixins --- .../abstract_pipeline_elements/core.py | 150 ++++++++++-------- molpipeline/error_handling.py | 25 ++- molpipeline/pipeline/_molpipeline.py | 74 ++++----- molpipeline/pipeline/_skl_pipeline.py | 122 +++++++++----- tests/test_elements/test_error_handling.py | 2 +- tests/utils/mock_element.py | 21 ++- 6 files changed, 246 insertions(+), 148 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 5f99724c..e5bb51bc 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -287,25 +287,6 @@ def transform(self, values: Any) -> Any: """ - @abc.abstractmethod - def transform_single(self, value: Any) -> Any: - """Transform a single molecule to the new representation. - - RemovedMolecule objects are passed without change. - - Parameters - ---------- - value: Any - Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) - - Returns - ------- - Any - New representation of the molecule. - (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) - - """ - class TransformingPipelineElement(ABCPipelineElement): """Ancestor of all PipelineElements.""" @@ -385,27 +366,6 @@ def fit_transform( self.fit(values, labels) return self.transform(values) - def transform_single(self, value: Any) -> Any: - """Transform a single molecule to the new representation. - - RemovedMolecule objects are passed without change. - - Parameters - ---------- - value: Any - Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) - - Returns - ------- - Any - New representation of the molecule. - (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) - - """ - if isinstance(value, InvalidInstance): - return value - return self.pretransform_single(value) - def assemble_output(self, value_list: Iterable[Any]) -> Any: """Aggregate rows, which in most cases is just return the list. @@ -426,25 +386,56 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: return list(value_list) @abc.abstractmethod - def pretransform_single(self, value: Any) -> Any: - """Transform the instance. + def pretransform(self, value_list: Iterable[Any]) -> list[Any]: + """Transform input_values according to object rules. - This is the first step for the full transformation. - It is followed by the finalize_single method and assemble output which collects - all single transformations. These functions are split as they need to be - accessed separately from outside the object. + Parameters + ---------- + value_list: Iterable[Any] + Iterable of instances to be pretransformed. + + Returns + ------- + list[Any] + Transformed input_values. + + """ + + def transform(self, values: Any) -> Any: + """Transform input_values according to object rules. Parameters ---------- - value: Any - Value to be pretransformed. + values: Any + Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, + PhysChem vectors etc.). + Input depends on the concrete PipelineElement. Returns ------- Any - Pretransformed value. (Skips applying parameters learned during fitting) + Transformed input_values. + + """ + output_values = self.pretransform(values) + return self.assemble_output(output_values) + + +class SingleInstanceTransformerMixin(abc.ABC): + """Mixin for single instance processing.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize SingleInstanceTransformerMixin. + + Parameters + ---------- + args: Any + Arguments for the object. + kwargs: Any + Keyword arguments for the object. """ + super().__init__(*args, **kwargs) def pretransform(self, value_list: Iterable[Any]) -> list[Any]: """Transform input_values according to object rules. @@ -460,32 +451,59 @@ def pretransform(self, value_list: Iterable[Any]) -> list[Any]: Transformed input_values. """ - parallel = Parallel(n_jobs=self.n_jobs) + parallel = Parallel(n_jobs=self.n_jobs) # type: ignore[attr-defined] return parallel( delayed(self.pretransform_single)(value) for value in value_list ) - def transform(self, values: Any) -> Any: - """Transform input_values according to object rules. + def transform_single(self, value: Any) -> Any: + """Transform a single molecule to the new representation. + + RemovedMolecule objects are passed without change. Parameters ---------- - values: Any - Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, - PhysChem vectors etc.). - Input depends on the concrete PipelineElement. + value: Any + Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) Returns ------- Any - Transformed input_values. + New representation of the molecule. + (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) + + """ + if isinstance(value, InvalidInstance): + return value + return self.pretransform_single(value) + + @abc.abstractmethod + def pretransform_single(self, value: Any) -> Any: + """Transform the instance. + + This is the first step for the full transformation. + It is followed by the finalize_single method and assemble output which collects + all single transformations. These functions are split as they need to be + accessed separately from outside the object. + + Parameters + ---------- + value: Any + Value to be pretransformed. + + Returns + ------- + Any + Pretransformed value. (Skips applying parameters learned during fitting) """ - output_values = self.pretransform(values) - return self.assemble_output(output_values) -class MolToMolPipelineElement(TransformingPipelineElement, abc.ABC): +class MolToMolPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, + abc.ABC, +): """Abstract PipelineElement where input and outputs are molecules.""" _input_type = "RDKitMol" @@ -558,7 +576,11 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: """ -class AnyToMolPipelineElement(TransformingPipelineElement, abc.ABC): +class AnyToMolPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, + abc.ABC, +): """Abstract PipelineElement which creates molecules from different inputs.""" _output_type = "RDKitMol" @@ -597,7 +619,11 @@ def pretransform_single(self, value: Any) -> OptionalMol: """ -class MolToAnyPipelineElement(TransformingPipelineElement, abc.ABC): +class MolToAnyPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, + abc.ABC, +): """Abstract PipelineElement which creates molecules from different inputs.""" _input_type = "RDKitMol" diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index b36c4f15..c68d83f5 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -6,11 +6,13 @@ import numpy as np import numpy.typing as npt +from loguru import logger from molpipeline.abstract_pipeline_elements.core import ( ABCPipelineElement, InvalidInstance, RemovedInstance, + SingleInstanceTransformerMixin, TransformingPipelineElement, ) @@ -25,7 +27,7 @@ _S = TypeVar("_S") -class ErrorFilter(ABCPipelineElement): +class ErrorFilter(SingleInstanceTransformerMixin, ABCPipelineElement): """Filter to remove InvalidInstances from a list of values.""" element_ids: set[str] @@ -257,6 +259,9 @@ def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: """ if self.n_total != len(values): + logger.error( + f"Expected {self.n_total} values, but got {len(values)}", + ) raise ValueError("Length of values does not match length of values in fit") if isinstance(values, list): out_list = [] @@ -294,6 +299,24 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: def transform_single(self, value: Any) -> Any: """Transform a single value. + Overrides parent method to not skip InvalidInstances. + + Parameters + ---------- + value: Any + Value to be transformed. + + Returns + ------- + Any + The original value or a RemovedInstance. + + """ + return self.pretransform_single(value) + + def pretransform_single(self, value: Any) -> Any: + """Transform a single value. + Parameters ---------- value: Any diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py index 559f93c9..cf6dbf94 100644 --- a/molpipeline/pipeline/_molpipeline.py +++ b/molpipeline/pipeline/_molpipeline.py @@ -13,6 +13,7 @@ ABCPipelineElement, InvalidInstance, RemovedInstance, + SingleInstanceTransformerMixin, TransformingPipelineElement, ) from molpipeline.error_handling import ( @@ -23,7 +24,7 @@ from molpipeline.utils.multi_proc import check_available_cores if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Iterable, Sequence from molpipeline.utils.molpipeline_types import TypeFixedVarSeq @@ -32,11 +33,11 @@ class _MolPipeline: """Contains the PipeElements which describe the functionality of the pipeline.""" _n_jobs: int - _element_list: list[ABCPipelineElement] + _element_list: Sequence[ABCPipelineElement] def __init__( self, - element_list: list[ABCPipelineElement], + element_list: Sequence[ABCPipelineElement], n_jobs: int = 1, name: str = "MolPipeline", ) -> None: @@ -151,42 +152,7 @@ def set_params(self, **parameter_dict: Any) -> Self: @property def element_list(self) -> list[ABCPipelineElement]: """Get a shallow copy from the list of pipeline elements.""" - return self._element_list[:] # [:] to create shallow copy. - - def _get_meta_element_list( - self, - ) -> list[ABCPipelineElement | _MolPipeline]: - """Merge elements which do not require fitting to a meta element. - - This improves the parallelization of the pipeline. - - Returns - ------- - list[ABCPipelineElement | _MolPipeline] - List of pipeline elements and meta elements. - - """ - meta_element_list: list[ABCPipelineElement | _MolPipeline] = [] - no_fit_element_list: list[ABCPipelineElement] = [] - for element in self._element_list: - if isinstance(element, TransformingPipelineElement): - no_fit_element_list.append(element) - else: - if len(no_fit_element_list) == 1: - meta_element_list.append(no_fit_element_list[0]) - elif len(no_fit_element_list) > 1: - meta_element_list.append( - _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs), - ) - no_fit_element_list = [] - meta_element_list.append(element) - if len(no_fit_element_list) == 1: - meta_element_list.append(no_fit_element_list[0]) - elif len(no_fit_element_list) > 1: - meta_element_list.append( - _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs), - ) - return meta_element_list + return list(self._element_list) # to create shallow copy. def fit( self, @@ -233,6 +199,11 @@ def fit_transform( # pylint: disable=unused-argument fit_params: Any Parameters. Only for SKlearn compatibility. + Raises + ------ + AssertionError + If the PipelineElement is not a TransformingPipelineElement. + Returns ------- Any @@ -247,9 +218,18 @@ def fit_transform( # pylint: disable=unused-argument iter_idx_array = np.arange(len(iter_input)) # The meta elements merge steps which do not require fitting - for i_element in self._get_meta_element_list(): - if not isinstance(i_element, (TransformingPipelineElement, _MolPipeline)): + for i_element in self._element_list: + if not isinstance( + i_element, + (SingleInstanceTransformerMixin, _MolPipeline), + ): + continue + if isinstance(i_element, ErrorFilter): continue + if not isinstance(i_element, TransformingPipelineElement): + raise AssertionError( + "PipelineElement is not a TransformingPipelineElement.", + ) i_element.n_jobs = self.n_jobs iter_input = i_element.pretransform(iter_input) for error_filter in self._filter_elements: @@ -287,6 +267,11 @@ def transform_single(self, input_value: Any) -> Any: input_value: Any Molecular representation which is subsequently transformed. + Raises + ------ + AssertionError + If the PipelineElement is not a SingleInstanceTransformerMixin. + Returns ------- Any @@ -296,6 +281,13 @@ def transform_single(self, input_value: Any) -> Any: log_block = BlockLogs() iter_value = input_value for p_element in self._element_list: + if not isinstance( + p_element, + (SingleInstanceTransformerMixin, _MolPipeline), + ): + raise AssertionError( + "PipelineElement is not a SingleInstanceTransformerMixin.", + ) try: if not isinstance(iter_value, RemovedInstance) or isinstance( p_element, diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 8a9d0507..5fa47e6b 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -4,29 +4,29 @@ from __future__ import annotations -from collections.abc import Iterable from copy import deepcopy -from typing import Any, Literal, Self, TypeVar +from typing import TYPE_CHECKING, Any, Literal, Self -import joblib import numpy as np import numpy.typing as npt from loguru import logger -from sklearn.base import _fit_context, clone +from sklearn.base import _fit_context, clone # noqa: PLC2701 from sklearn.pipeline import Pipeline as _Pipeline -from sklearn.pipeline import _final_estimator_has, _fit_transform_one -from sklearn.utils import Bunch -from sklearn.utils._tags import Tags, get_tags +from sklearn.pipeline import _final_estimator_has, _fit_transform_one # noqa: PLC2701 +from sklearn.utils._tags import Tags, get_tags # noqa: PLC2701 from sklearn.utils.metadata_routing import ( MetadataRouter, MethodMapping, - _routing_enabled, + _routing_enabled, # noqa: PLC2701 process_routing, ) from sklearn.utils.metaestimators import available_if from sklearn.utils.validation import check_memory -from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement +from molpipeline.abstract_pipeline_elements.core import ( + ABCPipelineElement, + SingleInstanceTransformerMixin, +) from molpipeline.error_handling import ErrorFilter, FilterReinserter from molpipeline.pipeline._molpipeline import _MolPipeline from molpipeline.post_prediction import ( @@ -42,11 +42,13 @@ ) from molpipeline.utils.value_checks import is_empty -__all__ = ["Pipeline"] +if TYPE_CHECKING: + from collections.abc import Iterable -# Type definitions -_T = TypeVar("_T") -# Cannot be moved to utils.molpipeline_types due to circular imports + import joblib + from sklearn.utils import Bunch + +__all__ = ["Pipeline"] _IndexedStep = tuple[int, str, AnyElement] @@ -97,7 +99,7 @@ def _set_error_resinserter(self) -> None: n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter) ] for step in self.steps: - if isinstance(step[1], PostPredictionWrapper): + if isinstance(step[1], PostPredictionWrapper): # noqa: SIM102 if isinstance(step[1].wrapped_estimator, FilterReinserter): error_replacer_list.append(step[1].wrapped_estimator) for error_replacer in error_replacer_list: @@ -158,7 +160,8 @@ def _iter( ) -> Iterable[_AggregatedPipelineStep]: """Iterate over all non post-processing steps. - Steps which are children of a ABCPipelineElement were aggregated to a MolPipeline. + Steps which are children of a ABCPipelineElement were aggregated to a + MolPipeline. Parameters ---------- @@ -197,9 +200,8 @@ def _iter( if last_element is None: raise AssertionError("Pipeline needs to have at least one step!") - if with_final and last_element[2] is not None: - if last_element[2] != "passthrough": - yield last_element + if with_final and last_element[2] not in {None, "passthrough"}: + yield last_element @property def _estimator_type(self) -> Any: @@ -208,7 +210,7 @@ def _estimator_type(self) -> Any: return None if hasattr(self._final_estimator, "_estimator_type"): # pylint: disable=protected-access - return self._final_estimator._estimator_type + return self._final_estimator._estimator_type # noqa: SLF001 return None @property @@ -227,9 +229,9 @@ def _final_estimator( return last_element[2] # pylint: disable=too-many-locals,too-many-branches - def _fit( + def _fit( # noqa: PLR0912 self, - X: Any, + X: Any, # noqa: N803 y: Any = None, routed_params: dict[str, Any] | None = None, raw_params: dict[str, Any] | None = None, @@ -306,7 +308,7 @@ def _fit( ) # Fit or load from cache the current transformer - X, fitted_transformer = fit_transform_one_cached( + X, fitted_transformer = fit_transform_one_cached( # noqa: N806 cloned_transformer, X, y, @@ -341,7 +343,7 @@ def _fit( def _transform( self, - X: Any, # pylint: disable=invalid-name + X: Any, # pylint: disable=invalid-name # noqa: N803 routed_params: Bunch, ) -> Any: """Transform the data, and skip final estimator. @@ -452,6 +454,11 @@ def _agg_non_postpred_steps( When filter_passthrough is True, 'passthrough' and None transformers are filtered out. + Raises + ------ + AssertionError + If the step is not a PipelineElement. + Yields ------ _AggregatedPipelineStep @@ -459,9 +466,14 @@ def _agg_non_postpred_steps( transformer. """ + aggregated_transformer_list: list[tuple[int, str, ABCPipelineElement]] aggregated_transformer_list = [] for i, (name_i, step_i) in enumerate(self._non_post_processing_steps()): - if isinstance(step_i, ABCPipelineElement): + if not isinstance(step_i, ABCPipelineElement): + raise AssertionError( + "Step is not a PipelineElement, hence cannot be aggregated.", + ) + if isinstance(step_i, SingleInstanceTransformerMixin): aggregated_transformer_list.append((i, name_i, step_i)) else: if aggregated_transformer_list: @@ -493,7 +505,12 @@ def _agg_non_postpred_steps( # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False, ) - def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: + def fit( + self, + X: Any, # noqa: N803 + y: Any = None, + **fit_params: Any, + ) -> Self: """Fit the model. Fit all the transformers one after the other and transform the @@ -521,10 +538,10 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: """ routed_params = self._check_method_params(method="fit", props=fit_params) - Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name + xt, yt = self._fit(X, y, routed_params) with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator != "passthrough": - if is_empty(Xt): + if is_empty(xt): logger.warning( "All input rows were filtered out! Model is not fitted!", ) @@ -532,7 +549,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: fit_params_last_step = routed_params[ self._non_post_processing_steps()[-1][0] ] - self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"]) + self._final_estimator.fit(xt, yt, **fit_params_last_step["fit"]) return self @@ -567,7 +584,12 @@ def _can_decision_function(self) -> bool: # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False, ) - def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: + def fit_transform( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: """Fit the model and transform with the final estimator. Fits all the transformers one after the other and transform the @@ -636,7 +658,11 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: return iter_input @available_if(_final_estimator_has("predict")) - def predict(self, X: Any, **params: Any) -> Any: + def predict( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: """Transform the data, and apply `predict` with the final estimator. Call `transform` of each transformer in the pipeline. The transformed @@ -692,7 +718,8 @@ def predict(self, X: Any, **params: Any) -> Any: iter_input = self._final_estimator.predict(iter_input, **params) else: raise AssertionError( - "Final estimator does not implement predict, hence this function should not be available.", + "Final estimator does not implement predict, " + "hence this function should not be available.", ) for _, post_element in self._post_processing_steps(): iter_input = post_element.transform(iter_input) @@ -703,7 +730,12 @@ def predict(self, X: Any, **params: Any) -> Any: # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False, ) - def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: + def fit_predict( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: """Transform the data, and apply `fit_predict` with the final estimator. Call `fit_transform` of each transformer in the pipeline. The @@ -765,7 +797,11 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: return y_pred @available_if(_final_estimator_has("predict_proba")) - def predict_proba(self, X: Any, **params: Any) -> Any: + def predict_proba( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: """Transform the data, and apply `predict_proba` with the final estimator. Call `transform` of each transformer in the pipeline. The transformed @@ -839,7 +875,11 @@ def _can_transform(self) -> bool: ) @available_if(_can_transform) - def transform(self, X: Any, **params: Any) -> Any: + def transform( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: """Transform the data, and apply `transform` with the final estimator. Call `transform` of each transformer in the pipeline. The transformed @@ -896,7 +936,11 @@ def transform(self, X: Any, **params: Any) -> Any: return iter_input @available_if(_can_decision_function) - def decision_function(self, X: Any, **params: Any) -> Any: + def decision_function( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: """Transform the data, and apply `decision_function` with the final estimator. Parameters @@ -972,7 +1016,7 @@ def classes_(self) -> list[Any] | npt.NDArray[Any]: return last_step.classes_ raise ValueError("Last step has no classes_ attribute.") - def __sklearn_tags__(self) -> Tags: + def __sklearn_tags__(self) -> Tags: # noqa: PLW3201 """Return the sklearn tags. Notes @@ -1011,7 +1055,8 @@ def __sklearn_tags__(self) -> Tags: pass try: - # Only the _final_estimator is changed from the original implementation is changed in the following 2 lines + # Only the _final_estimator is changed from the original implementation is + # changed in the following 2 lines if ( self._final_estimator is not None and self._final_estimator != "passthrough" @@ -1083,7 +1128,8 @@ def get_metadata_routing(self) -> MetadataRouter: router.add(method_mapping=method_mapping, **{name: trans}) - # Only the _non_post_processing_steps is changed from the original implementation is changed in the following line + # Only the _non_post_processing_steps is changed from the original + # implementation is changed in the following line final_name, final_est = self._non_post_processing_steps()[-1] if final_est is None or final_est == "passthrough": return router diff --git a/tests/test_elements/test_error_handling.py b/tests/test_elements/test_error_handling.py index ad1130a3..8ee8bf53 100644 --- a/tests/test_elements/test_error_handling.py +++ b/tests/test_elements/test_error_handling.py @@ -248,7 +248,7 @@ def test_replace_mixed_datatypes_expected_failures(self) -> None: ) pipeline2 = clone(pipeline) - self.assertRaises(ValueError, pipeline.fit, test_values) + pipeline.fit(test_values) self.assertRaises(ValueError, pipeline.transform, test_values) self.assertRaises(ValueError, pipeline2.fit_transform, test_values) diff --git a/tests/utils/mock_element.py b/tests/utils/mock_element.py index 2b89987c..5e7c6c65 100644 --- a/tests/utils/mock_element.py +++ b/tests/utils/mock_element.py @@ -3,18 +3,24 @@ from __future__ import annotations import copy -from collections.abc import Iterable -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import numpy as np from molpipeline.abstract_pipeline_elements.core import ( InvalidInstance, + SingleInstanceTransformerMixin, TransformingPipelineElement, ) +if TYPE_CHECKING: + from collections.abc import Iterable -class MockTransformingPipelineElement(TransformingPipelineElement): + +class MockTransformingPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, +): """Mock element for testing.""" def __init__( @@ -40,6 +46,7 @@ def __init__( Unique identifier of PipelineElement. n_jobs: int, default=1 Number of jobs to run in parallel. + """ super().__init__(name=name, uuid=uuid, n_jobs=n_jobs) if invalid_values is None: @@ -59,6 +66,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary containing all parameters defining the object. + """ params = super().get_params(deep) if deep: @@ -81,6 +89,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self MockTransformingPipelineElement with updated parameters. + """ super().set_params(**parameters) if "invalid_values" in parameters: @@ -101,6 +110,7 @@ def pretransform_single(self, value: Any) -> Any: ------- Any Other value. + """ if value in self.invalid_values: return InvalidInstance( @@ -113,8 +123,8 @@ def pretransform_single(self, value: Any) -> Any: def assemble_output(self, value_list: Iterable[Any]) -> Any: """Aggregate rows, which in most cases is just return the list. - Some representations might be better representd as a single object. For example a list of vectors can - be transformed to a matrix. + Some representations might be better representd as a single object. + For example a list of vectors can be transformed to a matrix. Parameters ---------- @@ -125,6 +135,7 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: ------- Any Aggregated output. This can also be the original input. + """ if self.return_as_numpy_array: return np.array(list(value_list)) From 06f47532891b6b1c43fe0c78cebc10b36556967c Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 7 May 2025 13:05:33 +0200 Subject: [PATCH 06/31] remove unnecessary type check --- molpipeline/pipeline/_skl_pipeline.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 5fa47e6b..0c35d7b8 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -469,10 +469,6 @@ def _agg_non_postpred_steps( aggregated_transformer_list: list[tuple[int, str, ABCPipelineElement]] aggregated_transformer_list = [] for i, (name_i, step_i) in enumerate(self._non_post_processing_steps()): - if not isinstance(step_i, ABCPipelineElement): - raise AssertionError( - "Step is not a PipelineElement, hence cannot be aggregated.", - ) if isinstance(step_i, SingleInstanceTransformerMixin): aggregated_transformer_list.append((i, name_i, step_i)) else: From 8045215ea0d34739ac5de32a696418ac6bc3c368 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 7 May 2025 13:07:00 +0200 Subject: [PATCH 07/31] remove Raises from docu --- molpipeline/pipeline/_skl_pipeline.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 0c35d7b8..b83340b5 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -454,11 +454,6 @@ def _agg_non_postpred_steps( When filter_passthrough is True, 'passthrough' and None transformers are filtered out. - Raises - ------ - AssertionError - If the step is not a PipelineElement. - Yields ------ _AggregatedPipelineStep From be4dd082118da52ab22de114978f55e577399e19 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 7 May 2025 13:09:16 +0200 Subject: [PATCH 08/31] remove pylint ignore --- molpipeline/pipeline/_skl_pipeline.py | 1 - 1 file changed, 1 deletion(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index b83340b5..0e7efc88 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -209,7 +209,6 @@ def _estimator_type(self) -> Any: if self._final_estimator is None or self._final_estimator == "passthrough": return None if hasattr(self._final_estimator, "_estimator_type"): - # pylint: disable=protected-access return self._final_estimator._estimator_type # noqa: SLF001 return None From d1411471d6a4437ac3c45aed3c35aafaeb6376ff Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 11:42:17 +0200 Subject: [PATCH 09/31] fix unittests and rewrite --- molpipeline/error_handling.py | 21 - molpipeline/pipeline/_molpipeline.py | 404 ------- molpipeline/pipeline/_skl_adapter_pipeline.py | 987 ++++++++++++++++++ molpipeline/pipeline/_skl_pipeline.py | 834 ++++++--------- 4 files changed, 1293 insertions(+), 953 deletions(-) delete mode 100644 molpipeline/pipeline/_molpipeline.py create mode 100644 molpipeline/pipeline/_skl_adapter_pipeline.py diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index c68d83f5..7d41d425 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -695,27 +695,6 @@ def fit_transform( self.fit(values) return self.transform(values) - def transform_single(self, value: Any) -> Any: - """Transform a single value. - - Parameters - ---------- - value: Any - Value to be transformed. - - Returns - ------- - Any - Transformed value. - - """ - if ( - isinstance(value, RemovedInstance) - and value.filter_element_id == self.error_filter.uuid - ): - return self.fill_value - return value - def transform( self, values: TypeFixedVarSeq, diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py deleted file mode 100644 index cf6dbf94..00000000 --- a/molpipeline/pipeline/_molpipeline.py +++ /dev/null @@ -1,404 +0,0 @@ -"""Defines the pipeline which handles pipeline elements for molecular operations.""" - -from __future__ import annotations - -from typing import TYPE_CHECKING, Any, Self - -import numpy as np -from joblib import Parallel, delayed -from rdkit.Chem.rdchem import MolSanitizeException -from rdkit.rdBase import BlockLogs - -from molpipeline.abstract_pipeline_elements.core import ( - ABCPipelineElement, - InvalidInstance, - RemovedInstance, - SingleInstanceTransformerMixin, - TransformingPipelineElement, -) -from molpipeline.error_handling import ( - ErrorFilter, - FilterReinserter, - _MultipleErrorFilter, -) -from molpipeline.utils.multi_proc import check_available_cores - -if TYPE_CHECKING: - from collections.abc import Iterable, Sequence - - from molpipeline.utils.molpipeline_types import TypeFixedVarSeq - - -class _MolPipeline: - """Contains the PipeElements which describe the functionality of the pipeline.""" - - _n_jobs: int - _element_list: Sequence[ABCPipelineElement] - - def __init__( - self, - element_list: Sequence[ABCPipelineElement], - n_jobs: int = 1, - name: str = "MolPipeline", - ) -> None: - """Initialize MolPipeline. - - Parameters - ---------- - element_list: list[ABCPipelineElement] - List of Pipeline Elements which form the pipeline. - n_jobs: - Number of cores used. - name: str - Name of pipeline. - - """ - self._element_list = element_list - self.n_jobs = n_jobs - self.name = name - - @property - def _filter_elements(self) -> list[ErrorFilter]: - """Get the elements which filter the input.""" - return [ - element - for element in self._element_list - if isinstance(element, ErrorFilter) - ] - - @property - def _filter_elements_agg(self) -> _MultipleErrorFilter: - """Get the aggregated filter element.""" - return _MultipleErrorFilter(self._filter_elements) - - @property - def _transforming_elements( - self, - ) -> list[TransformingPipelineElement | _MolPipeline]: - """Get the elements which transform the input.""" - return [ - element - for element in self._element_list - if isinstance(element, (TransformingPipelineElement, _MolPipeline)) - ] - - @property - def n_jobs(self) -> int: - """Return the number of cores to use in transformation step.""" - return self._n_jobs - - @n_jobs.setter - def n_jobs(self, requested_jobs: int) -> None: - """Set the number of cores to use in transformation step. - - Parameters - ---------- - requested_jobs: int - Number of cores requested for transformation steps. - If fewer cores than requested are available, the number of cores is set to - the maximum available. - - """ - self._n_jobs = check_available_cores(requested_jobs) - - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Get all parameters defining the object. - - Parameters - ---------- - deep: bool - If True get a deep copy of the parameters. - - Returns - ------- - dict[str, Any] - Dictionary of parameter names and corresponding values. - - """ - if deep: - return { - "element_list": self.element_list, - "n_jobs": self.n_jobs, - "name": self.name, - } - return { - "element_list": self._element_list, - "n_jobs": self.n_jobs, - "name": self.name, - } - - def set_params(self, **parameter_dict: Any) -> Self: - """Set parameters of the pipeline and pipeline elements. - - Parameters - ---------- - parameter_dict: Any - Dictionary of parameter names and corresponding values to be set. - - Returns - ------- - Self - MolPipeline object with updated parameters. - - """ - if "element_list" in parameter_dict: - self._element_list = parameter_dict["element_list"] - if "n_jobs" in parameter_dict: - self.n_jobs = int(parameter_dict["n_jobs"]) - if "name" in parameter_dict: - self.name = str(parameter_dict["name"]) - return self - - @property - def element_list(self) -> list[ABCPipelineElement]: - """Get a shallow copy from the list of pipeline elements.""" - return list(self._element_list) # to create shallow copy. - - def fit( - self, - x_input: Any, - y: Any = None, - **fit_params: dict[Any, Any], - ) -> Self: - """Fit the MolPipeline according to x_input. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently processed. - y: Any - Optional label of input. Only for SKlearn compatibility. - fit_params: Any - Parameters. Only for SKlearn compatibility. - - Returns - ------- - Self - Fitted MolPipeline. - - """ - _ = y # Making pylint happy - _ = fit_params # Making pylint happy - self.fit_transform(x_input) - return self - - def fit_transform( # pylint: disable=unused-argument - self, - x_input: Any, - y: Any = None, # noqa: ARG002 - **fit_params: dict[str, Any], # noqa: ARG002 - ) -> Any: - """Fit to x_input and return the transformed molecules. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently processed. - y: Any - Optional label of input. Only for SKlearn compatibility. - fit_params: Any - Parameters. Only for SKlearn compatibility. - - Raises - ------ - AssertionError - If the PipelineElement is not a TransformingPipelineElement. - - Returns - ------- - Any - Transformed molecules. - - """ - iter_input = x_input - - removed_rows: dict[ErrorFilter, list[int]] = {} - for error_filter in self._filter_elements: - removed_rows[error_filter] = [] - iter_idx_array = np.arange(len(iter_input)) - - # The meta elements merge steps which do not require fitting - for i_element in self._element_list: - if not isinstance( - i_element, - (SingleInstanceTransformerMixin, _MolPipeline), - ): - continue - if isinstance(i_element, ErrorFilter): - continue - if not isinstance(i_element, TransformingPipelineElement): - raise AssertionError( - "PipelineElement is not a TransformingPipelineElement.", - ) - i_element.n_jobs = self.n_jobs - iter_input = i_element.pretransform(iter_input) - for error_filter in self._filter_elements: - iter_input = error_filter.transform(iter_input) - for idx in error_filter.error_indices: - idx = iter_idx_array[idx] - removed_rows[error_filter].append(idx) - iter_idx_array = error_filter.co_transform(iter_idx_array) - iter_input = i_element.assemble_output(iter_input) - i_element.n_jobs = 1 - - # Set removed rows to filter elements to allow for correct co_transform - iter_idx_array = np.arange(len(x_input)) - for error_filter in self._filter_elements: - removed_idx_list = removed_rows[error_filter] - error_filter.error_indices = [] - for new_idx, _idx in enumerate(iter_idx_array): - if _idx in removed_idx_list: - error_filter.error_indices.append(new_idx) - error_filter.n_total = len(iter_idx_array) - iter_idx_array = error_filter.co_transform(iter_idx_array) - error_replacer_list = [ - ele for ele in self._element_list if isinstance(ele, FilterReinserter) - ] - for error_replacer in error_replacer_list: - error_replacer.select_error_filter(self._filter_elements) - iter_input = error_replacer.transform(iter_input) - return iter_input - - def transform_single(self, input_value: Any) -> Any: - """Transform a single input according to the sequence of PipelineElements. - - Parameters - ---------- - input_value: Any - Molecular representation which is subsequently transformed. - - Raises - ------ - AssertionError - If the PipelineElement is not a SingleInstanceTransformerMixin. - - Returns - ------- - Any - Transformed molecular representation. - - """ - log_block = BlockLogs() - iter_value = input_value - for p_element in self._element_list: - if not isinstance( - p_element, - (SingleInstanceTransformerMixin, _MolPipeline), - ): - raise AssertionError( - "PipelineElement is not a SingleInstanceTransformerMixin.", - ) - try: - if not isinstance(iter_value, RemovedInstance) or isinstance( - p_element, - FilterReinserter, - ): - iter_value = p_element.transform_single(iter_value) - except MolSanitizeException as err: - iter_value = InvalidInstance( - p_element.uuid, - f"RDKit MolSanitizeException: {err.args}", - p_element.name, - ) - del log_block - return iter_value - - def pretransform(self, x_input: Any) -> Any: - """Transform the input according to the sequence without assemble_output step. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently transformed. - - Returns - ------- - Any - Transformed molecular representations. - - """ - return list(self._transform_iterator(x_input)) - - def transform(self, x_input: Any) -> Any: - """Transform the input according to the sequence of provided PipelineElements. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently transformed. - - Returns - ------- - Any - Transformed molecular representations. - - """ - output_generator = self._transform_iterator(x_input) - return self.assemble_output(output_generator) - - def assemble_output(self, value_list: Iterable[Any]) -> Any: - """Assemble the output of the pipeline. - - Parameters - ---------- - value_list: Iterable[Any] - Generator which yields the output of the pipeline. - - Returns - ------- - Any - Assembled output. - - """ - last_element = self._transforming_elements[-1] - if hasattr(last_element, "assemble_output"): - return last_element.assemble_output(value_list) - return list(value_list) - - def _transform_iterator(self, x_input: Any) -> Any: - """Transform the input according to the sequence of provided PipelineElements. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently transformed. - - Yields - ------ - Any - Transformed molecular representations. - - """ - agg_filter = self._filter_elements_agg - for filter_element in self._filter_elements: - filter_element.error_indices = [] - parallel = Parallel( - n_jobs=self.n_jobs, - return_as="generator", - batch_size="auto", - ) - output_generator = parallel( - delayed(self.transform_single)(value) for value in x_input - ) - for i, transformed_value in enumerate(output_generator): - if isinstance(transformed_value, RemovedInstance): - agg_filter.register_removed(i, transformed_value) - else: - yield transformed_value - agg_filter.set_total(len(x_input)) - - def co_transform(self, x_input: TypeFixedVarSeq) -> TypeFixedVarSeq: - """Filter flagged rows from the input. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently filtered. - - Returns - ------- - Any - Filtered molecular representations. - - """ - return self._filter_elements_agg.co_transform(x_input) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py new file mode 100644 index 00000000..cae1b858 --- /dev/null +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -0,0 +1,987 @@ +"""Module to change functions of the sklearn pipeline.""" + +from __future__ import annotations + +from itertools import islice +from typing import TYPE_CHECKING, Any, Literal, Self + +import numpy as np +import numpy.typing as npt +from loguru import logger +from sklearn.base import _fit_context, clone # noqa: PLC2701 +from sklearn.pipeline import Pipeline as _Pipeline +from sklearn.pipeline import _final_estimator_has, _fit_transform_one # noqa: PLC2701 +from sklearn.utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + _routing_enabled, # noqa: PLC2701 + process_routing, +) +from sklearn.utils.metaestimators import available_if +from sklearn.utils.validation import check_memory + +from molpipeline.error_handling import ErrorFilter, FilterReinserter +from molpipeline.post_prediction import ( + PostPredictionTransformation, + PostPredictionWrapper, +) +from molpipeline.utils.logging import print_elapsed_time +from molpipeline.utils.molpipeline_types import ( + AnyElement, + AnyPredictor, + AnyStep, + AnyTransformer, +) +from molpipeline.utils.value_checks import is_empty + +if TYPE_CHECKING: + from collections.abc import Generator + + import joblib + from sklearn.utils import Bunch + + from molpipeline.abstract_pipeline_elements.core import ( + ABCPipelineElement, + ) + + +_IndexedStep = tuple[int, str, AnyElement] + + +class AdapterPipeline(_Pipeline): + """Defines the pipeline which handles pipeline elements.""" + + steps: list[AnyStep] + # * Adapted methods from sklearn.pipeline.Pipeline * + + @property + def _estimator_type(self) -> Any: + """Return the estimator type.""" + if self._final_estimator is None or self._final_estimator == "passthrough": + return None + if hasattr(self._final_estimator, "_estimator_type"): + return self._final_estimator._estimator_type # noqa: SLF001 + return None + + @property + def _final_estimator( + self, + ) -> Literal["passthrough"] | AnyTransformer | AnyPredictor | ABCPipelineElement: + """Return the lst estimator which is not a PostprocessingTransformer.""" + return self._modified_steps[-1][1] + + @property + def _modified_steps( + self, + ) -> list[AnyStep]: + """Return modified version of steps. + + Returns only steps before the first PostPredictionTransformation. + + Raises + ------ + AssertionError + If a PostPredictionTransformation is found before the last step. + + """ + non_post_processing_steps: list[AnyStep] = [] + start_adding = False + for step_name, step_estimator in self.steps[::-1]: + if not isinstance(step_estimator, PostPredictionTransformation): + start_adding = True + if start_adding: + if isinstance(step_estimator, PostPredictionTransformation): + raise AssertionError( + "PipelineElement of type PostPredictionTransformation occured " + "before the last step.", + ) + non_post_processing_steps.append((step_name, step_estimator)) + return list(non_post_processing_steps[::-1]) + + @property + def _post_processing_steps(self) -> list[tuple[str, PostPredictionTransformation]]: + """Return last steps which are PostPredictionTransformation.""" + post_processing_steps = [] + for step_name, step_estimator in self.steps[::-1]: + if isinstance(step_estimator, PostPredictionTransformation): + post_processing_steps.append((step_name, step_estimator)) + else: + break + return list(post_processing_steps[::-1]) + + @property + def classes_(self) -> list[Any] | npt.NDArray[Any]: + """Return the classes of the last element. + + PostPredictionTransformation elements are not considered as last element. + + Raises + ------ + ValueError + If the last step is passthrough or has no classes_ attribute. + + """ + check_last = [ + step + for step in self.steps + if not isinstance(step[1], PostPredictionTransformation) + ] + last_step = check_last[-1][1] + if last_step == "passthrough": + raise ValueError("Last step is passthrough.") + if hasattr(last_step, "classes_"): + return last_step.classes_ + raise ValueError("Last step has no classes_ attribute.") + + def __init__( + self, + steps: list[AnyStep], + *, + memory: str | joblib.Memory | None = None, + verbose: bool = False, + n_jobs: int = 1, + ): + """Initialize Pipeline. + + Parameters + ---------- + steps: list[tuple[str, AnyTransformer | AnyPredictor | ABCPipelineElement]] + List of (name, Estimator) tuples. + memory: str | joblib.Memory | None, optional + Path to cache transformers. + verbose: bool, optional + If True, print additional information. + n_jobs: int, optional + Number of cores used for aggregated steps. + + """ + super().__init__(steps, memory=memory, verbose=verbose) + self.n_jobs = n_jobs + self._set_error_resinserter() + + def _set_error_resinserter(self) -> None: + """Connect the error resinserters with the error filters.""" + error_replacer_list = [ + e_filler + for _, e_filler in self.steps + if isinstance(e_filler, FilterReinserter) + ] + error_filter_list = [ + n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter) + ] + for step in self.steps: + if isinstance(step[1], PostPredictionWrapper): # noqa: SIM102 + if isinstance(step[1].wrapped_estimator, FilterReinserter): + error_replacer_list.append(step[1].wrapped_estimator) + for error_replacer in error_replacer_list: + error_replacer.select_error_filter(error_filter_list) + + def _validate_steps(self) -> None: + """Validate the steps. + + Raises + ------ + TypeError + If the steps do not implement fit and transform or are not 'passthrough'. + + """ + names = [name for name, _ in self.steps] + + # validate names + self._validate_names(names) + + # validate estimators + estimator = self._modified_steps[-1][1] + + for _, transformer in self._modified_steps[:-1]: + if transformer is None or transformer == "passthrough": + continue + if not ( + hasattr(transformer, "fit") or hasattr(transformer, "fit_transform") + ) or not hasattr(transformer, "transform"): + raise TypeError( + f"All intermediate steps should be " + f"transformers and implement fit and transform " + f"or be the string 'passthrough' " + f"'{transformer}' (type {type(transformer)}) doesn't", + ) + + # We allow last estimator to be None as an identity transformation + if ( + estimator is not None + and estimator != "passthrough" + and not hasattr(estimator, "fit") + ): + raise TypeError( + f"Last step of Pipeline should implement fit " + f"or be the string 'passthrough'. " + f"'{estimator}' (type {type(estimator)}) doesn't", + ) + + # validate post-processing steps + # Calling steps automatically validates them + _ = self._post_processing_steps + + def _iter( + self, + with_final: bool = True, + filter_passthrough: bool = True, + ) -> Generator[ + tuple[int, str, AnyElement] | tuple[list[int], list[str], AnyElement], + Any, + None, + ]: + """Iterate over all non post-processing steps. + + Steps which are children of a ABCPipelineElement were aggregated to a + MolPipeline. + + Parameters + ---------- + with_final: bool, optional + If True, the final estimator is included. + filter_passthrough: bool, optional + If True, passthrough steps are filtered out. + + Yields + ------ + _AggregatedPipelineStep + The _AggregatedPipelineStep is composed of the index, the name and the + transformer. + + """ + stop = len(self._modified_steps) + if not with_final: + stop -= 1 + + for idx, (name, trans) in enumerate(islice(self._modified_steps, 0, stop)): + if not filter_passthrough or (trans is not None and trans != "passthrough"): + yield idx, name, trans + + # pylint: disable=too-many-locals,too-many-branches + def _fit( # noqa: PLR0912 + self, + X: Any, # noqa: N803 + y: Any = None, + routed_params: dict[str, Any] | None = None, + raw_params: dict[str, Any] | None = None, + ) -> tuple[Any, Any]: + """Fit the model by fitting all transformers except the final estimator. + + Data can be subsetted by the transformers. + + Parameters + ---------- + X : Any + Training data. + y : Any, optional (default=None) + Training objectives. + routed_params : dict[str, Any], optional + Parameters for each step as returned by process_routing. + Although this is marked as optional, it should not be None. + The awkward (argward?) typing is due to inheritance from sklearn. + Can be an empty dictionary. + raw_params : dict[str, Any], optional + Parameters passed by the user, used when `transform_input` + + Raises + ------ + AssertionError + If routed_params is None or if the transformer is 'passthrough'. + AssertionError + If the names are a list and the step is not a Pipeline. + + Returns + ------- + tuple[Any, Any] + The transformed data and the transformed objectives. + + """ + # shallow copy of steps - this should really be steps_ + self.steps = list(self.steps) + self._validate_steps() + if routed_params is None: + raise AssertionError("routed_params should not be None.") + + # Set up the memory + memory: joblib.Memory = check_memory(self.memory) + + fit_transform_one_cached = memory.cache(_fit_transform_one) + for step in self._iter(with_final=False, filter_passthrough=False): + step_idx, name, transformer = step + if transformer is None or transformer == "passthrough": + with print_elapsed_time("Pipeline", self._log_message(step_idx)): + continue + + if hasattr(memory, "location") and memory.location is None: + # we do not clone when caching is disabled to + # preserve backward compatibility + cloned_transformer = transformer + else: + cloned_transformer = clone(transformer) + if isinstance(name, list): + if not isinstance(transformer, _Pipeline): + raise AssertionError( + "If the name is a list, the transformer must be a Pipeline.", + ) + if routed_params: + step_params = { + "element_parameters": [routed_params[n] for n in name], + } + else: + step_params = {} + else: + step_params = self._get_metadata_for_step( + step_idx=step_idx, + step_params=routed_params[name], + all_params=raw_params, + ) + + # Fit or load from cache the current transformer + X, fitted_transformer = fit_transform_one_cached( # noqa: N806 # type: ignore + cloned_transformer, + X, + y, + None, + message_clsname="Pipeline", + message=self._log_message(step_idx), + params=step_params, + ) + # Replace the transformer of the step with the fitted + # transformer. This is necessary when loading the transformer + # from the cache. + if isinstance(fitted_transformer, AdapterPipeline): + ele_list = [step[1] for step in fitted_transformer.steps] + if not isinstance(name, list) or not isinstance(step_idx, list): + raise AssertionError() + if not len(name) == len(step_idx) == len(ele_list): + raise AssertionError() + for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): + self.steps[idx_i] = (name_i, ele_i) + if y is not None: + y = fitted_transformer.co_transform(y) + for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): + self.steps[idx_i] = (name_i, ele_i) + self._set_error_resinserter() + elif isinstance(name, list) or isinstance(step_idx, list): + raise AssertionError() + else: + self.steps[step_idx] = (name, fitted_transformer) + if is_empty(X): + return np.array([]), np.array([]) + return X, y + + def _transform( + self, + X: Any, # pylint: disable=invalid-name # noqa: N803 + routed_params: Bunch, + ) -> Any: + """Transform the data, and skip final estimator. + + Call `transform` of each transformer in the pipeline except the last one, + + Parameters + ---------- + X : iterable + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + routed_params: Bunch + parameters for each step as returned by process_routing + + Raises + ------ + AssertionError + If one of the transformers is 'passthrough' or does not implement + `transform`. + + Returns + ------- + Any + Result of calling `transform` on the second last estimator. + + """ + iter_input = X + do_routing = _routing_enabled() + if do_routing: + logger.warning("Routing is enabled and NOT fully tested!") + + for _, name, transform in self._iter(with_final=False): + if transform == "passthrough": + raise AssertionError("Passthrough should have been filtered out.") + if hasattr(transform, "transform"): + if do_routing: + iter_input = transform.transform( # type: ignore[call-arg] + iter_input, + routed_params[name].transform, + ) + else: + iter_input = transform.transform(iter_input) + else: + raise AssertionError( + f"Non transformer ocurred in transformation step: {transform}.", + ) + return iter_input + + @_fit_context( + # estimators in Pipeline.steps are not validated yet + prefer_skip_nested_validation=False, + ) + def fit( + self, + X: Any, # noqa: N803 + y: Any = None, + **fit_params: Any, + ) -> Self: + """Fit the model. + + Fit all the transformers one after the other and transform the + data. Finally, fit the transformed data using the final estimator. + + Parameters + ---------- + X : iterable + Training data. Must fulfill input requirements of first step of the + pipeline. + + y : iterable, default=None + Training targets. Must fulfill label requirements for all steps of + the pipeline. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + self : object + Pipeline with fitted steps. + + """ + routed_params = self._check_method_params(method="fit", props=fit_params) + xt, yt = self._fit(X, y, routed_params) + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + if self._final_estimator != "passthrough": + if is_empty(xt): + logger.warning( + "All input rows were filtered out! Model is not fitted!", + ) + else: + fit_params_last_step = routed_params[self._modified_steps[-1][0]] + self._final_estimator.fit(xt, yt, **fit_params_last_step["fit"]) + + return self + + def _can_fit_transform(self) -> bool: + """Check if the final estimator can fit_transform or is passthrough. + + Returns + ------- + bool + True if the final estimator can fit_transform or is passthrough. + + """ + return ( + self._final_estimator == "passthrough" + or hasattr(self._final_estimator, "transform") + or hasattr(self._final_estimator, "fit_transform") + ) + + def _can_decision_function(self) -> bool: + """Check if the final estimator implements decision_function. + + Returns + ------- + bool + True if the final estimator implements decision_function. + + """ + return hasattr(self._final_estimator, "decision_function") + + @available_if(_can_fit_transform) + @_fit_context( + # estimators in Pipeline.steps are not validated yet + prefer_skip_nested_validation=False, + ) + def fit_transform( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: + """Fit the model and transform with the final estimator. + + Fits all the transformers one after the other and transform the + data. Then uses `fit_transform` on transformed data with the final + estimator. + + Parameters + ---------- + X : iterable + Training data. Must fulfill input requirements of first step of the + pipeline. + + y : iterable, default=None + Training targets. Must fulfill label requirements for all steps of + the pipeline. + + **params : Any + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Raises + ------ + TypeError + If the last step does not implement `fit_transform` or `fit` and + `transform`. + + Returns + ------- + Xt : ndarray of shape (n_samples, n_transformed_features) + Transformed samples. + + """ + routed_params = self._check_method_params(method="fit_transform", props=params) + iter_input, iter_label = self._fit(X, y, routed_params) + last_step = self._final_estimator + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + if last_step == "passthrough": + pass + elif is_empty(iter_input): + logger.warning("All input rows were filtered out! Model is not fitted!") + else: + last_step_params = routed_params[self._modified_steps[-1][0]] + if hasattr(last_step, "fit_transform"): + iter_input = last_step.fit_transform( + iter_input, + iter_label, + **last_step_params["fit_transform"], + ) + elif hasattr(last_step, "transform") and hasattr(last_step, "fit"): + last_step.fit(iter_input, iter_label, **last_step_params["fit"]) + iter_input = last_step.transform( + iter_input, + **last_step_params["transform"], + ) + else: + raise TypeError( + f"fit_transform of the final estimator" + f" {last_step.__class__.__name__} {last_step_params} does not " + f"match fit_transform of Pipeline {self.__class__.__name__}", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.fit_transform(iter_input, iter_label) + return iter_input + + @available_if(_final_estimator_has("predict")) + def predict( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: + """Transform the data, and apply `predict` with the final estimator. + + Call `transform` of each transformer in the pipeline. The transformed + data are finally passed to the final estimator that calls `predict` + method. Only valid if the final estimator implements `predict`. + + Parameters + ---------- + X : iterable + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + **params : dict of string -> object + Parameters to the ``predict`` called at the end of all + transformations in the pipeline. Note that while this may be + used to return uncertainties from some models with return_std + or return_cov, uncertainties that are generated by the + transformations in the pipeline are not propagated to the + final estimator. + + .. versionadded:: 0.20 + + Raises + ------ + AssertionError + If the final estimator does not implement `predict`. + In this case this function should not be available. + + Returns + ------- + y_pred : ndarray + Result of calling `predict` on the final estimator. + + """ + if _routing_enabled(): + routed_params = process_routing(self, "predict", **params) + else: + routed_params = process_routing(self, "predict") + + iter_input = self._transform(X, routed_params) + + if self._final_estimator == "passthrough": + pass + elif is_empty(iter_input): + iter_input = [] + elif hasattr(self._final_estimator, "predict"): + if _routing_enabled(): + iter_input = self._final_estimator.predict( + iter_input, + **routed_params[self._modified_steps[-1][0]].predict, + ) + else: + iter_input = self._final_estimator.predict(iter_input, **params) + else: + raise AssertionError( + "Final estimator does not implement predict, " + "hence this function should not be available.", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input) + return iter_input + + @available_if(_final_estimator_has("fit_predict")) + @_fit_context( + # estimators in Pipeline.steps are not validated yet + prefer_skip_nested_validation=False, + ) + def fit_predict( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: + """Transform the data, and apply `fit_predict` with the final estimator. + + Call `fit_transform` of each transformer in the pipeline. The + transformed data are finally passed to the final estimator that calls + `fit_predict` method. Only valid if the final estimator implements + `fit_predict`. + + Parameters + ---------- + X : iterable + Training data. Must fulfill input requirements of first step of + the pipeline. + + y : iterable, default=None + Training targets. Must fulfill label requirements for all steps + of the pipeline. + + **params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Raises + ------ + AssertionError + If the final estimator does not implement `fit_predict`. + In this case this function should not be available. + + Returns + ------- + y_pred : ndarray + Result of calling `fit_predict` on the final estimator. + + """ + routed_params = self._check_method_params(method="fit_predict", props=params) + iter_input, iter_label = self._fit(X, y, routed_params) + + params_last_step = routed_params[self._modified_steps[-1][0]] + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + if self._final_estimator == "passthrough": + y_pred = iter_input + elif is_empty(iter_input): + logger.warning("All input rows were filtered out! Model is not fitted!") + iter_input = [] + y_pred = [] + elif hasattr(self._final_estimator, "fit_predict"): + y_pred = self._final_estimator.fit_predict( + iter_input, + iter_label, + **params_last_step.get("fit_predict", {}), + ) + else: + raise AssertionError( + "Final estimator does not implement fit_predict, " + "hence this function should not be available.", + ) + for _, post_element in self._post_processing_steps: + y_pred = post_element.fit_transform(y_pred, iter_label) + return y_pred + + @available_if(_final_estimator_has("predict_proba")) + def predict_proba( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: + """Transform the data, and apply `predict_proba` with the final estimator. + + Call `transform` of each transformer in the pipeline. The transformed + data are finally passed to the final estimator that calls `predict_proba` + method. Only valid if the final estimator implements `predict_proba`. + + Parameters + ---------- + X : iterable + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + **params : dict of string -> object + Parameters to the ``predict`` called at the end of all + transformations in the pipeline. Note that while this may be + used to return uncertainties from some models with return_std + or return_cov, uncertainties that are generated by the + transformations in the pipeline are not propagated to the + final estimator. + + Raises + ------ + AssertionError + If the final estimator does not implement `predict_proba`. + In this case this function should not be available. + + Returns + ------- + y_pred : ndarray + Result of calling `predict_proba` on the final estimator. + + """ + routed_params = process_routing(self, "predict_proba", **params) + iter_input = self._transform(X, routed_params) + + if self._final_estimator == "passthrough": + pass + elif is_empty(iter_input): + iter_input = [] + elif hasattr(self._final_estimator, "predict_proba"): + if _routing_enabled(): + iter_input = self._final_estimator.predict_proba( + iter_input, + **routed_params[self._modified_steps[-1][0]].predict_proba, + ) + else: + iter_input = self._final_estimator.predict_proba(iter_input, **params) + else: + raise AssertionError( + "Final estimator does not implement predict_proba, " + "hence this function should not be available.", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input) + return iter_input + + def _can_transform(self) -> bool: + """Check if the final estimator can transform or is passthrough. + + Returns + ------- + bool + True if the final estimator can transform or is passthrough. + + """ + return self._final_estimator == "passthrough" or hasattr( + self._final_estimator, + "transform", + ) + + @available_if(_can_transform) + def transform( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: + """Transform the data, and apply `transform` with the final estimator. + + Call `transform` of each transformer in the pipeline. The transformed + data are finally passed to the final estimator that calls + `transform` method. Only valid if the final estimator + implements `transform`. + + This also works where final estimator is `None` in which case all prior + transformations are applied. + + Parameters + ---------- + X : iterable + Data to transform. Must fulfill input requirements of first step + of the pipeline. + **params : Any + Parameters to the ``transform`` method of each estimator. + + Raises + ------ + AssertionError + If the final estimator does not implement `transform` or + `fit_transform` or is passthrough. + + Returns + ------- + Xt : ndarray of shape (n_samples, n_transformed_features) + Transformed data. + + """ + routed_params = process_routing(self, "transform", **params) + iter_input = X + for _, name, transform in self._iter(): + if transform == "passthrough": + continue + if hasattr(transform, "transform"): + iter_input = transform.transform( + iter_input, + **routed_params[name].transform, + ) + else: + raise AssertionError( + "Non transformer ocurred in transformation step." + "This should have been caught in the validation step.", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input, **params) + return iter_input + + @available_if(_can_decision_function) + def decision_function( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: + """Transform the data, and apply `decision_function` with the final estimator. + + Parameters + ---------- + X : iterable + Data to transform. Must fulfill input requirements of first step + of the pipeline. + **params : Any + Parameters to the ``decision_function`` method of the final estimator. + + Raises + ------ + AssertionError + If the final estimator does not implement `decision_function`. + + Returns + ------- + Any + Result of calling `decision_function` on the final estimator. + + """ + if _routing_enabled(): + routed_params = process_routing(self, "decision_function", **params) + else: + routed_params = process_routing(self, "decision_function") + + iter_input = self._transform(X, routed_params) + if self._final_estimator == "passthrough": + pass + elif is_empty(iter_input): + iter_input = [] + elif hasattr(self._final_estimator, "decision_function"): + if _routing_enabled(): + iter_input = self._final_estimator.decision_function( + iter_input, + **routed_params[self._final_estimator].predict, + ) + else: + iter_input = self._final_estimator.decision_function( + iter_input, + **params, + ) + else: + raise AssertionError( + "Final estimator does not implement `decision_function`, " + "hence this function should not be available.", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input) + return iter_input + + def get_metadata_routing(self) -> MetadataRouter: + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Notes + ----- + This method is copied from the original sklearn implementation. + Changes are marked with a comment. + + Returns + ------- + MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + + """ + router = MetadataRouter(owner=self.__class__.__name__) + + # first we add all steps except the last one + for _, name, trans in self._iter(with_final=False, filter_passthrough=True): + method_mapping = MethodMapping() + # fit, fit_predict, and fit_transform call fit_transform if it + # exists, or else fit and transform + if hasattr(trans, "fit_transform"): + ( + method_mapping.add(caller="fit", callee="fit_transform") + .add(caller="fit_transform", callee="fit_transform") + .add(caller="fit_predict", callee="fit_transform") + ) + else: + ( + method_mapping.add(caller="fit", callee="fit") + .add(caller="fit", callee="transform") + .add(caller="fit_transform", callee="fit") + .add(caller="fit_transform", callee="transform") + .add(caller="fit_predict", callee="fit") + .add(caller="fit_predict", callee="transform") + ) + + ( + method_mapping.add(caller="predict", callee="transform") + .add(caller="predict", callee="transform") + .add(caller="predict_proba", callee="transform") + .add(caller="decision_function", callee="transform") + .add(caller="predict_log_proba", callee="transform") + .add(caller="transform", callee="transform") + .add(caller="inverse_transform", callee="inverse_transform") + .add(caller="score", callee="transform") + ) + + router.add(method_mapping=method_mapping, **{name: trans}) + + # Only the _non_post_processing_steps is changed from the original + # implementation is changed in the following line + final_name, final_est = self._modified_steps[-1] + if final_est is None or final_est == "passthrough": + return router + + # then we add the last step + method_mapping = MethodMapping() + if hasattr(final_est, "fit_transform"): + method_mapping.add(caller="fit_transform", callee="fit_transform") + else: + method_mapping.add(caller="fit", callee="fit").add( + caller="fit", + callee="transform", + ) + ( + method_mapping.add(caller="fit", callee="fit") + .add(caller="predict", callee="predict") + .add(caller="fit_predict", callee="fit_predict") + .add(caller="predict_proba", callee="predict_proba") + .add(caller="decision_function", callee="decision_function") + .add(caller="predict_log_proba", callee="predict_log_proba") + .add(caller="transform", callee="transform") + .add(caller="inverse_transform", callee="inverse_transform") + .add(caller="score", callee="score") + ) + + router.add(method_mapping=method_mapping, **{final_name: final_est}) + return router diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 0e7efc88..3c2cb608 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -1,37 +1,40 @@ """Defines a pipeline is exposed to the user, accessible via pipeline.""" -# pylint: disable=too-many-lines - from __future__ import annotations from copy import deepcopy +from itertools import islice from typing import TYPE_CHECKING, Any, Literal, Self import numpy as np import numpy.typing as npt +from joblib import Parallel, delayed from loguru import logger -from sklearn.base import _fit_context, clone # noqa: PLC2701 -from sklearn.pipeline import Pipeline as _Pipeline -from sklearn.pipeline import _final_estimator_has, _fit_transform_one # noqa: PLC2701 +from rdkit.Chem.rdchem import MolSanitizeException +from rdkit.rdBase import BlockLogs +from sklearn.base import _fit_context # noqa: PLC2701 +from sklearn.pipeline import _final_estimator_has # noqa: PLC2701 from sklearn.utils._tags import Tags, get_tags # noqa: PLC2701 from sklearn.utils.metadata_routing import ( - MetadataRouter, - MethodMapping, _routing_enabled, # noqa: PLC2701 process_routing, ) from sklearn.utils.metaestimators import available_if -from sklearn.utils.validation import check_memory from molpipeline.abstract_pipeline_elements.core import ( ABCPipelineElement, + InvalidInstance, + RemovedInstance, SingleInstanceTransformerMixin, ) -from molpipeline.error_handling import ErrorFilter, FilterReinserter -from molpipeline.pipeline._molpipeline import _MolPipeline +from molpipeline.error_handling import ( + ErrorFilter, + FilterReinserter, + _MultipleErrorFilter, +) +from molpipeline.pipeline._skl_adapter_pipeline import AdapterPipeline from molpipeline.post_prediction import ( PostPredictionTransformation, - PostPredictionWrapper, ) from molpipeline.utils.logging import print_elapsed_time from molpipeline.utils.molpipeline_types import ( @@ -43,24 +46,76 @@ from molpipeline.utils.value_checks import is_empty if TYPE_CHECKING: - from collections.abc import Iterable + from collections.abc import Generator, Iterable import joblib - from sklearn.utils import Bunch + + from molpipeline.utils.molpipeline_types import TypeFixedVarSeq __all__ = ["Pipeline"] _IndexedStep = tuple[int, str, AnyElement] -_AggStep = tuple[list[int], list[str], _MolPipeline] +_AggStep = tuple[list[int], list[str], "Pipeline"] _AggregatedPipelineStep = _IndexedStep | _AggStep -class Pipeline(_Pipeline): +def _agg_transformers( + transformer_list: list[tuple[int, str, AnyTransformer]], + n_jobs: int = 1, +) -> tuple[list[int], list[str], AnyElement] | tuple[int, str, AnyTransformer]: + """Aggregate transformers to a single step. + + Parameters + ---------- + transformer_list: list[tuple[int, str, AnyTransformer]] + List of transformers to aggregate. + n_jobs: int, optional + Number of cores used for aggregated steps. + + Returns + ------- + tuple[list[int], list[str], AnyElement] | tuple[int, str, AnyTransformer] + Aggregated transformer. + If the list contains only one transformer, it is returned as is. + + """ + index_list = [step[0] for step in transformer_list] + name_list = [step[1] for step in transformer_list] + if len(transformer_list) == 1: + return transformer_list[0] + return ( + index_list, + name_list, + Pipeline( + [(step[1], step[2]) for step in transformer_list], + n_jobs=n_jobs, + ), + ) + + +class Pipeline(AdapterPipeline, ABCPipelineElement): """Defines the pipeline which handles pipeline elements.""" steps: list[AnyStep] - # * Adapted methods from sklearn.pipeline.Pipeline * + + @property + def _filter_elements(self) -> list[ErrorFilter]: + """Get the elements which filter the input.""" + return [step[1] for step in self.steps if isinstance(step[1], ErrorFilter)] + + @property + def _filter_elements_agg(self) -> _MultipleErrorFilter: + """Get the aggregated filter element.""" + return _MultipleErrorFilter(self._filter_elements) + + @property + def supports_single_instance(self) -> bool: + """Check if the pipeline supports single instance.""" + return all( + isinstance(step[1], SingleInstanceTransformerMixin) + for step in self._modified_steps + ) def __init__( self, @@ -88,23 +143,6 @@ def __init__( self.n_jobs = n_jobs self._set_error_resinserter() - def _set_error_resinserter(self) -> None: - """Connect the error resinserters with the error filters.""" - error_replacer_list = [ - e_filler - for _, e_filler in self.steps - if isinstance(e_filler, FilterReinserter) - ] - error_filter_list = [ - n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter) - ] - for step in self.steps: - if isinstance(step[1], PostPredictionWrapper): # noqa: SIM102 - if isinstance(step[1].wrapped_estimator, FilterReinserter): - error_replacer_list.append(step[1].wrapped_estimator) - for error_replacer in error_replacer_list: - error_replacer.select_error_filter(error_filter_list) - def _validate_steps(self) -> None: """Validate the steps. @@ -120,7 +158,7 @@ def _validate_steps(self) -> None: self._validate_names(names) # validate estimators - non_post_processing_steps = [e for _, _, e in self._agg_non_postpred_steps()] + non_post_processing_steps = [e for _, _, e in self._iter()] transformer_list = non_post_processing_steps[:-1] estimator = non_post_processing_steps[-1] @@ -151,13 +189,17 @@ def _validate_steps(self) -> None: # validate post-processing steps # Calling steps automatically validates them - _ = self._post_processing_steps() + _ = self._post_processing_steps def _iter( self, with_final: bool = True, filter_passthrough: bool = True, - ) -> Iterable[_AggregatedPipelineStep]: + ) -> Generator[ + tuple[int, str, AnyElement] | tuple[list[int], list[str], Pipeline], + Any, + None, + ]: """Iterate over all non post-processing steps. Steps which are children of a ABCPipelineElement were aggregated to a @@ -170,11 +212,6 @@ def _iter( filter_passthrough: bool, optional If True, passthrough steps are filtered out. - Raises - ------ - AssertionError - If the pipeline has no steps. - Yields ------ _AggregatedPipelineStep @@ -182,26 +219,40 @@ def _iter( transformer. """ - last_element: _AggregatedPipelineStep | None = None - - # This loop delays the output by one in order to identify the last step - for step in self._agg_non_postpred_steps(): - # Only happens for the first step - if last_element is None: - last_element = step - continue - if not filter_passthrough or ( - step[2] is not None and step[2] != "passthrough" + transformers_to_agg: list[tuple[int, str, AnyTransformer]] + transformers_to_agg = [] + final_transformer_list: list[ + tuple[int, str, AnyElement] | tuple[list[int], list[str], AnyElement] + ] = [] + for i, (name_i, step_i) in enumerate(super()._modified_steps): + if ( + isinstance(step_i, SingleInstanceTransformerMixin) + and not self.supports_single_instance ): - yield last_element - last_element = step + transformers_to_agg.append((i, name_i, step_i)) + else: + if transformers_to_agg: + if len(transformers_to_agg) == 1: + final_transformer_list.append(transformers_to_agg[0]) + else: + final_transformer_list.append( + _agg_transformers(transformers_to_agg, self.n_jobs), + ) + transformers_to_agg = [] + final_transformer_list.append((i, name_i, step_i)) - # This can only happen if no steps are set. - if last_element is None: - raise AssertionError("Pipeline needs to have at least one step!") + # yield last step if anything remains + if transformers_to_agg: + final_transformer_list.append( + _agg_transformers(transformers_to_agg, self.n_jobs), + ) + stop = len(final_transformer_list) + if not with_final: + stop -= 1 - if with_final and last_element[2] not in {None, "passthrough"}: - yield last_element + for idx, name, trans in islice(final_transformer_list, 0, stop): + if not filter_passthrough or (trans is not None and trans != "passthrough"): + yield idx, name, trans @property def _estimator_type(self) -> Any: @@ -215,282 +266,13 @@ def _estimator_type(self) -> Any: @property def _final_estimator( self, - ) -> ( - Literal["passthrough"] - | AnyTransformer - | AnyPredictor - | _MolPipeline - | ABCPipelineElement - ): + ) -> Literal["passthrough"] | AnyTransformer | AnyPredictor | ABCPipelineElement: """Return the lst estimator which is not a PostprocessingTransformer.""" - element_list = list(self._agg_non_postpred_steps()) - last_element = element_list[-1] + element_list = list(self._iter(with_final=True)) + steps = [s for s in element_list if not isinstance(s[2], ErrorFilter)] + last_element = steps[-1] return last_element[2] - # pylint: disable=too-many-locals,too-many-branches - def _fit( # noqa: PLR0912 - self, - X: Any, # noqa: N803 - y: Any = None, - routed_params: dict[str, Any] | None = None, - raw_params: dict[str, Any] | None = None, - ) -> tuple[Any, Any]: - """Fit the model by fitting all transformers except the final estimator. - - Data can be subsetted by the transformers. - - Parameters - ---------- - X : Any - Training data. - y : Any, optional (default=None) - Training objectives. - routed_params : dict[str, Any], optional - Parameters for each step as returned by process_routing. - Although this is marked as optional, it should not be None. - The awkward (argward?) typing is due to inheritance from sklearn. - Can be an empty dictionary. - raw_params : dict[str, Any], optional - Parameters passed by the user, used when `transform_input` - - Raises - ------ - AssertionError - If routed_params is None or if the transformer is 'passthrough'. - AssertionError - If the names are a list and the step is not a Pipeline. - - Returns - ------- - tuple[Any, Any] - The transformed data and the transformed objectives. - - """ - # shallow copy of steps - this should really be steps_ - self.steps = list(self.steps) - self._validate_steps() - if routed_params is None: - raise AssertionError("routed_params should not be None.") - - # Set up the memory - memory: joblib.Memory = check_memory(self.memory) - - fit_transform_one_cached = memory.cache(_fit_transform_one) - for step in self._iter(with_final=False, filter_passthrough=False): - step_idx, name, transformer = step - if transformer is None or transformer == "passthrough": - with print_elapsed_time("Pipeline", self._log_message(step_idx)): - continue - - if hasattr(memory, "location") and memory.location is None: - # we do not clone when caching is disabled to - # preserve backward compatibility - cloned_transformer = transformer - else: - cloned_transformer = clone(transformer) - if isinstance(cloned_transformer, _MolPipeline): - if routed_params: - step_params = { - "element_parameters": [routed_params[n] for n in name], - } - else: - step_params = {} - elif isinstance(name, list): - raise AssertionError( - "Names should not be a list, when the step is not a Pipeline", - ) - else: - step_params = self._get_metadata_for_step( - step_idx=step_idx, - step_params=routed_params[name], - all_params=raw_params, - ) - - # Fit or load from cache the current transformer - X, fitted_transformer = fit_transform_one_cached( # noqa: N806 - cloned_transformer, - X, - y, - None, - message_clsname="Pipeline", - message=self._log_message(step_idx), - params=step_params, - ) - # Replace the transformer of the step with the fitted - # transformer. This is necessary when loading the transformer - # from the cache. - if isinstance(fitted_transformer, _MolPipeline): - ele_list = fitted_transformer.element_list - if not isinstance(name, list) or not isinstance(step_idx, list): - raise AssertionError() - if not len(name) == len(step_idx) == len(ele_list): - raise AssertionError() - for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): - self.steps[idx_i] = (name_i, ele_i) - if y is not None: - y = fitted_transformer.co_transform(y) - for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): - self.steps[idx_i] = (name_i, ele_i) - self._set_error_resinserter() - elif isinstance(name, list) or isinstance(step_idx, list): - raise AssertionError() - else: - self.steps[step_idx] = (name, fitted_transformer) - if is_empty(X): - return np.array([]), np.array([]) - return X, y - - def _transform( - self, - X: Any, # pylint: disable=invalid-name # noqa: N803 - routed_params: Bunch, - ) -> Any: - """Transform the data, and skip final estimator. - - Call `transform` of each transformer in the pipeline except the last one, - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - routed_params: Bunch - parameters for each step as returned by process_routing - - Raises - ------ - AssertionError - If one of the transformers is 'passthrough' or does not implement - `transform`. - - Returns - ------- - Any - Result of calling `transform` on the second last estimator. - - """ - iter_input = X - do_routing = _routing_enabled() - if do_routing: - logger.warning("Routing is enabled and NOT fully tested!") - - for _, name, transform in self._iter(with_final=False): - if is_empty(iter_input): - if isinstance(transform, _MolPipeline): - _ = transform.transform(iter_input) - iter_input = [] - break - if transform == "passthrough": - raise AssertionError("Passthrough should have been filtered out.") - if hasattr(transform, "transform"): - if do_routing: - iter_input = transform.transform( # type: ignore[call-arg] - iter_input, - routed_params[name].transform, - ) - else: - iter_input = transform.transform(iter_input) - else: - raise AssertionError( - f"Non transformer ocurred in transformation step: {transform}.", - ) - return iter_input - - # * New implemented methods * - def _non_post_processing_steps( - self, - ) -> list[AnyStep]: - """Return all steps before the first PostPredictionTransformation. - - Raises - ------ - AssertionError - If a PostPredictionTransformation is found before the last step. - - Returns - ------- - list[AnyStep] - List of steps before the first PostPredictionTransformation. - - """ - non_post_processing_steps: list[AnyStep] = [] - start_adding = False - for step_name, step_estimator in self.steps[::-1]: - if not isinstance(step_estimator, PostPredictionTransformation): - start_adding = True - if start_adding: - if isinstance(step_estimator, PostPredictionTransformation): - raise AssertionError( - "PipelineElement of type PostPredictionTransformation occured " - "before the last step.", - ) - non_post_processing_steps.append((step_name, step_estimator)) - return list(non_post_processing_steps[::-1]) - - def _post_processing_steps(self) -> list[tuple[str, PostPredictionTransformation]]: - """Return last steps which are PostPredictionTransformation. - - Returns - ------- - list[tuple[str, PostPredictionTransformation]] - List of tuples containing the name and the PostPredictionTransformation. - - """ - post_processing_steps = [] - for step_name, step_estimator in self.steps[::-1]: - if isinstance(step_estimator, PostPredictionTransformation): - post_processing_steps.append((step_name, step_estimator)) - else: - break - return list(post_processing_steps[::-1]) - - def _agg_non_postpred_steps( - self, - ) -> Iterable[_AggregatedPipelineStep]: - """Generate (idx, (name, trans)) tuples from self.steps. - - When filter_passthrough is True, 'passthrough' and None transformers - are filtered out. - - Yields - ------ - _AggregatedPipelineStep - The _AggregatedPipelineStep is composed of the index, the name and the - transformer. - - """ - aggregated_transformer_list: list[tuple[int, str, ABCPipelineElement]] - aggregated_transformer_list = [] - for i, (name_i, step_i) in enumerate(self._non_post_processing_steps()): - if isinstance(step_i, SingleInstanceTransformerMixin): - aggregated_transformer_list.append((i, name_i, step_i)) - else: - if aggregated_transformer_list: - index_list = [step[0] for step in aggregated_transformer_list] - name_list = [step[1] for step in aggregated_transformer_list] - transformer_list = [step[2] for step in aggregated_transformer_list] - if len(aggregated_transformer_list) == 1: - yield index_list[0], name_list[0], transformer_list[0] - else: - pipeline = _MolPipeline(transformer_list, n_jobs=self.n_jobs) - yield index_list, name_list, pipeline - aggregated_transformer_list = [] - yield i, name_i, step_i - - # yield last step if anything remains - if aggregated_transformer_list: - index_list = [step[0] for step in aggregated_transformer_list] - name_list = [step[1] for step in aggregated_transformer_list] - transformer_list = [step[2] for step in aggregated_transformer_list] - - if len(aggregated_transformer_list) == 1: - yield index_list[0], name_list[0], transformer_list[0] - - elif len(aggregated_transformer_list) > 1: - pipeline = _MolPipeline(transformer_list, n_jobs=self.n_jobs) - yield index_list, name_list, pipeline - @_fit_context( # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False, @@ -527,6 +309,8 @@ def fit( Pipeline with fitted steps. """ + if self.supports_single_instance: + return self routed_params = self._check_method_params(method="fit", props=fit_params) xt, yt = self._fit(X, y, routed_params) with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): @@ -536,9 +320,7 @@ def fit( "All input rows were filtered out! Model is not fitted!", ) else: - fit_params_last_step = routed_params[ - self._non_post_processing_steps()[-1][0] - ] + fit_params_last_step = routed_params[self._modified_steps[-1][0]] self._final_estimator.fit(xt, yt, **fit_params_last_step["fit"]) return self @@ -603,9 +385,8 @@ def fit_transform( Raises ------ - TypeError - If the last step does not implement `fit_transform` or `fit` and - `transform`. + AssertionError + If PipelineElement is not a SingleInstanceTransformerMixin or Pipeline. Returns ------- @@ -613,38 +394,57 @@ def fit_transform( Transformed samples. """ - routed_params = self._check_method_params(method="fit_transform", props=params) - iter_input, iter_label = self._fit(X, y, routed_params) - last_step = self._final_estimator - with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): - if last_step == "passthrough": - pass - elif is_empty(iter_input): - logger.warning("All input rows were filtered out! Model is not fitted!") - else: - last_step_params = routed_params[ - self._non_post_processing_steps()[-1][0] - ] - if hasattr(last_step, "fit_transform"): - iter_input = last_step.fit_transform( - iter_input, - iter_label, - **last_step_params["fit_transform"], - ) - elif hasattr(last_step, "transform") and hasattr(last_step, "fit"): - last_step.fit(iter_input, iter_label, **last_step_params["fit"]) - iter_input = last_step.transform( - iter_input, - **last_step_params["transform"], - ) - else: - raise TypeError( - f"fit_transform of the final estimator" - f" {last_step.__class__.__name__} {last_step_params} does not " - f"match fit_transform of Pipeline {self.__class__.__name__}", - ) - for _, post_element in self._post_processing_steps(): - iter_input = post_element.fit_transform(iter_input, iter_label) + if not self.supports_single_instance: + return super().fit_transform(X, **params) + iter_input = X + removed_rows: dict[ErrorFilter, list[int]] = {} + for error_filter in self._filter_elements: + removed_rows[error_filter] = [] + iter_idx_array = np.arange(len(iter_input)) + + # The meta elements merge steps which do not require fitting + for idx, _, i_element in self._iter(with_final=True): + if not isinstance( + i_element, + SingleInstanceTransformerMixin, + ) and ( + isinstance(i_element, Pipeline) + and not i_element.supports_single_instance + ): + raise AssertionError( + "PipelineElement is not a SingleInstanceTransformerMixin.", + ) + if isinstance(i_element, (ErrorFilter, FilterReinserter)): + continue + i_element.n_jobs = self.n_jobs + iter_input = i_element.pretransform(iter_input) + for error_filter in self._filter_elements: + iter_input = error_filter.transform(iter_input) + for idx in error_filter.error_indices: + idx = iter_idx_array[idx] + removed_rows[error_filter].append(idx) + iter_idx_array = error_filter.co_transform(iter_idx_array) + iter_input = i_element.assemble_output(iter_input) + i_element.n_jobs = 1 + + # Set removed rows to filter elements to allow for correct co_transform + iter_idx_array = np.arange(len(X)) + for error_filter in self._filter_elements: + removed_idx_list = removed_rows[error_filter] + error_filter.error_indices = [] + for new_idx, _idx in enumerate(iter_idx_array): + if _idx in removed_idx_list: + error_filter.error_indices.append(new_idx) + error_filter.n_total = len(iter_idx_array) + iter_idx_array = error_filter.co_transform(iter_idx_array) + error_replacer_list = [ + ele for _, ele in self.steps if isinstance(ele, FilterReinserter) + ] + for error_replacer in error_replacer_list: + error_replacer.select_error_filter(self._filter_elements) + iter_input = error_replacer.transform(iter_input) + for _, post_element in self._post_processing_steps: + iter_input = post_element.fit_transform(iter_input, y) return iter_input @available_if(_final_estimator_has("predict")) @@ -702,7 +502,7 @@ def predict( if _routing_enabled(): iter_input = self._final_estimator.predict( iter_input, - **routed_params[self._non_post_processing_steps()[-1][0]].predict, + **routed_params[self._modified_steps[-1][0]].predict, ) else: iter_input = self._final_estimator.predict(iter_input, **params) @@ -711,7 +511,7 @@ def predict( "Final estimator does not implement predict, " "hence this function should not be available.", ) - for _, post_element in self._post_processing_steps(): + for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input) return iter_input @@ -763,7 +563,7 @@ def fit_predict( routed_params = self._check_method_params(method="fit_predict", props=params) iter_input, iter_label = self._fit(X, y, routed_params) - params_last_step = routed_params[self._non_post_processing_steps()[-1][0]] + params_last_step = routed_params[self._modified_steps[-1][0]] with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): if self._final_estimator == "passthrough": y_pred = iter_input @@ -782,74 +582,10 @@ def fit_predict( "Final estimator does not implement fit_predict, " "hence this function should not be available.", ) - for _, post_element in self._post_processing_steps(): + for _, post_element in self._post_processing_steps: y_pred = post_element.fit_transform(y_pred, iter_label) return y_pred - @available_if(_final_estimator_has("predict_proba")) - def predict_proba( - self, - X: Any, # noqa: N803 - **params: Any, - ) -> Any: - """Transform the data, and apply `predict_proba` with the final estimator. - - Call `transform` of each transformer in the pipeline. The transformed - data are finally passed to the final estimator that calls `predict_proba` - method. Only valid if the final estimator implements `predict_proba`. - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - **params : dict of string -> object - Parameters to the ``predict`` called at the end of all - transformations in the pipeline. Note that while this may be - used to return uncertainties from some models with return_std - or return_cov, uncertainties that are generated by the - transformations in the pipeline are not propagated to the - final estimator. - - Raises - ------ - AssertionError - If the final estimator does not implement `predict_proba`. - In this case this function should not be available. - - Returns - ------- - y_pred : ndarray - Result of calling `predict_proba` on the final estimator. - - """ - routed_params = process_routing(self, "predict_proba", **params) - iter_input = self._transform(X, routed_params) - - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "predict_proba"): - if _routing_enabled(): - iter_input = self._final_estimator.predict_proba( - iter_input, - **routed_params[ - self._non_post_processing_steps()[-1][0] - ].predict_proba, - ) - else: - iter_input = self._final_estimator.predict_proba(iter_input, **params) - else: - raise AssertionError( - "Final estimator does not implement predict_proba, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps(): - iter_input = post_element.transform(iter_input) - return iter_input - def _can_transform(self) -> bool: """Check if the final estimator can transform or is passthrough. @@ -900,28 +636,26 @@ def transform( Transformed data. """ - routed_params = process_routing(self, "transform", **params) - iter_input = X - for _, name, transform in self._iter(): - if transform == "passthrough": - continue - if is_empty(iter_input): - # This is done to prime the error filters - if isinstance(transform, _MolPipeline): - _ = transform.transform(iter_input) - iter_input = [] - break - if hasattr(transform, "transform"): - iter_input = transform.transform( - iter_input, - **routed_params[name].transform, - ) - else: - raise AssertionError( - "Non transformer ocurred in transformation step." - "This should have been caught in the validation step.", - ) - for _, post_element in self._post_processing_steps(): + if self.supports_single_instance: + output_generator = self._transform_iterator(X) + iter_input = self.assemble_output(output_generator) + else: + routed_params = process_routing(self, "transform", **params) + iter_input = X + for _, name, transform in self._iter(): + if transform == "passthrough": + continue + if hasattr(transform, "transform"): + iter_input = transform.transform( + iter_input, + **routed_params[name].transform, + ) + else: + raise AssertionError( + "Non transformer ocurred in transformation step." + "This should have been caught in the validation step.", + ) + for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input, **params) return iter_input @@ -978,7 +712,7 @@ def decision_function( "Final estimator does not implement `decision_function`, " "hence this function should not be available.", ) - for _, post_element in self._post_processing_steps(): + for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input) return iter_input @@ -1064,86 +798,130 @@ def __sklearn_tags__(self) -> Tags: # noqa: PLW3201 return tags - def get_metadata_routing(self) -> MetadataRouter: - """Get metadata routing of this object. + def transform_single(self, input_value: Any) -> Any: + """Transform a single input according to the sequence of PipelineElements. - Please check :ref:`User Guide ` on how the routing - mechanism works. + Parameters + ---------- + input_value: Any + Molecular representation which is subsequently transformed. - Notes - ----- - This method is copied from the original sklearn implementation. - Changes are marked with a comment. + Raises + ------ + AssertionError + If the PipelineElement is not a SingleInstanceTransformerMixin. Returns ------- - MetadataRouter - A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating - routing information. + Any + Transformed molecular representation. """ - router = MetadataRouter(owner=self.__class__.__name__) - - # first we add all steps except the last one - for _, name, trans in self._iter(with_final=False, filter_passthrough=True): - method_mapping = MethodMapping() - # fit, fit_predict, and fit_transform call fit_transform if it - # exists, or else fit and transform - if hasattr(trans, "fit_transform"): - ( - method_mapping.add(caller="fit", callee="fit_transform") - .add(caller="fit_transform", callee="fit_transform") - .add(caller="fit_predict", callee="fit_transform") + log_block = BlockLogs() + iter_value = input_value + for _, p_element in self._modified_steps: + if not isinstance(p_element, SingleInstanceTransformerMixin): + raise AssertionError( + "PipelineElement is not a SingleInstanceTransformerMixin.", ) - else: - ( - method_mapping.add(caller="fit", callee="fit") - .add(caller="fit", callee="transform") - .add(caller="fit_transform", callee="fit") - .add(caller="fit_transform", callee="transform") - .add(caller="fit_predict", callee="fit") - .add(caller="fit_predict", callee="transform") + if not isinstance(p_element, ABCPipelineElement): + raise AssertionError( + "PipelineElement is not a ABCPipelineElement.", ) + try: + if not isinstance(iter_value, RemovedInstance) or isinstance( + p_element, + FilterReinserter, + ): + iter_value = p_element.transform_single(iter_value) + except MolSanitizeException as err: + iter_value = InvalidInstance( + p_element.uuid, + f"RDKit MolSanitizeException: {err.args}", + p_element.name, + ) + del log_block + return iter_value - ( - method_mapping.add(caller="predict", callee="transform") - .add(caller="predict", callee="transform") - .add(caller="predict_proba", callee="transform") - .add(caller="decision_function", callee="transform") - .add(caller="predict_log_proba", callee="transform") - .add(caller="transform", callee="transform") - .add(caller="inverse_transform", callee="inverse_transform") - .add(caller="score", callee="transform") - ) + def pretransform(self, x_input: Any) -> Any: + """Transform the input according to the sequence without assemble_output step. - router.add(method_mapping=method_mapping, **{name: trans}) + Parameters + ---------- + x_input: Any + Molecular representations which are subsequently transformed. - # Only the _non_post_processing_steps is changed from the original - # implementation is changed in the following line - final_name, final_est = self._non_post_processing_steps()[-1] - if final_est is None or final_est == "passthrough": - return router + Returns + ------- + Any + Transformed molecular representations. - # then we add the last step - method_mapping = MethodMapping() - if hasattr(final_est, "fit_transform"): - method_mapping.add(caller="fit_transform", callee="fit_transform") - else: - method_mapping.add(caller="fit", callee="fit").add( - caller="fit", - callee="transform", - ) - ( - method_mapping.add(caller="fit", callee="fit") - .add(caller="predict", callee="predict") - .add(caller="fit_predict", callee="fit_predict") - .add(caller="predict_proba", callee="predict_proba") - .add(caller="decision_function", callee="decision_function") - .add(caller="predict_log_proba", callee="predict_log_proba") - .add(caller="transform", callee="transform") - .add(caller="inverse_transform", callee="inverse_transform") - .add(caller="score", callee="score") + """ + return list(self._transform_iterator(x_input)) + + def assemble_output(self, value_list: Iterable[Any]) -> Any: + """Assemble the output of the pipeline. + + Parameters + ---------- + value_list: Iterable[Any] + Generator which yields the output of the pipeline. + + Returns + ------- + Any + Assembled output. + + """ + final_estimator = self._final_estimator + if hasattr(final_estimator, "assemble_output"): + return final_estimator.assemble_output(value_list) + return list(value_list) + + def _transform_iterator(self, x_input: Any) -> Any: + """Transform the input according to the sequence of provided PipelineElements. + + Parameters + ---------- + x_input: Any + Molecular representations which are subsequently transformed. + + Yields + ------ + Any + Transformed molecular representations. + + """ + agg_filter = self._filter_elements_agg + for filter_element in self._filter_elements: + filter_element.error_indices = [] + parallel = Parallel( + n_jobs=self.n_jobs, + return_as="generator", + batch_size="auto", + ) + output_generator = parallel( + delayed(self.transform_single)(value) for value in x_input ) + for i, transformed_value in enumerate(output_generator): + if isinstance(transformed_value, RemovedInstance): + agg_filter.register_removed(i, transformed_value) + else: + yield transformed_value + agg_filter.set_total(len(x_input)) + + def co_transform(self, x_input: TypeFixedVarSeq) -> TypeFixedVarSeq: + """Filter flagged rows from the input. - router.add(method_mapping=method_mapping, **{final_name: final_est}) - return router + Parameters + ---------- + x_input: Any + Molecular representations which are subsequently filtered. + + Returns + ------- + Any + Filtered molecular representations. + + """ + return self._filter_elements_agg.co_transform(x_input) From d356b844dd288ef50ecc45eb4597bad0111d739a Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 16:58:10 +0200 Subject: [PATCH 10/31] Change inheritance --- molpipeline/error_handling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index 7d41d425..16923fd1 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -27,7 +27,7 @@ _S = TypeVar("_S") -class ErrorFilter(SingleInstanceTransformerMixin, ABCPipelineElement): +class ErrorFilter(SingleInstanceTransformerMixin, TransformingPipelineElement): """Filter to remove InvalidInstances from a list of values.""" element_ids: set[str] From 3b684cd527bac3e93216a64ccbcad2726db2b8f4 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 17:06:45 +0200 Subject: [PATCH 11/31] Type cast --- molpipeline/pipeline/_skl_pipeline.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 3c2cb608..29d34d0c 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -421,8 +421,7 @@ def fit_transform( for error_filter in self._filter_elements: iter_input = error_filter.transform(iter_input) for idx in error_filter.error_indices: - idx = iter_idx_array[idx] - removed_rows[error_filter].append(idx) + removed_rows[error_filter].append(int(iter_idx_array[idx])) iter_idx_array = error_filter.co_transform(iter_idx_array) iter_input = i_element.assemble_output(iter_input) i_element.n_jobs = 1 From 08260792868429438c7e9da9167412d0248bf15e Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 17:14:11 +0200 Subject: [PATCH 12/31] type hints --- molpipeline/pipeline/_skl_adapter_pipeline.py | 2 +- molpipeline/pipeline/_skl_pipeline.py | 90 ++++++++++++------- 2 files changed, 60 insertions(+), 32 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index cae1b858..82613c04 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -227,7 +227,7 @@ def _iter( with_final: bool = True, filter_passthrough: bool = True, ) -> Generator[ - tuple[int, str, AnyElement] | tuple[list[int], list[str], AnyElement], + tuple[int, str, AnyElement] | tuple[list[int], list[str], AdapterPipeline], Any, None, ]: diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 29d34d0c..fa6bb18f 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -4,7 +4,7 @@ from copy import deepcopy from itertools import islice -from typing import TYPE_CHECKING, Any, Literal, Self +from typing import TYPE_CHECKING, Any, Literal, Self, TypeIs import numpy as np import numpy.typing as npt @@ -26,6 +26,7 @@ InvalidInstance, RemovedInstance, SingleInstanceTransformerMixin, + TransformingPipelineElement, ) from molpipeline.error_handling import ( ErrorFilter, @@ -46,7 +47,7 @@ from molpipeline.utils.value_checks import is_empty if TYPE_CHECKING: - from collections.abc import Generator, Iterable + from collections.abc import Generator, Iterable, Sequence import joblib @@ -61,21 +62,21 @@ def _agg_transformers( - transformer_list: list[tuple[int, str, AnyTransformer]], + transformer_list: Sequence[tuple[int, str, AnyElement]], n_jobs: int = 1, -) -> tuple[list[int], list[str], AnyElement] | tuple[int, str, AnyTransformer]: +) -> tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement]: """Aggregate transformers to a single step. Parameters ---------- - transformer_list: list[tuple[int, str, AnyTransformer]] + transformer_list: list[tuple[int, str, AnyElement]] List of transformers to aggregate. n_jobs: int, optional Number of cores used for aggregated steps. Returns ------- - tuple[list[int], list[str], AnyElement] | tuple[int, str, AnyTransformer] + tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement] Aggregated transformer. If the list contains only one transformer, it is returned as is. @@ -94,7 +95,28 @@ def _agg_transformers( ) -class Pipeline(AdapterPipeline, ABCPipelineElement): +def check_single_instance_support( + estimator: Any, +) -> TypeIs[SingleInstanceTransformerMixin | Pipeline]: + """Check if the estimator supports single instance processing. + + Parameters + ---------- + estimator: Any + Estimator to check. + + Returns + ------- + TypeIs[SingleInstanceTransformerMixin | Pipeline] + True if the estimator supports single instance processing. + + """ + if isinstance(estimator, SingleInstanceTransformerMixin): + return True + return isinstance(estimator, Pipeline) and estimator.supports_single_instance + + +class Pipeline(AdapterPipeline, TransformingPipelineElement): """Defines the pipeline which handles pipeline elements.""" steps: list[AnyStep] @@ -196,7 +218,7 @@ def _iter( with_final: bool = True, filter_passthrough: bool = True, ) -> Generator[ - tuple[int, str, AnyElement] | tuple[list[int], list[str], Pipeline], + tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement], Any, None, ]: @@ -219,16 +241,19 @@ def _iter( transformer. """ + if self.supports_single_instance: + yield from super()._iter( + with_final=with_final, + filter_passthrough=filter_passthrough, + ) + return transformers_to_agg: list[tuple[int, str, AnyTransformer]] transformers_to_agg = [] final_transformer_list: list[ - tuple[int, str, AnyElement] | tuple[list[int], list[str], AnyElement] + tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement] ] = [] for i, (name_i, step_i) in enumerate(super()._modified_steps): - if ( - isinstance(step_i, SingleInstanceTransformerMixin) - and not self.supports_single_instance - ): + if isinstance(step_i, SingleInstanceTransformerMixin): transformers_to_agg.append((i, name_i, step_i)) else: if transformers_to_agg: @@ -250,9 +275,11 @@ def _iter( if not with_final: stop -= 1 - for idx, name, trans in islice(final_transformer_list, 0, stop): - if not filter_passthrough or (trans is not None and trans != "passthrough"): - yield idx, name, trans + for step in islice(final_transformer_list, 0, stop): + if not filter_passthrough or ( + step[2] is not None and step[2] != "passthrough" + ): + yield step @property def _estimator_type(self) -> Any: @@ -387,6 +414,8 @@ def fit_transform( ------ AssertionError If PipelineElement is not a SingleInstanceTransformerMixin or Pipeline. + AssertionError + If PipelineElement is not a TransformingPipelineElement. Returns ------- @@ -397,31 +426,30 @@ def fit_transform( if not self.supports_single_instance: return super().fit_transform(X, **params) iter_input = X - removed_rows: dict[ErrorFilter, list[int]] = {} - for error_filter in self._filter_elements: - removed_rows[error_filter] = [] + removed_row_dict: dict[ErrorFilter, list[int]] = { + error_filter: [] for error_filter in self._filter_elements + } iter_idx_array = np.arange(len(iter_input)) # The meta elements merge steps which do not require fitting for idx, _, i_element in self._iter(with_final=True): - if not isinstance( - i_element, - SingleInstanceTransformerMixin, - ) and ( - isinstance(i_element, Pipeline) - and not i_element.supports_single_instance - ): + if not isinstance(i_element, TransformingPipelineElement): raise AssertionError( - "PipelineElement is not a SingleInstanceTransformerMixin.", + "PipelineElement is not a TransformingPipelineElement.", + ) + if not check_single_instance_support(i_element): + raise AssertionError( + "PipelineElement is not a SingleInstanceTransformerMixin or" + " Pipeline with signle instance support.", ) - if isinstance(i_element, (ErrorFilter, FilterReinserter)): + if isinstance(i_element, ErrorFilter): continue i_element.n_jobs = self.n_jobs iter_input = i_element.pretransform(iter_input) for error_filter in self._filter_elements: iter_input = error_filter.transform(iter_input) for idx in error_filter.error_indices: - removed_rows[error_filter].append(int(iter_idx_array[idx])) + removed_row_dict[error_filter].append(int(iter_idx_array[idx])) iter_idx_array = error_filter.co_transform(iter_idx_array) iter_input = i_element.assemble_output(iter_input) i_element.n_jobs = 1 @@ -429,7 +457,7 @@ def fit_transform( # Set removed rows to filter elements to allow for correct co_transform iter_idx_array = np.arange(len(X)) for error_filter in self._filter_elements: - removed_idx_list = removed_rows[error_filter] + removed_idx_list = removed_row_dict[error_filter] error_filter.error_indices = [] for new_idx, _idx in enumerate(iter_idx_array): if _idx in removed_idx_list: @@ -873,7 +901,7 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: """ final_estimator = self._final_estimator - if hasattr(final_estimator, "assemble_output"): + if isinstance(final_estimator, TransformingPipelineElement): return final_estimator.assemble_output(value_list) return list(value_list) From f1f835078944ad1063258b0cb2a2e76040c32c59 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 17:15:07 +0200 Subject: [PATCH 13/31] fix var name --- molpipeline/pipeline/_skl_pipeline.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index fa6bb18f..3606b71c 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -870,12 +870,12 @@ def transform_single(self, input_value: Any) -> Any: del log_block return iter_value - def pretransform(self, x_input: Any) -> Any: + def pretransform(self, value_list: Any) -> Any: """Transform the input according to the sequence without assemble_output step. Parameters ---------- - x_input: Any + value_list: Any Molecular representations which are subsequently transformed. Returns @@ -884,7 +884,7 @@ def pretransform(self, x_input: Any) -> Any: Transformed molecular representations. """ - return list(self._transform_iterator(x_input)) + return list(self._transform_iterator(value_list)) def assemble_output(self, value_list: Iterable[Any]) -> Any: """Assemble the output of the pipeline. From 3d1ed3e76195b96e7e92c5baae51e9847f34a3c4 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 17:43:15 +0200 Subject: [PATCH 14/31] Add type ignore and minor linting --- .../test_mol2mol/test_mol2mol_filter.py | 59 +++++++++++-------- 1 file changed, 36 insertions(+), 23 deletions(-) diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 881ffc03..ee9cc50f 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -1,9 +1,14 @@ -"""Test MolFilter, which invalidate molecules based on criteria defined in the respective filter.""" +"""Unittest for MolFilter functionality. + +MolFilter set flag Molecules as invalid based on the criteria defined in the filter. + +""" import json import tempfile import unittest from pathlib import Path +from typing import TYPE_CHECKING from molpipeline import ErrorFilter, FilterReinserter, Pipeline from molpipeline.any2mol import SmilesToMol @@ -19,7 +24,9 @@ ) from molpipeline.utils.comparison import compare_recursive from molpipeline.utils.json_operations import recursive_from_json, recursive_to_json -from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange + +if TYPE_CHECKING: + from molpipeline.utils.molpipeline_types import FloatCountRange, IntOrIntCountRange # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated SMILES_ANTIMONY = "[SbH6+3]" @@ -80,7 +87,7 @@ def test_element_filter(self) -> None: 6: 6, 1: (5, 6), 17: (0, 1), - } + }, }, "result": [SMILES_BENZENE, SMILES_CHLOROBENZENE], }, @@ -88,7 +95,7 @@ def test_element_filter(self) -> None: ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -99,14 +106,15 @@ def test_json_roundtrip(self) -> None: ----- It is important to save the ElementFilter as a JSON file and then load it back. This is because json.dumps() sets the keys of the dictionary to strings. + """ element_filter = ElementFilter() json_object = recursive_to_json(element_filter) with tempfile.TemporaryDirectory() as temp_folder: temp_file_path = Path(temp_folder) / "test.json" - with open(temp_file_path, "w", encoding="UTF-8") as out_file: + with temp_file_path.open("w", encoding="UTF-8") as out_file: json.dump(json_object, out_file) - with open(temp_file_path, encoding="UTF-8") as in_file: + with temp_file_path.open(encoding="UTF-8") as in_file: loaded_json_object = json.load(in_file) recreated_element_filter = recursive_from_json(loaded_json_object) @@ -117,7 +125,8 @@ def test_json_roundtrip(self) -> None: with self.subTest(param_name=param_name): self.assertTrue( compare_recursive(original_value, recreated_params[param_name]), - f"Original: {original_value}, Recreated: {recreated_params[param_name]}", + f"Original: {original_value}, " + f"Recreated: {recreated_params[param_name]}", ) @@ -132,6 +141,7 @@ def _create_pipeline() -> Pipeline: ------- Pipeline Pipeline with a complex filter. + """ element_filter_1 = ElementFilter({6: 6, 1: 6}) element_filter_2 = ElementFilter({6: 6, 1: 5, 17: 1}) @@ -140,10 +150,10 @@ def _create_pipeline() -> Pipeline: ( ("element_filter_1", element_filter_1), ("element_filter_2", element_filter_2), - ) + ), ) - pipeline = Pipeline( + return Pipeline( [ ("Smiles2Mol", SmilesToMol()), ("MultiElementFilter", multi_element_filter), @@ -151,7 +161,6 @@ def _create_pipeline() -> Pipeline: ("ErrorFilter", ErrorFilter()), ], ) - return pipeline def test_complex_filter(self) -> None: """Test if molecules are filtered correctly by allowed chemical elements.""" @@ -169,14 +178,14 @@ def test_complex_filter(self) -> None: { "params": { "MultiElementFilter__mode": "any", - "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, + "MultiElementFilter__pipeline_filter_elements__element_filter_1__add_hydrogens": False, # noqa: E501 }, "result": [SMILES_CHLOROBENZENE], }, ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -198,7 +207,7 @@ def test_complex_filter_non_unique_names(self) -> None: with self.assertRaises(ValueError): ComplexFilter( - (("filter_1", element_filter_1), ("filter_1", element_filter_2)) + (("filter_1", element_filter_1), ("filter_1", element_filter_2)), ) @@ -264,7 +273,7 @@ def test_smarts_smiles_filter(self) -> None: ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -285,7 +294,11 @@ def test_smarts_smiles_filter_wrong_pattern(self) -> None: SmilesFilter(smiles_pats) def test_smarts_filter_parallel(self) -> None: - """Test if molecules are filtered correctly by allowed SMARTS patterns in parallel.""" + """Test if molecules are filtered correctly. + + This test runs the SmartsFilter in parallel. + + """ smarts_pats: dict[str, IntOrIntCountRange] = { "c": (4, None), "Cl": 1, @@ -352,38 +365,38 @@ def test_descriptor_filter(self) -> None: }, { "params": { - "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)} + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1.99, 4)}, }, "result": [SMILES_CL_BR], }, { "params": { - "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)} + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (2.01, 4)}, }, "result": [], }, { "params": { - "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)} + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.00)}, }, "result": [SMILES_CL_BR], }, { "params": { - "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)} + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 2.01)}, }, "result": [SMILES_CL_BR], }, { "params": { - "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)} + "DescriptorsFilter__filter_elements": {"NumHAcceptors": (1, 1.99)}, }, "result": [], }, ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -409,7 +422,7 @@ def test_invalidate_mixtures(self) -> None: ("mol2smi", mol2smi), ("error_filter", error_filter), ("error_replacer", error_replacer), - ] + ], ) mols_processed = pipeline.fit_transform(mol_list) self.assertEqual(expected_invalidated_mol_list, mols_processed) @@ -424,7 +437,7 @@ def test_inorganic_filter(self) -> None: inorganics_filter = InorganicsFilter() mol2smiles = MolToSmiles() error_filter = ErrorFilter.from_element_list( - [smiles2mol, inorganics_filter, mol2smiles] + [smiles2mol, inorganics_filter, mol2smiles], ) pipeline = Pipeline( [ From 9cd2354a3979e84ad89f5c5df642ad4dde6360ca Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 18:01:52 +0200 Subject: [PATCH 15/31] remove final estimator --- molpipeline/pipeline/_skl_adapter_pipeline.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 82613c04..0815dcc6 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -3,7 +3,7 @@ from __future__ import annotations from itertools import islice -from typing import TYPE_CHECKING, Any, Literal, Self +from typing import TYPE_CHECKING, Any, Self import numpy as np import numpy.typing as npt @@ -28,9 +28,7 @@ from molpipeline.utils.logging import print_elapsed_time from molpipeline.utils.molpipeline_types import ( AnyElement, - AnyPredictor, AnyStep, - AnyTransformer, ) from molpipeline.utils.value_checks import is_empty @@ -40,10 +38,6 @@ import joblib from sklearn.utils import Bunch - from molpipeline.abstract_pipeline_elements.core import ( - ABCPipelineElement, - ) - _IndexedStep = tuple[int, str, AnyElement] @@ -63,13 +57,6 @@ def _estimator_type(self) -> Any: return self._final_estimator._estimator_type # noqa: SLF001 return None - @property - def _final_estimator( - self, - ) -> Literal["passthrough"] | AnyTransformer | AnyPredictor | ABCPipelineElement: - """Return the lst estimator which is not a PostprocessingTransformer.""" - return self._modified_steps[-1][1] - @property def _modified_steps( self, From 905bd9f781d853fcab0ac4b6fa0465f4a6fbdf05 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 15:15:31 +0200 Subject: [PATCH 16/31] remode duplicate _estimator_type property --- molpipeline/pipeline/_skl_pipeline.py | 9 --------- 1 file changed, 9 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 3606b71c..6f71433a 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -281,15 +281,6 @@ def _iter( ): yield step - @property - def _estimator_type(self) -> Any: - """Return the estimator type.""" - if self._final_estimator is None or self._final_estimator == "passthrough": - return None - if hasattr(self._final_estimator, "_estimator_type"): - return self._final_estimator._estimator_type # noqa: SLF001 - return None - @property def _final_estimator( self, From 225858d1484801e3717c782fc8038bc1e5f7deac Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 15:22:24 +0200 Subject: [PATCH 17/31] add ignore to duplicate code --- molpipeline/pipeline/_skl_adapter_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 0815dcc6..0fd80b60 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -763,7 +763,7 @@ def predict_proba( iter_input = post_element.transform(iter_input) return iter_input - def _can_transform(self) -> bool: + def _can_transform(self) -> bool: # pylint: ignore[duplicate-code] """Check if the final estimator can transform or is passthrough. Returns From 577ea3895e772273531f063fc8c8cb917de9eebc Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 15:26:36 +0200 Subject: [PATCH 18/31] use sklearn native transform --- molpipeline/pipeline/_skl_adapter_pipeline.py | 55 ------------------- molpipeline/pipeline/_skl_pipeline.py | 22 +------- 2 files changed, 1 insertion(+), 76 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 0fd80b60..7f72a559 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -777,61 +777,6 @@ def _can_transform(self) -> bool: # pylint: ignore[duplicate-code] "transform", ) - @available_if(_can_transform) - def transform( - self, - X: Any, # noqa: N803 - **params: Any, - ) -> Any: - """Transform the data, and apply `transform` with the final estimator. - - Call `transform` of each transformer in the pipeline. The transformed - data are finally passed to the final estimator that calls - `transform` method. Only valid if the final estimator - implements `transform`. - - This also works where final estimator is `None` in which case all prior - transformations are applied. - - Parameters - ---------- - X : iterable - Data to transform. Must fulfill input requirements of first step - of the pipeline. - **params : Any - Parameters to the ``transform`` method of each estimator. - - Raises - ------ - AssertionError - If the final estimator does not implement `transform` or - `fit_transform` or is passthrough. - - Returns - ------- - Xt : ndarray of shape (n_samples, n_transformed_features) - Transformed data. - - """ - routed_params = process_routing(self, "transform", **params) - iter_input = X - for _, name, transform in self._iter(): - if transform == "passthrough": - continue - if hasattr(transform, "transform"): - iter_input = transform.transform( - iter_input, - **routed_params[name].transform, - ) - else: - raise AssertionError( - "Non transformer ocurred in transformation step." - "This should have been caught in the validation step.", - ) - for _, post_element in self._post_processing_steps: - iter_input = post_element.transform(iter_input, **params) - return iter_input - @available_if(_can_decision_function) def decision_function( self, diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 6f71433a..ec769ce7 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -642,12 +642,6 @@ def transform( **params : Any Parameters to the ``transform`` method of each estimator. - Raises - ------ - AssertionError - If the final estimator does not implement `transform` or - `fit_transform` or is passthrough. - Returns ------- Xt : ndarray of shape (n_samples, n_transformed_features) @@ -658,21 +652,7 @@ def transform( output_generator = self._transform_iterator(X) iter_input = self.assemble_output(output_generator) else: - routed_params = process_routing(self, "transform", **params) - iter_input = X - for _, name, transform in self._iter(): - if transform == "passthrough": - continue - if hasattr(transform, "transform"): - iter_input = transform.transform( - iter_input, - **routed_params[name].transform, - ) - else: - raise AssertionError( - "Non transformer ocurred in transformation step." - "This should have been caught in the validation step.", - ) + iter_input = super().transform(X, **params) for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input, **params) return iter_input From dc076cf6f97a6fc96abad1d94b37ec20d20f92bc Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 15:29:20 +0200 Subject: [PATCH 19/31] delete _can_transform --- molpipeline/pipeline/_skl_adapter_pipeline.py | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 7f72a559..66919b0e 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -763,20 +763,6 @@ def predict_proba( iter_input = post_element.transform(iter_input) return iter_input - def _can_transform(self) -> bool: # pylint: ignore[duplicate-code] - """Check if the final estimator can transform or is passthrough. - - Returns - ------- - bool - True if the final estimator can transform or is passthrough. - - """ - return self._final_estimator == "passthrough" or hasattr( - self._final_estimator, - "transform", - ) - @available_if(_can_decision_function) def decision_function( self, From b96254cd6d7caa622ea27aba307f978f0ad8d4e1 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 15:34:19 +0200 Subject: [PATCH 20/31] use sklearn native decision function --- molpipeline/pipeline/_skl_adapter_pipeline.py | 57 ------------------- molpipeline/pipeline/_skl_pipeline.py | 32 +---------- 2 files changed, 1 insertion(+), 88 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 66919b0e..25fec796 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -763,63 +763,6 @@ def predict_proba( iter_input = post_element.transform(iter_input) return iter_input - @available_if(_can_decision_function) - def decision_function( - self, - X: Any, # noqa: N803 - **params: Any, - ) -> Any: - """Transform the data, and apply `decision_function` with the final estimator. - - Parameters - ---------- - X : iterable - Data to transform. Must fulfill input requirements of first step - of the pipeline. - **params : Any - Parameters to the ``decision_function`` method of the final estimator. - - Raises - ------ - AssertionError - If the final estimator does not implement `decision_function`. - - Returns - ------- - Any - Result of calling `decision_function` on the final estimator. - - """ - if _routing_enabled(): - routed_params = process_routing(self, "decision_function", **params) - else: - routed_params = process_routing(self, "decision_function") - - iter_input = self._transform(X, routed_params) - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "decision_function"): - if _routing_enabled(): - iter_input = self._final_estimator.decision_function( - iter_input, - **routed_params[self._final_estimator].predict, - ) - else: - iter_input = self._final_estimator.decision_function( - iter_input, - **params, - ) - else: - raise AssertionError( - "Final estimator does not implement `decision_function`, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps: - iter_input = post_element.transform(iter_input) - return iter_input - def get_metadata_routing(self) -> MetadataRouter: """Get metadata routing of this object. diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index ec769ce7..96bb2f63 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -673,43 +673,13 @@ def decision_function( **params : Any Parameters to the ``decision_function`` method of the final estimator. - Raises - ------ - AssertionError - If the final estimator does not implement `decision_function`. - Returns ------- Any Result of calling `decision_function` on the final estimator. """ - if _routing_enabled(): - routed_params = process_routing(self, "decision_function", **params) - else: - routed_params = process_routing(self, "decision_function") - - iter_input = self._transform(X, routed_params) - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "decision_function"): - if _routing_enabled(): - iter_input = self._final_estimator.decision_function( - iter_input, - **routed_params[self._final_estimator].predict, - ) - else: - iter_input = self._final_estimator.decision_function( - iter_input, - **params, - ) - else: - raise AssertionError( - "Final estimator does not implement `decision_function`, " - "hence this function should not be available.", - ) + iter_input = super().decision_function(X, **params) for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input) return iter_input From b500f4b3b972ef152d330ca7fab7244df036ecb2 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 15:53:26 +0200 Subject: [PATCH 21/31] remove duplicate fit_predict function --- molpipeline/pipeline/_skl_pipeline.py | 71 --------------------------- 1 file changed, 71 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 96bb2f63..ad58bc7b 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -533,77 +533,6 @@ def predict( iter_input = post_element.transform(iter_input) return iter_input - @available_if(_final_estimator_has("fit_predict")) - @_fit_context( - # estimators in Pipeline.steps are not validated yet - prefer_skip_nested_validation=False, - ) - def fit_predict( - self, - X: Any, # noqa: N803 - y: Any = None, - **params: Any, - ) -> Any: - """Transform the data, and apply `fit_predict` with the final estimator. - - Call `fit_transform` of each transformer in the pipeline. The - transformed data are finally passed to the final estimator that calls - `fit_predict` method. Only valid if the final estimator implements - `fit_predict`. - - Parameters - ---------- - X : iterable - Training data. Must fulfill input requirements of first step of - the pipeline. - - y : iterable, default=None - Training targets. Must fulfill label requirements for all steps - of the pipeline. - - **params : dict of string -> object - Parameters passed to the ``fit`` method of each step, where - each parameter name is prefixed such that parameter ``p`` for step - ``s`` has key ``s__p``. - - Raises - ------ - AssertionError - If the final estimator does not implement `fit_predict`. - In this case this function should not be available. - - Returns - ------- - y_pred : ndarray - Result of calling `fit_predict` on the final estimator. - - """ - routed_params = self._check_method_params(method="fit_predict", props=params) - iter_input, iter_label = self._fit(X, y, routed_params) - - params_last_step = routed_params[self._modified_steps[-1][0]] - with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): - if self._final_estimator == "passthrough": - y_pred = iter_input - elif is_empty(iter_input): - logger.warning("All input rows were filtered out! Model is not fitted!") - iter_input = [] - y_pred = [] - elif hasattr(self._final_estimator, "fit_predict"): - y_pred = self._final_estimator.fit_predict( - iter_input, - iter_label, - **params_last_step.get("fit_predict", {}), - ) - else: - raise AssertionError( - "Final estimator does not implement fit_predict, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps: - y_pred = post_element.fit_transform(y_pred, iter_label) - return y_pred - def _can_transform(self) -> bool: """Check if the final estimator can transform or is passthrough. From d8a717a5fc03a81f2a426ef1409b242ac56b278e Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Wed, 14 May 2025 17:40:11 +0200 Subject: [PATCH 22/31] rework predict function --- molpipeline/pipeline/_skl_adapter_pipeline.py | 153 ++++++++++-------- molpipeline/pipeline/_skl_pipeline.py | 73 --------- 2 files changed, 84 insertions(+), 142 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 25fec796..df76dc8f 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -10,7 +10,11 @@ from loguru import logger from sklearn.base import _fit_context, clone # noqa: PLC2701 from sklearn.pipeline import Pipeline as _Pipeline -from sklearn.pipeline import _final_estimator_has, _fit_transform_one # noqa: PLC2701 +from sklearn.pipeline import ( + _final_estimator_has, # noqa: PLC2701 + _fit_transform_one, # noqa: PLC2701 + _raise_or_warn_if_not_fitted, # noqa: PLC2701 +) from sklearn.utils.metadata_routing import ( MetadataRouter, MethodMapping, @@ -562,74 +566,6 @@ def fit_transform( iter_input = post_element.fit_transform(iter_input, iter_label) return iter_input - @available_if(_final_estimator_has("predict")) - def predict( - self, - X: Any, # noqa: N803 - **params: Any, - ) -> Any: - """Transform the data, and apply `predict` with the final estimator. - - Call `transform` of each transformer in the pipeline. The transformed - data are finally passed to the final estimator that calls `predict` - method. Only valid if the final estimator implements `predict`. - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - **params : dict of string -> object - Parameters to the ``predict`` called at the end of all - transformations in the pipeline. Note that while this may be - used to return uncertainties from some models with return_std - or return_cov, uncertainties that are generated by the - transformations in the pipeline are not propagated to the - final estimator. - - .. versionadded:: 0.20 - - Raises - ------ - AssertionError - If the final estimator does not implement `predict`. - In this case this function should not be available. - - Returns - ------- - y_pred : ndarray - Result of calling `predict` on the final estimator. - - """ - if _routing_enabled(): - routed_params = process_routing(self, "predict", **params) - else: - routed_params = process_routing(self, "predict") - - iter_input = self._transform(X, routed_params) - - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "predict"): - if _routing_enabled(): - iter_input = self._final_estimator.predict( - iter_input, - **routed_params[self._modified_steps[-1][0]].predict, - ) - else: - iter_input = self._final_estimator.predict(iter_input, **params) - else: - raise AssertionError( - "Final estimator does not implement predict, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps: - iter_input = post_element.transform(iter_input) - return iter_input - @available_if(_final_estimator_has("fit_predict")) @_fit_context( # estimators in Pipeline.steps are not validated yet @@ -701,6 +637,85 @@ def fit_predict( y_pred = post_element.fit_transform(y_pred, iter_label) return y_pred + @available_if(_final_estimator_has("predict")) + def predict( + self, + X: npt.NDArray[Any], # noqa: N803 + **params: Any, + ) -> npt.NDArray[Any]: + """Transform the data, and apply `predict` with the final estimator. + + Call `transform` of each transformer in the pipeline. The transformed + data are finally passed to the final estimator that calls `predict` + method. Only valid if the final estimator implements `predict`. + + Parameters + ---------- + X : npt.NDArray[Any] + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + **params : Any + If `enable_metadata_routing=False` (default): Parameters to the + ``predict`` called at the end of all transformations in the pipeline. + If `enable_metadata_routing=True`: Parameters requested and accepted by + steps. Each step must have requested certain metadata for these parameters + to be forwarded to them. + + Raises + ------ + AssertionError + If a step before the final estimator is 'passthrough' or does not + implement `transform`. + + Returns + ------- + y_pred : npt.NDArray[Any] + Result of calling `predict` on the final estimator. + + """ + iter_input = X + with _raise_or_warn_if_not_fitted(self): + if not _routing_enabled(): + for _, name, transform in self._iter(with_final=False): + if ( + not hasattr(transform, "transform") + or transform == "passthrough" + ): + raise AssertionError( + f"Non transformer occurred in transformation step: {name}.", + ) + iter_input = transform.transform(iter_input) + if is_empty(iter_input): + iter_input = np.array([]) + else: + iter_input = self._final_estimator.predict(iter_input, **params) + else: + # metadata routing enabled + routed_params = process_routing(self, "predict", **params) + for _, name, transform in self._iter(with_final=False): + if ( + not hasattr(transform, "transform") + or transform == "passthrough" + ): + raise AssertionError( + f"Non transformer occurred in transformation step: {name}.", + ) + iter_input = transform.transform( + iter_input, + **routed_params[name].transform, + ) + if is_empty(iter_input): + iter_input = np.array([]) + else: + iter_input = self._final_estimator.predict( + iter_input, + **routed_params[self.steps[-1][0]].predict, + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input, **params) + return np.array(iter_input) + @available_if(_final_estimator_has("predict_proba")) def predict_proba( self, diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index ad58bc7b..3ddc4f02 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -13,12 +13,7 @@ from rdkit.Chem.rdchem import MolSanitizeException from rdkit.rdBase import BlockLogs from sklearn.base import _fit_context # noqa: PLC2701 -from sklearn.pipeline import _final_estimator_has # noqa: PLC2701 from sklearn.utils._tags import Tags, get_tags # noqa: PLC2701 -from sklearn.utils.metadata_routing import ( - _routing_enabled, # noqa: PLC2701 - process_routing, -) from sklearn.utils.metaestimators import available_if from molpipeline.abstract_pipeline_elements.core import ( @@ -465,74 +460,6 @@ def fit_transform( iter_input = post_element.fit_transform(iter_input, y) return iter_input - @available_if(_final_estimator_has("predict")) - def predict( - self, - X: Any, # noqa: N803 - **params: Any, - ) -> Any: - """Transform the data, and apply `predict` with the final estimator. - - Call `transform` of each transformer in the pipeline. The transformed - data are finally passed to the final estimator that calls `predict` - method. Only valid if the final estimator implements `predict`. - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - **params : dict of string -> object - Parameters to the ``predict`` called at the end of all - transformations in the pipeline. Note that while this may be - used to return uncertainties from some models with return_std - or return_cov, uncertainties that are generated by the - transformations in the pipeline are not propagated to the - final estimator. - - .. versionadded:: 0.20 - - Raises - ------ - AssertionError - If the final estimator does not implement `predict`. - In this case this function should not be available. - - Returns - ------- - y_pred : ndarray - Result of calling `predict` on the final estimator. - - """ - if _routing_enabled(): - routed_params = process_routing(self, "predict", **params) - else: - routed_params = process_routing(self, "predict") - - iter_input = self._transform(X, routed_params) - - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "predict"): - if _routing_enabled(): - iter_input = self._final_estimator.predict( - iter_input, - **routed_params[self._modified_steps[-1][0]].predict, - ) - else: - iter_input = self._final_estimator.predict(iter_input, **params) - else: - raise AssertionError( - "Final estimator does not implement predict, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps: - iter_input = post_element.transform(iter_input) - return iter_input - def _can_transform(self) -> bool: """Check if the final estimator can transform or is passthrough. From 654f549230686ff5dc827040fe4e700fe3078a6f Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 10:22:43 +0200 Subject: [PATCH 23/31] Remove classes property --- molpipeline/pipeline/_skl_pipeline.py | 28 --------------------------- 1 file changed, 28 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 3ddc4f02..98e0e3f5 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -7,7 +7,6 @@ from typing import TYPE_CHECKING, Any, Literal, Self, TypeIs import numpy as np -import numpy.typing as npt from joblib import Parallel, delayed from loguru import logger from rdkit.Chem.rdchem import MolSanitizeException @@ -29,9 +28,6 @@ _MultipleErrorFilter, ) from molpipeline.pipeline._skl_adapter_pipeline import AdapterPipeline -from molpipeline.post_prediction import ( - PostPredictionTransformation, -) from molpipeline.utils.logging import print_elapsed_time from molpipeline.utils.molpipeline_types import ( AnyElement, @@ -540,30 +536,6 @@ def decision_function( iter_input = post_element.transform(iter_input) return iter_input - @property - def classes_(self) -> list[Any] | npt.NDArray[Any]: - """Return the classes of the last element. - - PostPredictionTransformation elements are not considered as last element. - - Raises - ------ - ValueError - If the last step is passthrough or has no classes_ attribute. - - """ - check_last = [ - step - for step in self.steps - if not isinstance(step[1], PostPredictionTransformation) - ] - last_step = check_last[-1][1] - if last_step == "passthrough": - raise ValueError("Last step is passthrough.") - if hasattr(last_step, "classes_"): - return last_step.classes_ - raise ValueError("Last step has no classes_ attribute.") - def __sklearn_tags__(self) -> Tags: # noqa: PLW3201 """Return the sklearn tags. From bf630003c8636d4f31f9c99f40845305876708d4 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 10:42:16 +0200 Subject: [PATCH 24/31] Remove Validate steps --- molpipeline/pipeline/_skl_pipeline.py | 48 --------------------------- 1 file changed, 48 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 98e0e3f5..d0e6b5c7 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -156,54 +156,6 @@ def __init__( self.n_jobs = n_jobs self._set_error_resinserter() - def _validate_steps(self) -> None: - """Validate the steps. - - Raises - ------ - TypeError - If the steps do not implement fit and transform or are not 'passthrough'. - - """ - names = [name for name, _ in self.steps] - - # validate names - self._validate_names(names) - - # validate estimators - non_post_processing_steps = [e for _, _, e in self._iter()] - transformer_list = non_post_processing_steps[:-1] - estimator = non_post_processing_steps[-1] - - for transformer in transformer_list: - if transformer is None or transformer == "passthrough": - continue - if not ( - hasattr(transformer, "fit") or hasattr(transformer, "fit_transform") - ) or not hasattr(transformer, "transform"): - raise TypeError( - f"All intermediate steps should be " - f"transformers and implement fit and transform " - f"or be the string 'passthrough' " - f"'{transformer}' (type {type(transformer)}) doesn't", - ) - - # We allow last estimator to be None as an identity transformation - if ( - estimator is not None - and estimator != "passthrough" - and not hasattr(estimator, "fit") - ): - raise TypeError( - f"Last step of Pipeline should implement fit " - f"or be the string 'passthrough'. " - f"'{estimator}' (type {type(estimator)}) doesn't", - ) - - # validate post-processing steps - # Calling steps automatically validates them - _ = self._post_processing_steps - def _iter( self, with_final: bool = True, From f4893049a890f782056927d2b6b3ead80b4eaee4 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 10:59:24 +0200 Subject: [PATCH 25/31] Switch type casting back and adapt types --- molpipeline/pipeline/_skl_adapter_pipeline.py | 10 +++++----- molpipeline/utils/molpipeline_types.py | 16 ++++++++-------- 2 files changed, 13 insertions(+), 13 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index df76dc8f..485a32d2 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -640,9 +640,9 @@ def fit_predict( @available_if(_final_estimator_has("predict")) def predict( self, - X: npt.NDArray[Any], # noqa: N803 + X: npt.NDArray[Any] | list[Any], # noqa: N803 **params: Any, - ) -> npt.NDArray[Any]: + ) -> npt.NDArray[Any] | list[Any]: """Transform the data, and apply `predict` with the final estimator. Call `transform` of each transformer in the pipeline. The transformed @@ -687,7 +687,7 @@ def predict( ) iter_input = transform.transform(iter_input) if is_empty(iter_input): - iter_input = np.array([]) + iter_input = [] else: iter_input = self._final_estimator.predict(iter_input, **params) else: @@ -706,7 +706,7 @@ def predict( **routed_params[name].transform, ) if is_empty(iter_input): - iter_input = np.array([]) + iter_input = [] else: iter_input = self._final_estimator.predict( iter_input, @@ -714,7 +714,7 @@ def predict( ) for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input, **params) - return np.array(iter_input) + return iter_input @available_if(_final_estimator_has("predict_proba")) def predict_proba( diff --git a/molpipeline/utils/molpipeline_types.py b/molpipeline/utils/molpipeline_types.py index 9240dab6..3664df1c 100644 --- a/molpipeline/utils/molpipeline_types.py +++ b/molpipeline/utils/molpipeline_types.py @@ -81,7 +81,7 @@ def set_params(self, **params: Any) -> Self: def fit( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 y: npt.NDArray[Any] | None, **fit_params: Any, ) -> Self: @@ -89,7 +89,7 @@ def fit( Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. y: npt.NDArray[Any] | None Target values. @@ -110,7 +110,7 @@ class AnyPredictor(AnySklearnEstimator, Protocol): def fit_predict( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 y: npt.NDArray[Any] | None, **fit_params: Any, ) -> npt.NDArray[Any]: @@ -118,7 +118,7 @@ def fit_predict( Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. y: npt.NDArray[Any] | None Target values. @@ -138,7 +138,7 @@ class AnyTransformer(AnySklearnEstimator, Protocol): def fit_transform( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 y: npt.NDArray[Any] | None, **fit_params: Any, ) -> npt.NDArray[Any]: @@ -146,7 +146,7 @@ def fit_transform( Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. y: npt.NDArray[Any] | None Target values. @@ -163,14 +163,14 @@ def fit_transform( def transform( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 **params: Any, ) -> npt.NDArray[Any]: """Transform and return X according to object protocol. Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. params: Any Additional parameters for transforming. From f3124832267fe1d651cacd71f19d22d34ba6627c Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 11:11:10 +0200 Subject: [PATCH 26/31] use super.fit --- molpipeline/pipeline/_skl_pipeline.py | 16 +--------------- 1 file changed, 1 insertion(+), 15 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index d0e6b5c7..9b4f6403 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -8,7 +8,6 @@ import numpy as np from joblib import Parallel, delayed -from loguru import logger from rdkit.Chem.rdchem import MolSanitizeException from rdkit.rdBase import BlockLogs from sklearn.base import _fit_context # noqa: PLC2701 @@ -28,14 +27,12 @@ _MultipleErrorFilter, ) from molpipeline.pipeline._skl_adapter_pipeline import AdapterPipeline -from molpipeline.utils.logging import print_elapsed_time from molpipeline.utils.molpipeline_types import ( AnyElement, AnyPredictor, AnyStep, AnyTransformer, ) -from molpipeline.utils.value_checks import is_empty if TYPE_CHECKING: from collections.abc import Generator, Iterable, Sequence @@ -272,18 +269,7 @@ def fit( """ if self.supports_single_instance: return self - routed_params = self._check_method_params(method="fit", props=fit_params) - xt, yt = self._fit(X, y, routed_params) - with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): - if self._final_estimator != "passthrough": - if is_empty(xt): - logger.warning( - "All input rows were filtered out! Model is not fitted!", - ) - else: - fit_params_last_step = routed_params[self._modified_steps[-1][0]] - self._final_estimator.fit(xt, yt, **fit_params_last_step["fit"]) - + super().fit(X, y, **fit_params) return self def _can_fit_transform(self) -> bool: From 763f0dc5aba780fdcdba5bcc834902eb637d36d9 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 11:23:12 +0200 Subject: [PATCH 27/31] remove can decision function --- molpipeline/pipeline/_skl_adapter_pipeline.py | 11 ----------- 1 file changed, 11 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 485a32d2..96e32e83 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -479,17 +479,6 @@ def _can_fit_transform(self) -> bool: or hasattr(self._final_estimator, "fit_transform") ) - def _can_decision_function(self) -> bool: - """Check if the final estimator implements decision_function. - - Returns - ------- - bool - True if the final estimator implements decision_function. - - """ - return hasattr(self._final_estimator, "decision_function") - @available_if(_can_fit_transform) @_fit_context( # estimators in Pipeline.steps are not validated yet From 389c2fa7a1e2c1399601915870d791dcc17a67e2 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 11:25:34 +0200 Subject: [PATCH 28/31] ignore duplicate code (cannot be inherited) --- molpipeline/pipeline/_skl_adapter_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 96e32e83..e5530dc0 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -464,7 +464,7 @@ def fit( return self - def _can_fit_transform(self) -> bool: + def _can_fit_transform(self) -> bool: # pylint: disable=duplicate-code """Check if the final estimator can fit_transform or is passthrough. Returns From 1527dc575f7141988328ac5628352b1e4f294842 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 11:27:39 +0200 Subject: [PATCH 29/31] pylint ignore --- molpipeline/pipeline/_skl_adapter_pipeline.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index e5530dc0..619144d6 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -58,7 +58,7 @@ def _estimator_type(self) -> Any: if self._final_estimator is None or self._final_estimator == "passthrough": return None if hasattr(self._final_estimator, "_estimator_type"): - return self._final_estimator._estimator_type # noqa: SLF001 + return self._final_estimator._estimator_type # noqa: SLF001 # pylint: disable=protected-access return None @property From 0f1c587ace4b4a6e7df95149574158a31d9a2bc7 Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 11:51:38 +0200 Subject: [PATCH 30/31] pylint ignore and move function --- molpipeline/pipeline/_skl_pipeline.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 9b4f6403..89e415fd 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -272,7 +272,7 @@ def fit( super().fit(X, y, **fit_params) return self - def _can_fit_transform(self) -> bool: + def _can_fit_transform(self) -> bool: # pylint: disable=duplicate-code """Check if the final estimator can fit_transform or is passthrough. Returns @@ -287,17 +287,6 @@ def _can_fit_transform(self) -> bool: or hasattr(self._final_estimator, "fit_transform") ) - def _can_decision_function(self) -> bool: - """Check if the final estimator implements decision_function. - - Returns - ------- - bool - True if the final estimator implements decision_function. - - """ - return hasattr(self._final_estimator, "decision_function") - @available_if(_can_fit_transform) @_fit_context( # estimators in Pipeline.steps are not validated yet @@ -447,6 +436,17 @@ def transform( iter_input = post_element.transform(iter_input, **params) return iter_input + def _can_decision_function(self) -> bool: + """Check if the final estimator implements decision_function. + + Returns + ------- + bool + True if the final estimator implements decision_function. + + """ + return hasattr(self._final_estimator, "decision_function") + @available_if(_can_decision_function) def decision_function( self, From 8cbeb9dbd26bc55ce9d32dc8f2090aa820324c6f Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Thu, 15 May 2025 12:12:30 +0200 Subject: [PATCH 31/31] change ignore statement --- molpipeline/pipeline/_skl_adapter_pipeline.py | 5 ++++- molpipeline/pipeline/_skl_pipeline.py | 7 +++++-- 2 files changed, 9 insertions(+), 3 deletions(-) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py index 619144d6..d6299be0 100644 --- a/molpipeline/pipeline/_skl_adapter_pipeline.py +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -464,7 +464,8 @@ def fit( return self - def _can_fit_transform(self) -> bool: # pylint: disable=duplicate-code + # pylint: disable=duplicate-code + def _can_fit_transform(self) -> bool: """Check if the final estimator can fit_transform or is passthrough. Returns @@ -479,6 +480,8 @@ def _can_fit_transform(self) -> bool: # pylint: disable=duplicate-code or hasattr(self._final_estimator, "fit_transform") ) + # pylint: enable=duplicate-code + @available_if(_can_fit_transform) @_fit_context( # estimators in Pipeline.steps are not validated yet diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 89e415fd..6bc558b8 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -104,7 +104,7 @@ def check_single_instance_support( return isinstance(estimator, Pipeline) and estimator.supports_single_instance -class Pipeline(AdapterPipeline, TransformingPipelineElement): +class Pipeline(AdapterPipeline, TransformingPipelineElement): # pylint: disable=too-many-ancestors """Defines the pipeline which handles pipeline elements.""" steps: list[AnyStep] @@ -272,7 +272,8 @@ def fit( super().fit(X, y, **fit_params) return self - def _can_fit_transform(self) -> bool: # pylint: disable=duplicate-code + # pylint: disable=duplicate-code + def _can_fit_transform(self) -> bool: """Check if the final estimator can fit_transform or is passthrough. Returns @@ -287,6 +288,8 @@ def _can_fit_transform(self) -> bool: # pylint: disable=duplicate-code or hasattr(self._final_estimator, "fit_transform") ) + # pylint: enable=duplicate-code + @available_if(_can_fit_transform) @_fit_context( # estimators in Pipeline.steps are not validated yet