diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 1666a9c8..169b10ad 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -5,8 +5,7 @@ import abc import copy import inspect -from collections.abc import Iterable -from typing import Any, NamedTuple, Self +from typing import TYPE_CHECKING, Any, NamedTuple, Self from uuid import uuid4 from joblib import Parallel, delayed @@ -16,6 +15,9 @@ from molpipeline.utils.multi_proc import check_available_cores +if TYPE_CHECKING: + from collections.abc import Iterable + class InvalidInstance(NamedTuple): """Object which is returned when an instance cannot be processed. @@ -29,6 +31,7 @@ class InvalidInstance(NamedTuple): element_name: str | None Optional name of the element which could not be processed. The name of the pipeline element is often more descriptive than the id. + """ element_id: str @@ -42,6 +45,7 @@ def __repr__(self) -> str: ------- str String representation of InvalidInstance. + """ return ( f"InvalidInstance({self.element_name or self.element_id}, {self.message})" @@ -75,6 +79,7 @@ def __repr__(self) -> str: ------- str String representation of RemovedInstance. + """ return f"RemovedInstance({self.filter_element_id}, {self.message})" @@ -83,7 +88,6 @@ class ABCPipelineElement(abc.ABC): """Ancestor of all PipelineElements.""" name: str - _requires_fitting: bool = False uuid: str def __init__( @@ -102,6 +106,7 @@ def __init__( Number of cores used for processing. uuid: str | None, optional Unique identifier of the PipelineElement. + """ if name is None: name = self.__class__.__name__ @@ -119,13 +124,13 @@ def __repr__(self) -> str: ------- str String representation of object. + """ parm_list = [] for key, value in self._get_non_default_parameters().items(): parm_list.append(f"{key}={value}") parm_str = ", ".join(parm_list) - repr_str = f"{self.__class__.__name__}({parm_str})" - return repr_str + return f"{self.__class__.__name__}({parm_str})" def _get_non_default_parameters(self) -> dict[str, Any]: """Return all parameters which are not default. @@ -134,6 +139,7 @@ def _get_non_default_parameters(self) -> dict[str, Any]: ------- dict[str, Any] Dictionary of non default parameters. + """ signature = inspect.signature(self.__class__.__init__) self_params = self.get_params() @@ -144,9 +150,8 @@ def _get_non_default_parameters(self) -> dict[str, Any]: if value_default is inspect.Parameter.empty: non_default_params[parm] = self_params[parm] continue - if parm in self_params: - if value_default.default != self_params[parm]: - non_default_params[parm] = self_params[parm] + if parm in self_params and value_default.default != self_params[parm]: + non_default_params[parm] = self_params[parm] return non_default_params def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -161,6 +166,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameters of the object. + """ if deep: return { @@ -176,7 +182,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: } def set_params(self, **parameters: Any) -> Self: - """As the setter function cannot be assessed with super(), this method is implemented for inheritance. + """Set the parameters of the object. Parameters ---------- @@ -197,7 +203,7 @@ def set_params(self, **parameters: Any) -> Self: for att_name, att_value in parameters.items(): if not hasattr(self, att_name): raise ValueError( - f"Cannot set attribute {att_name} on {self.__class__.__name__}" + f"Cannot set attribute {att_name} on {self.__class__.__name__}", ) setattr(self, att_name, att_value) return self @@ -219,45 +225,25 @@ def n_jobs(self, n_jobs: int) -> None: """ self._n_jobs = check_available_cores(n_jobs) - @property - def requires_fitting(self) -> bool: - """Return whether the object requires fitting or not.""" - return self._requires_fitting - - def fit(self, values: Any, labels: Any = None) -> Self: + def fit( + self, + values: Any, # noqa: ARG002 # pylint: disable=unused-argument + labels: Any = None, # noqa: ARG002 # pylint: disable=unused-argument + ) -> Self: """Fit object to input_values. - Most objects might not need fitting, but it is implemented for consitency for all PipelineElements. - Parameters ---------- values: Any List of molecule representations. - labels: Any + labels: Any, optional Optional label for fitting. Returns ------- Self Fitted object. - """ - _ = self.fit_transform(values, labels) - return self - - def fit_to_result(self, values: Any) -> Self: # pylint: disable=unused-argument - """Fit object to result of transformed values. - Fit object to the result of the transform function. This is useful catching nones and removed molecules. - - Parameters - ---------- - values: Any - List of molecule representations. - - Returns - ------- - Self - Fitted object. """ return self @@ -272,7 +258,7 @@ def fit_transform( Parameters ---------- values: Any - Apply transformation specified in transform_single to all molecules in the value_list. + Apply transform_single to all molecules in the value_list. labels: Any Optional label for fitting. @@ -280,6 +266,7 @@ def fit_transform( ------- Any List of instances in new representation. + """ @abc.abstractmethod @@ -289,28 +276,15 @@ def transform(self, values: Any) -> Any: Parameters ---------- values: Any - Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, PhysChem vectors etc.). + Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, + PhysChem vectors etc.). Input depends on the concrete PipelineElement. Returns ------- Any Transformed input_values. - """ - - @abc.abstractmethod - def transform_single(self, value: Any) -> Any: - """Transform a single value. - Parameters - ---------- - value: Any - Value to be transformed. - - Returns - ------- - Any - Transformed value. """ @@ -337,6 +311,7 @@ def __init__( Number of cores used for processing. uuid: str | None, optional Unique identifier of the PipelineElement. + """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self._is_fitted = False @@ -346,11 +321,6 @@ def input_type(self) -> str: """Return the input type.""" return self._input_type - @property - def is_fitted(self) -> bool: - """Return whether the object is fitted or not.""" - return self._is_fitted - @property def output_type(self) -> str: """Return the output type.""" @@ -373,24 +343,6 @@ def parameters(self, parameters: Any) -> None: """ self.set_params(**parameters) - def fit_to_result(self, values: Any) -> Self: - """Fit object to result of transformed values. - - Fit object to the result of the transform function. This is useful catching nones and removed molecules. - - Parameters - ---------- - values: Any - List of molecule representations. - - Returns - ------- - Self - Fitted object. - """ - self._is_fitted = True - return super().fit_to_result(values) - def fit_transform( self, values: Any, @@ -401,7 +353,7 @@ def fit_transform( Parameters ---------- values: Any - Apply transformation specified in transform_single to all molecules in the value_list. + Apply transform_single to all molecules in the value_list. labels: Any Optional label for fitting. @@ -409,43 +361,16 @@ def fit_transform( ------- Any List of molecules in new representation. - """ - self._is_fitted = True - if self.requires_fitting: - pre_value_list = self.pretransform(values) - self.fit_to_result(pre_value_list) - output_list = self.finalize_list(pre_value_list) - if hasattr(self, "assemble_output"): - return self.assemble_output(output_list) - return self.transform(values) - - def transform_single(self, value: Any) -> Any: - """Transform a single molecule to the new representation. - RemovedMolecule objects are passed without change, as no transformations are applicable. - - Parameters - ---------- - value: Any - Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) - - Returns - ------- - Any - New representation of the molecule. (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) """ - if isinstance(value, InvalidInstance): - return value - pre_value = self.pretransform_single(value) - if isinstance(pre_value, InvalidInstance): - return pre_value - return self.finalize_single(pre_value) + self.fit(values, labels) + return self.transform(values) def assemble_output(self, value_list: Iterable[Any]) -> Any: """Aggregate rows, which in most cases is just return the list. - Some representations might be better representd as a single object. For example a list of vectors can - be transformed to a matrix. + Some representations might be better representd as a single object. + For example a list of vectors can be transformed to a matrix. Parameters ---------- @@ -456,45 +381,64 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: ------- Any Aggregated output. This can also be the original input. + """ return list(value_list) @abc.abstractmethod - def pretransform_single(self, value: Any) -> Any: - """Transform the instance, but skips parameters learned during fitting. - - This is the first step for the full transformation. - It is followed by the finalize_single method and assemble output which collects all single transformations. - These functions are split as they need to be accessed separately from outside the object. + def pretransform(self, value_list: Iterable[Any]) -> list[Any]: + """Transform input_values according to object rules. Parameters ---------- - value: Any - Value to be pretransformed. + value_list: Iterable[Any] + Iterable of instances to be pretransformed. Returns ------- - Any - Pretransformed value. (Skips applying parameters learned during fitting) + list[Any] + Transformed input_values. + """ - def finalize_single(self, value: Any) -> Any: - """Apply parameters learned during fitting to a single instance. + def transform(self, values: Any) -> Any: + """Transform input_values according to object rules. Parameters ---------- - value: Any - Value obtained from pretransform_single. + values: Any + Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, + PhysChem vectors etc.). + Input depends on the concrete PipelineElement. Returns ------- Any - Finalized value. + Transformed input_values. + + """ + output_values = self.pretransform(values) + return self.assemble_output(output_values) + + +class SingleInstanceTransformerMixin(abc.ABC): + """Mixin for single instance processing.""" + + def __init__(self, *args: Any, **kwargs: Any) -> None: + """Initialize SingleInstanceTransformerMixin. + + Parameters + ---------- + args: Any + Arguments for the object. + kwargs: Any + Keyword arguments for the object. + """ - return value + super().__init__(*args, **kwargs) def pretransform(self, value_list: Iterable[Any]) -> list[Any]: - """Transform input_values according to object rules without fitting specifics. + """Transform input_values according to object rules. Parameters ---------- @@ -505,53 +449,61 @@ def pretransform(self, value_list: Iterable[Any]) -> list[Any]: ------- list[Any] Transformed input_values. + """ - parallel = Parallel(n_jobs=self.n_jobs) - output_values = parallel( + parallel = Parallel(n_jobs=self.n_jobs) # type: ignore[attr-defined] + return parallel( delayed(self.pretransform_single)(value) for value in value_list ) - return output_values - def finalize_list(self, value_list: Iterable[Any]) -> list[Any]: - """Transform list of values according to parameters learned during fitting. + def transform_single(self, value: Any) -> Any: + """Transform a single molecule to the new representation. + + RemovedMolecule objects are passed without change. Parameters ---------- - value_list: Iterable[Any] - List of values to be transformed. + value: Any + Current representation of the molecule. (Eg. SMILES, RDKit Mol, ...) Returns ------- - list[Any] - List of transformed values. + Any + New representation of the molecule. + (Eg. SMILES, RDKit Mol, Descriptor-Vector, ...) + """ - parallel = Parallel(n_jobs=self.n_jobs) - output_values = parallel( - delayed(self.finalize_single)(value) for value in value_list - ) - return output_values + if isinstance(value, InvalidInstance): + return value + return self.pretransform_single(value) - def transform(self, values: Any) -> Any: - """Transform input_values according to object rules. + @abc.abstractmethod + def pretransform_single(self, value: Any) -> Any: + """Transform the instance. + + This is the first step for the full transformation. + It is followed by the finalize_single method and assemble output which collects + all single transformations. These functions are split as they need to be + accessed separately from outside the object. Parameters ---------- - values: Any - Iterable of molecule representations (SMILES, MolBlocks RDKit Molecules, PhysChem vectors etc.). - Input depends on the concrete PipelineElement. + value: Any + Value to be pretransformed. Returns ------- Any - Transformed input_values. + Pretransformed value. (Skips applying parameters learned during fitting) + """ - output_rows = self.pretransform(values) - output_rows = self.finalize_list(output_rows) - output = self.assemble_output(output_rows) - return output -class MolToMolPipelineElement(TransformingPipelineElement, abc.ABC): +class MolToMolPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, + abc.ABC, +): """Abstract PipelineElement where input and outputs are molecules.""" _input_type = "RDKitMol" @@ -568,7 +520,8 @@ def transform(self, values: list[OptionalMol]) -> list[OptionalMol]: Returns ------- list[OptionalMol] - List of molecules or InvalidInstances, if corresponding transformation was not successful. + List of molecules or InvalidInstances. + """ mol_list: list[OptionalMol] = super().transform(values) # Stupid mypy... return mol_list @@ -585,17 +538,22 @@ def transform_single(self, value: OptionalMol) -> OptionalMol: ------- OptionalMol Transformed molecule if transformation was successful, else InvalidInstance. + """ + if isinstance(value, InvalidInstance): + return value try: return super().transform_single(value) except Exception as exc: if isinstance(value, Chem.Mol): logger.error( - f"Failed to process: {Chem.MolToSmiles(value)} | ({self.name}) ({self.uuid})" + f"Failed to process: {Chem.MolToSmiles(value)} | ({self.name}) " + f"({self.uuid})", ) else: logger.error( - f"Failed to process: {value} ({type(value)}) | ({self.name}) ({self.uuid})" + f"Failed to process: {value} ({type(value)}) | ({self.name}) " + f"({self.uuid})", ) raise exc @@ -614,10 +572,15 @@ def pretransform_single(self, value: RDKitMol) -> OptionalMol: ------- OptionalMol Transformed molecule if transformation was successful, else InvalidInstance. + """ -class AnyToMolPipelineElement(TransformingPipelineElement, abc.ABC): +class AnyToMolPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, + abc.ABC, +): """Abstract PipelineElement which creates molecules from different inputs.""" _output_type = "RDKitMol" @@ -633,14 +596,15 @@ def transform(self, values: Any) -> list[OptionalMol]: Returns ------- list[OptionalMol] - List of molecules or InvalidInstances, if corresponding representation was invalid. + List of molecules or InvalidInstances. + """ mol_list: list[OptionalMol] = super().transform(values) # Stupid mypy... return mol_list @abc.abstractmethod def pretransform_single(self, value: Any) -> OptionalMol: - """Transform the instance to a molecule, but skip parameters learned during fitting. + """Transform the instance to a molecule. Parameters ---------- @@ -651,17 +615,22 @@ def pretransform_single(self, value: Any) -> OptionalMol: ------- OptionalMol Obtained molecule if valid representation, else InvalidInstance. + """ -class MolToAnyPipelineElement(TransformingPipelineElement, abc.ABC): +class MolToAnyPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, + abc.ABC, +): """Abstract PipelineElement which creates molecules from different inputs.""" _input_type = "RDKitMol" @abc.abstractmethod def pretransform_single(self, value: RDKitMol) -> Any: - """Transform the molecule, but skip parameters learned during fitting. + """Transform the molecule. Parameters ---------- @@ -672,4 +641,5 @@ def pretransform_single(self, value: RDKitMol) -> Any: ------- Any Transformed molecule. + """ diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index 87cad027..16923fd1 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -2,19 +2,24 @@ from __future__ import annotations -from collections.abc import Iterable, Sequence -from typing import Any, Generic, Self, TypeVar +from typing import TYPE_CHECKING, Any, Generic, Self, TypeVar import numpy as np import numpy.typing as npt +from loguru import logger from molpipeline.abstract_pipeline_elements.core import ( ABCPipelineElement, InvalidInstance, RemovedInstance, + SingleInstanceTransformerMixin, TransformingPipelineElement, ) -from molpipeline.utils.molpipeline_types import AnyVarSeq, TypeFixedVarSeq + +if TYPE_CHECKING: + from collections.abc import Iterable, Sequence + + from molpipeline.utils.molpipeline_types import AnyVarSeq, TypeFixedVarSeq __all__ = ["ErrorFilter", "FilterReinserter", "_MultipleErrorFilter"] @@ -22,8 +27,8 @@ _S = TypeVar("_S") -class ErrorFilter(ABCPipelineElement): - """Collects None values and can fill Dummy values to matrices where None values were removed.""" +class ErrorFilter(SingleInstanceTransformerMixin, TransformingPipelineElement): + """Filter to remove InvalidInstances from a list of values.""" element_ids: set[str] error_indices: list[int] @@ -64,7 +69,7 @@ def __init__( if element_ids is None: if not filter_everything: raise ValueError( - "If element_ids is None, filter_everything must be True" + "If element_ids is None, filter_everything must be True", ) element_ids = set() if not isinstance(element_ids, set): @@ -72,7 +77,6 @@ def __init__( self.element_ids = element_ids self.filter_everything = filter_everything self.n_total = 0 - self._requires_fitting = True @classmethod def from_element_list( @@ -99,10 +103,15 @@ def from_element_list( ------- Self Constructed ErrorFilter object. + """ element_ids = {element.uuid for element in element_list} return cls( - element_ids, filter_everything=False, name=name, n_jobs=n_jobs, uuid=uuid + element_ids, + filter_everything=False, + name=name, + n_jobs=n_jobs, + uuid=uuid, ) def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -117,6 +126,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter names mapped to their values. + """ params = super().get_params(deep=deep) params["filter_everything"] = self.filter_everything @@ -168,14 +178,17 @@ def check_removal(self, value: Any) -> bool: ------- bool True if value should be removed. + """ if not isinstance(value, InvalidInstance): return False - if self.filter_everything or value.element_id in self.element_ids: - return True - return False + return self.filter_everything or value.element_id in self.element_ids - def fit(self, values: AnyVarSeq, labels: Any = None) -> Self: + def fit( + self, + values: AnyVarSeq, # noqa: ARG002 + labels: Any = None, # noqa: ARG002 + ) -> Self: """Fit to input values. Only for compatibility with sklearn Pipelines. @@ -191,11 +204,14 @@ def fit(self, values: AnyVarSeq, labels: Any = None) -> Self: ------- Self Fitted ErrorFilter. + """ return self def fit_transform( - self, values: TypeFixedVarSeq, labels: Any = None + self, + values: TypeFixedVarSeq, + labels: Any = None, ) -> TypeFixedVarSeq: """Transform values and return a list without the None values. @@ -205,22 +221,24 @@ def fit_transform( ---------- values: TypeFixedVarSeq Iterable to which element is fitted and which is subsequently transformed. - labels: Any - Label used for fitting. (Not used, but required for compatibility with sklearn) + labels: Any, optional + Label used for fitting. + (Not used, but required for compatibility with sklearn) Returns ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ self.fit(values, labels) return self.transform(values) def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: - """Remove rows at positions which contained discarded values during transformation. + """Remove rows at positions which contained discarded values in `transform`. - This ensures that rows of this instance maintain a one to one correspondence with the rows of data seen during - transformation. + This ensures that rows of this instance maintain a one to one correspondence + with the rows of data seen during transformation. Parameters ---------- @@ -241,6 +259,9 @@ def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: """ if self.n_total != len(values): + logger.error( + f"Expected {self.n_total} values, but got {len(values)}", + ) raise ValueError("Length of values does not match length of values in fit") if isinstance(values, list): out_list = [] @@ -266,6 +287,7 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ self.n_total = len(values) self.error_indices = [] @@ -277,6 +299,8 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: def transform_single(self, value: Any) -> Any: """Transform a single value. + Overrides parent method to not skip InvalidInstances. + Parameters ---------- value: Any @@ -285,7 +309,8 @@ def transform_single(self, value: Any) -> Any: Returns ------- Any - Transformed value. + The original value or a RemovedInstance. + """ return self.pretransform_single(value) @@ -301,6 +326,7 @@ def pretransform_single(self, value: Any) -> Any: ------- Any Transformed value. + """ if self.check_removal(value): return RemovedInstance( @@ -322,6 +348,7 @@ def __init__(self, error_filter_list: list[ErrorFilter]) -> None: ---------- error_filter_list: list[ErrorFilter] List of ErrorFilter objects. + """ self.error_filter_list = error_filter_list prior_remover_dict = {} @@ -341,13 +368,14 @@ def transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ for error_filter in self.error_filter_list: values = error_filter.transform(values) return values def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: - """Remove rows at positions which contained discarded values during transformation. + """Remove rows at positions which contained discarded values in `transform`. Parameters ---------- @@ -357,7 +385,8 @@ def co_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: Returns ------- TypeFixedVarSeq - Iterable where rows are removed which were removed during the transformation. + Iterable where rows are removed which were removed in `transform`. + """ for error_filter in self.error_filter_list: values = error_filter.co_transform(values) @@ -375,6 +404,7 @@ def fit_transform(self, values: TypeFixedVarSeq) -> TypeFixedVarSeq: ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ for error_filter in self.error_filter_list: values = error_filter.fit_transform(values) @@ -410,7 +440,8 @@ def register_removed(self, index: int, value: RemovedInstance) -> None: break else: raise ValueError( - f"Invalid instance not captured by any ErrorFilter: {value.filter_element_id}" + f"Invalid instance not captured by any ErrorFilter: " + f"{value.filter_element_id}", ) def set_total(self, total: int) -> None: @@ -457,6 +488,7 @@ def __init__( Number of parallel jobs to use. uuid: str | None, optional UUID of the pipeline element. + """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.error_filter_id = error_filter_id @@ -491,6 +523,7 @@ def from_error_filter( ------- Self Constructed FilterReinserter object. + """ filler = cls( error_filter_id=error_filter.uuid, @@ -514,6 +547,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter names mapped to their values. + """ params = super().get_params(deep=deep) if deep: @@ -539,6 +573,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self The instance itself. + """ parameter_copy = dict(parameters) if "error_filter_id" in parameter_copy: @@ -570,6 +605,7 @@ def error_filter(self, error_filter: ErrorFilter) -> None: ---------- error_filter: ErrorFilter ErrorFilter to set. + """ self._error_filter = error_filter @@ -603,9 +639,9 @@ def select_error_filter(self, error_filter_list: list[ErrorFilter]) -> Self: # pylint: disable=unused-argument def fit( self, - values: TypeFixedVarSeq, - labels: Any = None, - **params: Any, + values: TypeFixedVarSeq, # noqa: ARG002 + labels: Any = None, # noqa: ARG002 + **params: Any, # noqa: ARG002 ) -> Self: """Fit to input values. @@ -615,8 +651,9 @@ def fit( ---------- values: TypeFixedVarSeq Values used for fitting. - labels: Any - Label used for fitting. (Not used, but required for compatibility with sklearn) + labels: Any, optional + Label used for fitting. + (Not used, but required for compatibility with sklearn) **params: Any Additional keyword arguments. (Not used) @@ -624,6 +661,7 @@ def fit( ------- Self Fitted FilterReinserter. + """ return self @@ -631,8 +669,8 @@ def fit( def fit_transform( self, values: TypeFixedVarSeq, - labels: Any = None, - **params: Any, + labels: Any = None, # noqa: ARG002 + **params: Any, # noqa: ARG002 ) -> TypeFixedVarSeq: """Transform values and return a list without the Invalid values. @@ -642,8 +680,9 @@ def fit_transform( ---------- values: TypeFixedVarSeq Iterable to which element is fitted and which is subsequently transformed. - labels: Any - Label used for fitting. (Not used, but required for compatibility with sklearn) + labels: Any, optional + Label used for fitting. + (Not used, but required for compatibility with sklearn) **params: Any Additional keyword arguments. (Not used) @@ -651,49 +690,15 @@ def fit_transform( ------- TypeFixedVarSeq Iterable where invalid instances were removed. + """ self.fit(values) return self.transform(values) - def transform_single(self, value: Any) -> Any: - """Transform a single value. - - Parameters - ---------- - value: Any - Value to be transformed. - - Returns - ------- - Any - Transformed value. - """ - return self.pretransform_single(value) - - def pretransform_single(self, value: Any) -> Any: - """Transform a single value. - - Parameters - ---------- - value: Any - Value to be transformed. - - Returns - ------- - Any - Transformed value. - """ - if ( - isinstance(value, RemovedInstance) - and value.filter_element_id == self.error_filter.uuid - ): - return self.fill_value - return value - def transform( self, values: TypeFixedVarSeq, - **params: Any, + **params: Any, # noqa: ARG002 ) -> TypeFixedVarSeq: """Transform iterable of values by removing invalid instances. @@ -718,11 +723,13 @@ def transform( """ if len(values) != self.error_filter.n_total - len( - self.error_filter.error_indices + self.error_filter.error_indices, ): raise ValueError( f"Length of values does not match length of values in fit. " - f"Expected: {self.error_filter.n_total - len(self.error_filter.error_indices)} - Received :{len(values)}" + f"Expected: " + f"{self.error_filter.n_total - len(self.error_filter.error_indices)} " + f"- Received :{len(values)}", ) return self.fill_with_dummy(values) @@ -756,7 +763,7 @@ def _fill_list(self, list_to_fill: Sequence[_S]) -> Sequence[_S | _T]: next_value_pos += 1 if len(list_to_fill) != next_value_pos: raise AssertionError( - "Length of list does not match length of values in fit" + "Length of list does not match length of values in fit", ) return filled_list @@ -771,7 +778,9 @@ def _fill_numpy_arr(self, value_array: npt.NDArray[Any]) -> npt.NDArray[Any]: Returns ------- npt.NDArray[Any] - Numpy array where dummy values were inserted to replace instances which could not be processed. + Numpy array where dummy values were inserted to replace instances which + could not be processed. + """ fill_value = self.fill_value output_shape = list(value_array.shape) @@ -808,7 +817,8 @@ def fill_with_dummy( Returns ------- AnyVarSeq - Iterable where dummy values were inserted to replace molecules which could not be processed. + Iterable where dummy values were inserted to replace molecules which could + not be processed. """ if isinstance(value_container, list): diff --git a/molpipeline/mol2any/mol2bool.py b/molpipeline/mol2any/mol2bool.py index c605a6ba..1529e79b 100644 --- a/molpipeline/mol2any/mol2bool.py +++ b/molpipeline/mol2any/mol2bool.py @@ -26,16 +26,16 @@ def pretransform_single(self, value: Any) -> bool: ------- str Binary representation of molecule. + """ - if isinstance(value, InvalidInstance): - return False - return True + return not isinstance(value, InvalidInstance) def transform_single(self, value: Any) -> Any: """Transform a single molecule to a bool representation. Valid molecules are passed as True, InvalidInstances are passed as False. - RemovedMolecule objects are passed without change, as no transformations are applicable. + RemovedMolecule objects are passed without change, as no transformations are + applicable. Parameters ---------- @@ -46,6 +46,6 @@ def transform_single(self, value: Any) -> Any: ------- Any Bool representation of the molecule. + """ - pre_value = self.pretransform_single(value) - return self.finalize_single(pre_value) + return self.pretransform_single(value) diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index 5955b4d1..fc5a5dd0 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -1,9 +1,8 @@ -"""Classes for creating arrays from multiple concatenated descriptors or fingerprints.""" +"""Classes for descriptors from multiple descriptors or fingerprints.""" from __future__ import annotations -from collections.abc import Iterable -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import numpy as np import numpy.typing as npt @@ -17,11 +16,15 @@ from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( MolToFingerprintPipelineElement, ) -from molpipeline.utils.molpipeline_types import RDKitMol + +if TYPE_CHECKING: + from collections.abc import Iterable + + from molpipeline.utils.molpipeline_types import RDKitMol class MolToConcatenatedVector(MolToAnyPipelineElement): - """Creates a concatenated descriptor vectored from multiple MolToAny PipelineElements.""" + """A concatenated descriptor vector from multiple MolToAny PipelineElements.""" _element_list: list[tuple[str, MolToAnyPipelineElement]] @@ -52,7 +55,8 @@ def __init__( uuid: str | None, optional (default=None) UUID of the pipeline element. If None, a random UUID is generated. kwargs: Any - Additional keyword arguments. Can be used to set parameters of the pipeline elements. + Additional keyword arguments. + Can be used to set parameters of the pipeline elements. Raises ------ @@ -66,7 +70,7 @@ def __init__( self._use_feature_names_prefix = use_feature_names_prefix super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) # set element execution details - self._set_element_execution_details(self._element_list) + self._set_element_execution_details() # set feature names self._feature_names = self._create_feature_names( self._element_list, @@ -156,18 +160,8 @@ def _create_feature_names( ) return feature_names - def _set_element_execution_details( - self, - element_list: list[tuple[str, MolToAnyPipelineElement]], - ) -> None: - """Set output type and requires fitting for the concatenated vector. - - Parameters - ---------- - element_list: list[tuple[str, MolToAnyPipelineElement]] - List of pipeline elements. - - """ + def _set_element_execution_details(self) -> None: + """Set output type and requires fitting for the concatenated vector.""" output_types = set() for _, element in self._element_list: element.n_jobs = self.n_jobs @@ -176,10 +170,6 @@ def _set_element_execution_details( self._output_type = output_types.pop() else: self._output_type = "mixed" - self._requires_fitting = any( - element[1]._requires_fitting # pylint: disable=protected-access - for element in element_list - ) def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. @@ -243,7 +233,7 @@ def _set_element_list( if len(element_list) == 0: raise ValueError("element_list must contain at least one element.") # reset element execution details - self._set_element_execution_details(self._element_list) + self._set_element_execution_details() step_params: dict[str, dict[str, Any]] = {} step_dict = dict(self._element_list) to_delete_list = [] @@ -311,12 +301,14 @@ def assemble_output( Parameters ---------- value_list: Iterable[npt.NDArray[np.float64]] - List of molecular descriptors or fingerprints which are concatenated to a single matrix. + List of molecular descriptors or fingerprints which are concatenated to a + single matrix. Returns ------- npt.NDArray[np.float64] - Matrix of shape (n_molecules, n_features) with concatenated features specified during init. + Matrix of shape (n_molecules, n_features) with concatenated features + specified during init. """ return np.vstack(list(value_list)) @@ -332,7 +324,8 @@ def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: Returns ------- npt.NDArray[np.float64] - Matrix of shape (n_molecules, n_features) with concatenated features specified during init. + Matrix of shape (n_molecules, n_features) with concatenated features + specified during init. """ output: npt.NDArray[np.float64] = super().transform(values) @@ -341,14 +334,14 @@ def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: def fit( self, values: list[RDKitMol], - labels: Any = None, + labels: Any = None, # noqa: ARG002 ) -> Self: """Fit each pipeline element. Parameters ---------- values: list[RDKitMol] - List of molecules used to fit the pipeline elements creating the concatenated vector. + List of molecules used to fit the pipeline elements. labels: Any Labels for the molecules. Not used. @@ -365,7 +358,7 @@ def fit( def pretransform_single( self, value: RDKitMol, - ) -> list[npt.NDArray[np.float64] | dict[int, int]] | InvalidInstance: + ) -> npt.NDArray[np.float64] | InvalidInstance: """Get pretransform of each element and concatenate for output. Parameters @@ -380,64 +373,29 @@ def pretransform_single( If any element returns None, InvalidInstance is returned. """ - final_vector = [] + transfored_list = [] error_message = "" for name, pipeline_element in self._element_list: - vector = pipeline_element.pretransform_single(value) - if isinstance(vector, InvalidInstance): + transformed_value = pipeline_element.pretransform_single(value) + if isinstance(transformed_value, InvalidInstance): error_message += f"{self.name}__{name} returned an InvalidInstance." break - - final_vector.append(vector) - else: # no break - return final_vector - return InvalidInstance(self.uuid, error_message, self.name) - - def finalize_single(self, value: Any) -> Any: - """Finalize the output of transform_single. - - Parameters - ---------- - value: Any - Output of transform_single. - - Returns - ------- - Any - Finalized output. - - """ - final_vector_list = [] - for (_, element), sub_value in zip(self._element_list, value, strict=True): - final_value = element.finalize_single(sub_value) - if isinstance(element, MolToFingerprintPipelineElement) and isinstance( - final_value, + if isinstance( + pipeline_element, + MolToFingerprintPipelineElement, + ) and isinstance( + transformed_value, dict, ): - vector = np.zeros(element.n_bits) - vector[list(final_value.keys())] = np.array(list(final_value.values())) + vector = np.zeros(pipeline_element.n_bits) + vector[list(transformed_value.keys())] = np.array( + list(transformed_value.values()), + ) final_value = vector - if not isinstance(final_value, np.ndarray): - final_value = np.array(final_value) - final_vector_list.append(final_value) - return np.hstack(final_vector_list) - - def fit_to_result(self, values: Any) -> Self: - """Fit the pipeline element to the result of transform_single. - - Parameters - ---------- - values: Any - Output of transform_single. - - Returns - ------- - Self - Fitted pipeline element. + else: + final_value = np.array(transformed_value) - """ - for element, value in zip( - self._element_list, zip(*values, strict=True), strict=True - ): - element[1].fit_to_result(value) - return self + transfored_list.append(final_value) + else: # no break + return np.hstack(transfored_list) + return InvalidInstance(self.uuid, error_message, self.name) diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py deleted file mode 100644 index 3a9f2def..00000000 --- a/molpipeline/pipeline/_molpipeline.py +++ /dev/null @@ -1,437 +0,0 @@ -"""Defines the pipeline which handles pipeline elements for molecular operations.""" - -from __future__ import annotations - -from collections.abc import Iterable -from typing import Any, Self - -import numpy as np -from joblib import Parallel, delayed -from rdkit.Chem.rdchem import MolSanitizeException -from rdkit.rdBase import BlockLogs - -from molpipeline.abstract_pipeline_elements.core import ( - ABCPipelineElement, - InvalidInstance, - RemovedInstance, - TransformingPipelineElement, -) -from molpipeline.error_handling import ( - ErrorFilter, - FilterReinserter, - _MultipleErrorFilter, -) -from molpipeline.utils.molpipeline_types import TypeFixedVarSeq -from molpipeline.utils.multi_proc import check_available_cores - - -class _MolPipeline: - """Contains the PipeElements which describe the functionality of the pipeline.""" - - _n_jobs: int - _element_list: list[ABCPipelineElement] - _requires_fitting: bool - - def __init__( - self, - element_list: list[ABCPipelineElement], - n_jobs: int = 1, - name: str = "MolPipeline", - ) -> None: - """Initialize MolPipeline. - - Parameters - ---------- - element_list: list[ABCPipelineElement] - List of Pipeline Elements which form the pipeline. - n_jobs: - Number of cores used. - name: str - Name of pipeline. - - """ - self._element_list = element_list - self.n_jobs = n_jobs - self.name = name - self._requires_fitting = any( - element.requires_fitting for element in self._element_list - ) - - @property - def _filter_elements(self) -> list[ErrorFilter]: - """Get the elements which filter the input.""" - return [ - element - for element in self._element_list - if isinstance(element, ErrorFilter) - ] - - @property - def _filter_elements_agg(self) -> _MultipleErrorFilter: - """Get the aggregated filter element.""" - return _MultipleErrorFilter(self._filter_elements) - - @property - def _transforming_elements( - self, - ) -> list[TransformingPipelineElement | _MolPipeline]: - """Get the elements which transform the input.""" - return [ - element - for element in self._element_list - if isinstance(element, (TransformingPipelineElement, _MolPipeline)) - ] - - @property - def n_jobs(self) -> int: - """Return the number of cores to use in transformation step.""" - return self._n_jobs - - @n_jobs.setter - def n_jobs(self, requested_jobs: int) -> None: - """Set the number of cores to use in transformation step. - - Parameters - ---------- - requested_jobs: int - Number of cores requested for transformation steps. - If fewer cores than requested are available, the number of cores is set to maximum available. - - """ - self._n_jobs = check_available_cores(requested_jobs) - - @property - def parameters(self) -> dict[str, Any]: - """Get all parameters defining the object.""" - return self.get_params() - - @parameters.setter - def parameters(self, parameter_dict: dict[str, Any]) -> None: - """Set parameters of the pipeline and pipeline elements. - - Parameters - ---------- - parameter_dict: dict[str, Any] - Dictionary containing the parameter names and corresponding values to be set. - - """ - self.set_params(**parameter_dict) - - @property - def requires_fitting(self) -> bool: - """Return whether the pipeline requires fitting.""" - return self._requires_fitting - - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Get all parameters defining the object. - - Parameters - ---------- - deep: bool - If True get a deep copy of the parameters. - - Returns - ------- - dict[str, Any] - Dictionary containing the parameter names and corresponding values. - """ - if deep: - return { - "element_list": self.element_list, - "n_jobs": self.n_jobs, - "name": self.name, - } - return { - "element_list": self._element_list, - "n_jobs": self.n_jobs, - "name": self.name, - } - - def set_params(self, **parameter_dict: Any) -> Self: - """Set parameters of the pipeline and pipeline elements. - - Parameters - ---------- - parameter_dict: Any - Dictionary containing the parameter names and corresponding values to be set. - - Returns - ------- - Self - MolPipeline object with updated parameters. - """ - if "element_list" in parameter_dict: - self._element_list = parameter_dict["element_list"] - if "n_jobs" in parameter_dict: - self.n_jobs = int(parameter_dict["n_jobs"]) - if "name" in parameter_dict: - self.name = str(parameter_dict["name"]) - return self - - @property - def element_list(self) -> list[ABCPipelineElement]: - """Get a shallow copy from the list of pipeline elements.""" - return self._element_list[:] # [:] to create shallow copy. - - def _get_meta_element_list( - self, - ) -> list[ABCPipelineElement | _MolPipeline]: - """Merge elements which do not require fitting to a meta element which improves parallelization. - - Returns - ------- - list[ABCPipelineElement | _MolPipeline] - List of pipeline elements and meta elements. - """ - meta_element_list: list[ABCPipelineElement | _MolPipeline] = [] - no_fit_element_list: list[ABCPipelineElement] = [] - for element in self._element_list: - if ( - isinstance(element, TransformingPipelineElement) - and not element.requires_fitting - ): - no_fit_element_list.append(element) - else: - if len(no_fit_element_list) == 1: - meta_element_list.append(no_fit_element_list[0]) - elif len(no_fit_element_list) > 1: - meta_element_list.append( - _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs) - ) - no_fit_element_list = [] - meta_element_list.append(element) - if len(no_fit_element_list) == 1: - meta_element_list.append(no_fit_element_list[0]) - elif len(no_fit_element_list) > 1: - meta_element_list.append( - _MolPipeline(no_fit_element_list, n_jobs=self.n_jobs) - ) - return meta_element_list - - def fit( - self, - x_input: Any, - y: Any = None, - **fit_params: dict[Any, Any], - ) -> Self: - """Fit the MolPipeline according to x_input. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently processed. - y: Any - Optional label of input. Only for SKlearn compatibility. - fit_params: Any - Parameters. Only for SKlearn compatibility. - - Returns - ------- - Self - Fitted MolPipeline. - """ - _ = y # Making pylint happy - _ = fit_params # Making pylint happy - if self.requires_fitting: - self.fit_transform(x_input) - return self - - def fit_transform( # pylint: disable=unused-argument - self, - x_input: Any, - y: Any = None, - **fit_params: dict[str, Any], - ) -> Any: - """Fit the MolPipeline according to x_input and return the transformed molecules. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently processed. - y: Any - Optional label of input. Only for SKlearn compatibility. - fit_params: Any - Parameters. Only for SKlearn compatibility. - - Raises - ------ - AssertionError - If a subpipeline requires fitting, which by definition should not be the - case. - - Returns - ------- - Any - Transformed molecules. - """ - iter_input = x_input - - removed_rows: dict[ErrorFilter, list[int]] = {} - for error_filter in self._filter_elements: - removed_rows[error_filter] = [] - iter_idx_array = np.arange(len(iter_input)) - - # The meta elements merge steps which do not require fitting to improve parallelization - for i_element in self._get_meta_element_list(): - if not isinstance(i_element, (TransformingPipelineElement, _MolPipeline)): - continue - i_element.n_jobs = self.n_jobs - iter_input = i_element.pretransform(iter_input) - for error_filter in self._filter_elements: - iter_input = error_filter.transform(iter_input) - for idx in error_filter.error_indices: - idx = iter_idx_array[idx] - removed_rows[error_filter].append(idx) - iter_idx_array = error_filter.co_transform(iter_idx_array) - if i_element.requires_fitting: - if isinstance(i_element, _MolPipeline): - raise AssertionError("No subpipline should require fitting!") - i_element.fit_to_result(iter_input) - if isinstance(i_element, TransformingPipelineElement): - iter_input = i_element.finalize_list(iter_input) - iter_input = i_element.assemble_output(iter_input) - i_element.n_jobs = 1 - - # Set removed rows to filter elements to allow for correct co_transform - iter_idx_array = np.arange(len(x_input)) - for error_filter in self._filter_elements: - removed_idx_list = removed_rows[error_filter] - error_filter.error_indices = [] - for new_idx, _idx in enumerate(iter_idx_array): - if _idx in removed_idx_list: - error_filter.error_indices.append(new_idx) - error_filter.n_total = len(iter_idx_array) - iter_idx_array = error_filter.co_transform(iter_idx_array) - error_replacer_list = [ - ele for ele in self._element_list if isinstance(ele, FilterReinserter) - ] - for error_replacer in error_replacer_list: - error_replacer.select_error_filter(self._filter_elements) - iter_input = error_replacer.transform(iter_input) - return iter_input - - def transform_single(self, input_value: Any) -> Any: - """Transform a single input according to the sequence of provided PipelineElements. - - Parameters - ---------- - input_value: Any - Molecular representation which is subsequently transformed. - - Returns - ------- - Any - Transformed molecular representation. - """ - log_block = BlockLogs() - iter_value = input_value - for p_element in self._element_list: - try: - if not isinstance(iter_value, RemovedInstance): - iter_value = p_element.transform_single(iter_value) - elif isinstance(p_element, FilterReinserter): - iter_value = p_element.transform_single(iter_value) - except MolSanitizeException as err: - iter_value = InvalidInstance( - p_element.uuid, - f"RDKit MolSanitizeException: {err.args}", - p_element.name, - ) - del log_block - return iter_value - - def pretransform(self, x_input: Any) -> Any: - """Transform the input according to the sequence BUT skip the assemble output step. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently transformed. - - Returns - ------- - Any - Transformed molecular representations. - """ - return list(self._transform_iterator(x_input)) - - def transform(self, x_input: Any) -> Any: - """Transform the input according to the sequence of provided PipelineElements. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently transformed. - - Returns - ------- - Any - Transformed molecular representations. - """ - output_generator = self._transform_iterator(x_input) - return self.assemble_output(output_generator) - - def assemble_output(self, value_list: Iterable[Any]) -> Any: - """Assemble the output of the pipeline. - - Parameters - ---------- - value_list: Iterable[Any] - Generator which yields the output of the pipeline. - - Returns - ------- - Any - Assembled output. - """ - last_element = self._transforming_elements[-1] - if hasattr(last_element, "assemble_output"): - return last_element.assemble_output(value_list) - return list(value_list) - - def _transform_iterator(self, x_input: Any) -> Any: - """Transform the input according to the sequence of provided PipelineElements. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently transformed. - - Yields - ------ - Any - Transformed molecular representations. - - """ - agg_filter = self._filter_elements_agg - for filter_element in self._filter_elements: - filter_element.error_indices = [] - parallel = Parallel( - n_jobs=self.n_jobs, - return_as="generator", - batch_size="auto", - ) - output_generator = parallel( - delayed(self.transform_single)(value) for value in x_input - ) - for i, transformed_value in enumerate(output_generator): - if isinstance(transformed_value, RemovedInstance): - agg_filter.register_removed(i, transformed_value) - else: - yield transformed_value - agg_filter.set_total(len(x_input)) - - def co_transform(self, x_input: TypeFixedVarSeq) -> TypeFixedVarSeq: - """Filter flagged rows from the input. - - Parameters - ---------- - x_input: Any - Molecular representations which are subsequently filtered. - - Returns - ------- - Any - Filtered molecular representations. - """ - return self._filter_elements_agg.co_transform(x_input) diff --git a/molpipeline/pipeline/_skl_adapter_pipeline.py b/molpipeline/pipeline/_skl_adapter_pipeline.py new file mode 100644 index 00000000..d6299be0 --- /dev/null +++ b/molpipeline/pipeline/_skl_adapter_pipeline.py @@ -0,0 +1,855 @@ +"""Module to change functions of the sklearn pipeline.""" + +from __future__ import annotations + +from itertools import islice +from typing import TYPE_CHECKING, Any, Self + +import numpy as np +import numpy.typing as npt +from loguru import logger +from sklearn.base import _fit_context, clone # noqa: PLC2701 +from sklearn.pipeline import Pipeline as _Pipeline +from sklearn.pipeline import ( + _final_estimator_has, # noqa: PLC2701 + _fit_transform_one, # noqa: PLC2701 + _raise_or_warn_if_not_fitted, # noqa: PLC2701 +) +from sklearn.utils.metadata_routing import ( + MetadataRouter, + MethodMapping, + _routing_enabled, # noqa: PLC2701 + process_routing, +) +from sklearn.utils.metaestimators import available_if +from sklearn.utils.validation import check_memory + +from molpipeline.error_handling import ErrorFilter, FilterReinserter +from molpipeline.post_prediction import ( + PostPredictionTransformation, + PostPredictionWrapper, +) +from molpipeline.utils.logging import print_elapsed_time +from molpipeline.utils.molpipeline_types import ( + AnyElement, + AnyStep, +) +from molpipeline.utils.value_checks import is_empty + +if TYPE_CHECKING: + from collections.abc import Generator + + import joblib + from sklearn.utils import Bunch + + +_IndexedStep = tuple[int, str, AnyElement] + + +class AdapterPipeline(_Pipeline): + """Defines the pipeline which handles pipeline elements.""" + + steps: list[AnyStep] + # * Adapted methods from sklearn.pipeline.Pipeline * + + @property + def _estimator_type(self) -> Any: + """Return the estimator type.""" + if self._final_estimator is None or self._final_estimator == "passthrough": + return None + if hasattr(self._final_estimator, "_estimator_type"): + return self._final_estimator._estimator_type # noqa: SLF001 # pylint: disable=protected-access + return None + + @property + def _modified_steps( + self, + ) -> list[AnyStep]: + """Return modified version of steps. + + Returns only steps before the first PostPredictionTransformation. + + Raises + ------ + AssertionError + If a PostPredictionTransformation is found before the last step. + + """ + non_post_processing_steps: list[AnyStep] = [] + start_adding = False + for step_name, step_estimator in self.steps[::-1]: + if not isinstance(step_estimator, PostPredictionTransformation): + start_adding = True + if start_adding: + if isinstance(step_estimator, PostPredictionTransformation): + raise AssertionError( + "PipelineElement of type PostPredictionTransformation occured " + "before the last step.", + ) + non_post_processing_steps.append((step_name, step_estimator)) + return list(non_post_processing_steps[::-1]) + + @property + def _post_processing_steps(self) -> list[tuple[str, PostPredictionTransformation]]: + """Return last steps which are PostPredictionTransformation.""" + post_processing_steps = [] + for step_name, step_estimator in self.steps[::-1]: + if isinstance(step_estimator, PostPredictionTransformation): + post_processing_steps.append((step_name, step_estimator)) + else: + break + return list(post_processing_steps[::-1]) + + @property + def classes_(self) -> list[Any] | npt.NDArray[Any]: + """Return the classes of the last element. + + PostPredictionTransformation elements are not considered as last element. + + Raises + ------ + ValueError + If the last step is passthrough or has no classes_ attribute. + + """ + check_last = [ + step + for step in self.steps + if not isinstance(step[1], PostPredictionTransformation) + ] + last_step = check_last[-1][1] + if last_step == "passthrough": + raise ValueError("Last step is passthrough.") + if hasattr(last_step, "classes_"): + return last_step.classes_ + raise ValueError("Last step has no classes_ attribute.") + + def __init__( + self, + steps: list[AnyStep], + *, + memory: str | joblib.Memory | None = None, + verbose: bool = False, + n_jobs: int = 1, + ): + """Initialize Pipeline. + + Parameters + ---------- + steps: list[tuple[str, AnyTransformer | AnyPredictor | ABCPipelineElement]] + List of (name, Estimator) tuples. + memory: str | joblib.Memory | None, optional + Path to cache transformers. + verbose: bool, optional + If True, print additional information. + n_jobs: int, optional + Number of cores used for aggregated steps. + + """ + super().__init__(steps, memory=memory, verbose=verbose) + self.n_jobs = n_jobs + self._set_error_resinserter() + + def _set_error_resinserter(self) -> None: + """Connect the error resinserters with the error filters.""" + error_replacer_list = [ + e_filler + for _, e_filler in self.steps + if isinstance(e_filler, FilterReinserter) + ] + error_filter_list = [ + n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter) + ] + for step in self.steps: + if isinstance(step[1], PostPredictionWrapper): # noqa: SIM102 + if isinstance(step[1].wrapped_estimator, FilterReinserter): + error_replacer_list.append(step[1].wrapped_estimator) + for error_replacer in error_replacer_list: + error_replacer.select_error_filter(error_filter_list) + + def _validate_steps(self) -> None: + """Validate the steps. + + Raises + ------ + TypeError + If the steps do not implement fit and transform or are not 'passthrough'. + + """ + names = [name for name, _ in self.steps] + + # validate names + self._validate_names(names) + + # validate estimators + estimator = self._modified_steps[-1][1] + + for _, transformer in self._modified_steps[:-1]: + if transformer is None or transformer == "passthrough": + continue + if not ( + hasattr(transformer, "fit") or hasattr(transformer, "fit_transform") + ) or not hasattr(transformer, "transform"): + raise TypeError( + f"All intermediate steps should be " + f"transformers and implement fit and transform " + f"or be the string 'passthrough' " + f"'{transformer}' (type {type(transformer)}) doesn't", + ) + + # We allow last estimator to be None as an identity transformation + if ( + estimator is not None + and estimator != "passthrough" + and not hasattr(estimator, "fit") + ): + raise TypeError( + f"Last step of Pipeline should implement fit " + f"or be the string 'passthrough'. " + f"'{estimator}' (type {type(estimator)}) doesn't", + ) + + # validate post-processing steps + # Calling steps automatically validates them + _ = self._post_processing_steps + + def _iter( + self, + with_final: bool = True, + filter_passthrough: bool = True, + ) -> Generator[ + tuple[int, str, AnyElement] | tuple[list[int], list[str], AdapterPipeline], + Any, + None, + ]: + """Iterate over all non post-processing steps. + + Steps which are children of a ABCPipelineElement were aggregated to a + MolPipeline. + + Parameters + ---------- + with_final: bool, optional + If True, the final estimator is included. + filter_passthrough: bool, optional + If True, passthrough steps are filtered out. + + Yields + ------ + _AggregatedPipelineStep + The _AggregatedPipelineStep is composed of the index, the name and the + transformer. + + """ + stop = len(self._modified_steps) + if not with_final: + stop -= 1 + + for idx, (name, trans) in enumerate(islice(self._modified_steps, 0, stop)): + if not filter_passthrough or (trans is not None and trans != "passthrough"): + yield idx, name, trans + + # pylint: disable=too-many-locals,too-many-branches + def _fit( # noqa: PLR0912 + self, + X: Any, # noqa: N803 + y: Any = None, + routed_params: dict[str, Any] | None = None, + raw_params: dict[str, Any] | None = None, + ) -> tuple[Any, Any]: + """Fit the model by fitting all transformers except the final estimator. + + Data can be subsetted by the transformers. + + Parameters + ---------- + X : Any + Training data. + y : Any, optional (default=None) + Training objectives. + routed_params : dict[str, Any], optional + Parameters for each step as returned by process_routing. + Although this is marked as optional, it should not be None. + The awkward (argward?) typing is due to inheritance from sklearn. + Can be an empty dictionary. + raw_params : dict[str, Any], optional + Parameters passed by the user, used when `transform_input` + + Raises + ------ + AssertionError + If routed_params is None or if the transformer is 'passthrough'. + AssertionError + If the names are a list and the step is not a Pipeline. + + Returns + ------- + tuple[Any, Any] + The transformed data and the transformed objectives. + + """ + # shallow copy of steps - this should really be steps_ + self.steps = list(self.steps) + self._validate_steps() + if routed_params is None: + raise AssertionError("routed_params should not be None.") + + # Set up the memory + memory: joblib.Memory = check_memory(self.memory) + + fit_transform_one_cached = memory.cache(_fit_transform_one) + for step in self._iter(with_final=False, filter_passthrough=False): + step_idx, name, transformer = step + if transformer is None or transformer == "passthrough": + with print_elapsed_time("Pipeline", self._log_message(step_idx)): + continue + + if hasattr(memory, "location") and memory.location is None: + # we do not clone when caching is disabled to + # preserve backward compatibility + cloned_transformer = transformer + else: + cloned_transformer = clone(transformer) + if isinstance(name, list): + if not isinstance(transformer, _Pipeline): + raise AssertionError( + "If the name is a list, the transformer must be a Pipeline.", + ) + if routed_params: + step_params = { + "element_parameters": [routed_params[n] for n in name], + } + else: + step_params = {} + else: + step_params = self._get_metadata_for_step( + step_idx=step_idx, + step_params=routed_params[name], + all_params=raw_params, + ) + + # Fit or load from cache the current transformer + X, fitted_transformer = fit_transform_one_cached( # noqa: N806 # type: ignore + cloned_transformer, + X, + y, + None, + message_clsname="Pipeline", + message=self._log_message(step_idx), + params=step_params, + ) + # Replace the transformer of the step with the fitted + # transformer. This is necessary when loading the transformer + # from the cache. + if isinstance(fitted_transformer, AdapterPipeline): + ele_list = [step[1] for step in fitted_transformer.steps] + if not isinstance(name, list) or not isinstance(step_idx, list): + raise AssertionError() + if not len(name) == len(step_idx) == len(ele_list): + raise AssertionError() + for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): + self.steps[idx_i] = (name_i, ele_i) + if y is not None: + y = fitted_transformer.co_transform(y) + for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): + self.steps[idx_i] = (name_i, ele_i) + self._set_error_resinserter() + elif isinstance(name, list) or isinstance(step_idx, list): + raise AssertionError() + else: + self.steps[step_idx] = (name, fitted_transformer) + if is_empty(X): + return np.array([]), np.array([]) + return X, y + + def _transform( + self, + X: Any, # pylint: disable=invalid-name # noqa: N803 + routed_params: Bunch, + ) -> Any: + """Transform the data, and skip final estimator. + + Call `transform` of each transformer in the pipeline except the last one, + + Parameters + ---------- + X : iterable + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + routed_params: Bunch + parameters for each step as returned by process_routing + + Raises + ------ + AssertionError + If one of the transformers is 'passthrough' or does not implement + `transform`. + + Returns + ------- + Any + Result of calling `transform` on the second last estimator. + + """ + iter_input = X + do_routing = _routing_enabled() + if do_routing: + logger.warning("Routing is enabled and NOT fully tested!") + + for _, name, transform in self._iter(with_final=False): + if transform == "passthrough": + raise AssertionError("Passthrough should have been filtered out.") + if hasattr(transform, "transform"): + if do_routing: + iter_input = transform.transform( # type: ignore[call-arg] + iter_input, + routed_params[name].transform, + ) + else: + iter_input = transform.transform(iter_input) + else: + raise AssertionError( + f"Non transformer ocurred in transformation step: {transform}.", + ) + return iter_input + + @_fit_context( + # estimators in Pipeline.steps are not validated yet + prefer_skip_nested_validation=False, + ) + def fit( + self, + X: Any, # noqa: N803 + y: Any = None, + **fit_params: Any, + ) -> Self: + """Fit the model. + + Fit all the transformers one after the other and transform the + data. Finally, fit the transformed data using the final estimator. + + Parameters + ---------- + X : iterable + Training data. Must fulfill input requirements of first step of the + pipeline. + + y : iterable, default=None + Training targets. Must fulfill label requirements for all steps of + the pipeline. + + **fit_params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Returns + ------- + self : object + Pipeline with fitted steps. + + """ + routed_params = self._check_method_params(method="fit", props=fit_params) + xt, yt = self._fit(X, y, routed_params) + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + if self._final_estimator != "passthrough": + if is_empty(xt): + logger.warning( + "All input rows were filtered out! Model is not fitted!", + ) + else: + fit_params_last_step = routed_params[self._modified_steps[-1][0]] + self._final_estimator.fit(xt, yt, **fit_params_last_step["fit"]) + + return self + + # pylint: disable=duplicate-code + def _can_fit_transform(self) -> bool: + """Check if the final estimator can fit_transform or is passthrough. + + Returns + ------- + bool + True if the final estimator can fit_transform or is passthrough. + + """ + return ( + self._final_estimator == "passthrough" + or hasattr(self._final_estimator, "transform") + or hasattr(self._final_estimator, "fit_transform") + ) + + # pylint: enable=duplicate-code + + @available_if(_can_fit_transform) + @_fit_context( + # estimators in Pipeline.steps are not validated yet + prefer_skip_nested_validation=False, + ) + def fit_transform( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: + """Fit the model and transform with the final estimator. + + Fits all the transformers one after the other and transform the + data. Then uses `fit_transform` on transformed data with the final + estimator. + + Parameters + ---------- + X : iterable + Training data. Must fulfill input requirements of first step of the + pipeline. + + y : iterable, default=None + Training targets. Must fulfill label requirements for all steps of + the pipeline. + + **params : Any + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Raises + ------ + TypeError + If the last step does not implement `fit_transform` or `fit` and + `transform`. + + Returns + ------- + Xt : ndarray of shape (n_samples, n_transformed_features) + Transformed samples. + + """ + routed_params = self._check_method_params(method="fit_transform", props=params) + iter_input, iter_label = self._fit(X, y, routed_params) + last_step = self._final_estimator + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + if last_step == "passthrough": + pass + elif is_empty(iter_input): + logger.warning("All input rows were filtered out! Model is not fitted!") + else: + last_step_params = routed_params[self._modified_steps[-1][0]] + if hasattr(last_step, "fit_transform"): + iter_input = last_step.fit_transform( + iter_input, + iter_label, + **last_step_params["fit_transform"], + ) + elif hasattr(last_step, "transform") and hasattr(last_step, "fit"): + last_step.fit(iter_input, iter_label, **last_step_params["fit"]) + iter_input = last_step.transform( + iter_input, + **last_step_params["transform"], + ) + else: + raise TypeError( + f"fit_transform of the final estimator" + f" {last_step.__class__.__name__} {last_step_params} does not " + f"match fit_transform of Pipeline {self.__class__.__name__}", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.fit_transform(iter_input, iter_label) + return iter_input + + @available_if(_final_estimator_has("fit_predict")) + @_fit_context( + # estimators in Pipeline.steps are not validated yet + prefer_skip_nested_validation=False, + ) + def fit_predict( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: + """Transform the data, and apply `fit_predict` with the final estimator. + + Call `fit_transform` of each transformer in the pipeline. The + transformed data are finally passed to the final estimator that calls + `fit_predict` method. Only valid if the final estimator implements + `fit_predict`. + + Parameters + ---------- + X : iterable + Training data. Must fulfill input requirements of first step of + the pipeline. + + y : iterable, default=None + Training targets. Must fulfill label requirements for all steps + of the pipeline. + + **params : dict of string -> object + Parameters passed to the ``fit`` method of each step, where + each parameter name is prefixed such that parameter ``p`` for step + ``s`` has key ``s__p``. + + Raises + ------ + AssertionError + If the final estimator does not implement `fit_predict`. + In this case this function should not be available. + + Returns + ------- + y_pred : ndarray + Result of calling `fit_predict` on the final estimator. + + """ + routed_params = self._check_method_params(method="fit_predict", props=params) + iter_input, iter_label = self._fit(X, y, routed_params) + + params_last_step = routed_params[self._modified_steps[-1][0]] + with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): + if self._final_estimator == "passthrough": + y_pred = iter_input + elif is_empty(iter_input): + logger.warning("All input rows were filtered out! Model is not fitted!") + iter_input = [] + y_pred = [] + elif hasattr(self._final_estimator, "fit_predict"): + y_pred = self._final_estimator.fit_predict( + iter_input, + iter_label, + **params_last_step.get("fit_predict", {}), + ) + else: + raise AssertionError( + "Final estimator does not implement fit_predict, " + "hence this function should not be available.", + ) + for _, post_element in self._post_processing_steps: + y_pred = post_element.fit_transform(y_pred, iter_label) + return y_pred + + @available_if(_final_estimator_has("predict")) + def predict( + self, + X: npt.NDArray[Any] | list[Any], # noqa: N803 + **params: Any, + ) -> npt.NDArray[Any] | list[Any]: + """Transform the data, and apply `predict` with the final estimator. + + Call `transform` of each transformer in the pipeline. The transformed + data are finally passed to the final estimator that calls `predict` + method. Only valid if the final estimator implements `predict`. + + Parameters + ---------- + X : npt.NDArray[Any] + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + **params : Any + If `enable_metadata_routing=False` (default): Parameters to the + ``predict`` called at the end of all transformations in the pipeline. + If `enable_metadata_routing=True`: Parameters requested and accepted by + steps. Each step must have requested certain metadata for these parameters + to be forwarded to them. + + Raises + ------ + AssertionError + If a step before the final estimator is 'passthrough' or does not + implement `transform`. + + Returns + ------- + y_pred : npt.NDArray[Any] + Result of calling `predict` on the final estimator. + + """ + iter_input = X + with _raise_or_warn_if_not_fitted(self): + if not _routing_enabled(): + for _, name, transform in self._iter(with_final=False): + if ( + not hasattr(transform, "transform") + or transform == "passthrough" + ): + raise AssertionError( + f"Non transformer occurred in transformation step: {name}.", + ) + iter_input = transform.transform(iter_input) + if is_empty(iter_input): + iter_input = [] + else: + iter_input = self._final_estimator.predict(iter_input, **params) + else: + # metadata routing enabled + routed_params = process_routing(self, "predict", **params) + for _, name, transform in self._iter(with_final=False): + if ( + not hasattr(transform, "transform") + or transform == "passthrough" + ): + raise AssertionError( + f"Non transformer occurred in transformation step: {name}.", + ) + iter_input = transform.transform( + iter_input, + **routed_params[name].transform, + ) + if is_empty(iter_input): + iter_input = [] + else: + iter_input = self._final_estimator.predict( + iter_input, + **routed_params[self.steps[-1][0]].predict, + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input, **params) + return iter_input + + @available_if(_final_estimator_has("predict_proba")) + def predict_proba( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: + """Transform the data, and apply `predict_proba` with the final estimator. + + Call `transform` of each transformer in the pipeline. The transformed + data are finally passed to the final estimator that calls `predict_proba` + method. Only valid if the final estimator implements `predict_proba`. + + Parameters + ---------- + X : iterable + Data to predict on. Must fulfill input requirements of first step + of the pipeline. + + **params : dict of string -> object + Parameters to the ``predict`` called at the end of all + transformations in the pipeline. Note that while this may be + used to return uncertainties from some models with return_std + or return_cov, uncertainties that are generated by the + transformations in the pipeline are not propagated to the + final estimator. + + Raises + ------ + AssertionError + If the final estimator does not implement `predict_proba`. + In this case this function should not be available. + + Returns + ------- + y_pred : ndarray + Result of calling `predict_proba` on the final estimator. + + """ + routed_params = process_routing(self, "predict_proba", **params) + iter_input = self._transform(X, routed_params) + + if self._final_estimator == "passthrough": + pass + elif is_empty(iter_input): + iter_input = [] + elif hasattr(self._final_estimator, "predict_proba"): + if _routing_enabled(): + iter_input = self._final_estimator.predict_proba( + iter_input, + **routed_params[self._modified_steps[-1][0]].predict_proba, + ) + else: + iter_input = self._final_estimator.predict_proba(iter_input, **params) + else: + raise AssertionError( + "Final estimator does not implement predict_proba, " + "hence this function should not be available.", + ) + for _, post_element in self._post_processing_steps: + iter_input = post_element.transform(iter_input) + return iter_input + + def get_metadata_routing(self) -> MetadataRouter: + """Get metadata routing of this object. + + Please check :ref:`User Guide ` on how the routing + mechanism works. + + Notes + ----- + This method is copied from the original sklearn implementation. + Changes are marked with a comment. + + Returns + ------- + MetadataRouter + A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating + routing information. + + """ + router = MetadataRouter(owner=self.__class__.__name__) + + # first we add all steps except the last one + for _, name, trans in self._iter(with_final=False, filter_passthrough=True): + method_mapping = MethodMapping() + # fit, fit_predict, and fit_transform call fit_transform if it + # exists, or else fit and transform + if hasattr(trans, "fit_transform"): + ( + method_mapping.add(caller="fit", callee="fit_transform") + .add(caller="fit_transform", callee="fit_transform") + .add(caller="fit_predict", callee="fit_transform") + ) + else: + ( + method_mapping.add(caller="fit", callee="fit") + .add(caller="fit", callee="transform") + .add(caller="fit_transform", callee="fit") + .add(caller="fit_transform", callee="transform") + .add(caller="fit_predict", callee="fit") + .add(caller="fit_predict", callee="transform") + ) + + ( + method_mapping.add(caller="predict", callee="transform") + .add(caller="predict", callee="transform") + .add(caller="predict_proba", callee="transform") + .add(caller="decision_function", callee="transform") + .add(caller="predict_log_proba", callee="transform") + .add(caller="transform", callee="transform") + .add(caller="inverse_transform", callee="inverse_transform") + .add(caller="score", callee="transform") + ) + + router.add(method_mapping=method_mapping, **{name: trans}) + + # Only the _non_post_processing_steps is changed from the original + # implementation is changed in the following line + final_name, final_est = self._modified_steps[-1] + if final_est is None or final_est == "passthrough": + return router + + # then we add the last step + method_mapping = MethodMapping() + if hasattr(final_est, "fit_transform"): + method_mapping.add(caller="fit_transform", callee="fit_transform") + else: + method_mapping.add(caller="fit", callee="fit").add( + caller="fit", + callee="transform", + ) + ( + method_mapping.add(caller="fit", callee="fit") + .add(caller="predict", callee="predict") + .add(caller="fit_predict", callee="fit_predict") + .add(caller="predict_proba", callee="predict_proba") + .add(caller="decision_function", callee="decision_function") + .add(caller="predict_log_proba", callee="predict_log_proba") + .add(caller="transform", callee="transform") + .add(caller="inverse_transform", callee="inverse_transform") + .add(caller="score", callee="score") + ) + + router.add(method_mapping=method_mapping, **{final_name: final_est}) + return router diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index 8a9d0507..6bc558b8 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -1,64 +1,131 @@ """Defines a pipeline is exposed to the user, accessible via pipeline.""" -# pylint: disable=too-many-lines - from __future__ import annotations -from collections.abc import Iterable from copy import deepcopy -from typing import Any, Literal, Self, TypeVar +from itertools import islice +from typing import TYPE_CHECKING, Any, Literal, Self, TypeIs -import joblib import numpy as np -import numpy.typing as npt -from loguru import logger -from sklearn.base import _fit_context, clone -from sklearn.pipeline import Pipeline as _Pipeline -from sklearn.pipeline import _final_estimator_has, _fit_transform_one -from sklearn.utils import Bunch -from sklearn.utils._tags import Tags, get_tags -from sklearn.utils.metadata_routing import ( - MetadataRouter, - MethodMapping, - _routing_enabled, - process_routing, -) +from joblib import Parallel, delayed +from rdkit.Chem.rdchem import MolSanitizeException +from rdkit.rdBase import BlockLogs +from sklearn.base import _fit_context # noqa: PLC2701 +from sklearn.utils._tags import Tags, get_tags # noqa: PLC2701 from sklearn.utils.metaestimators import available_if -from sklearn.utils.validation import check_memory - -from molpipeline.abstract_pipeline_elements.core import ABCPipelineElement -from molpipeline.error_handling import ErrorFilter, FilterReinserter -from molpipeline.pipeline._molpipeline import _MolPipeline -from molpipeline.post_prediction import ( - PostPredictionTransformation, - PostPredictionWrapper, + +from molpipeline.abstract_pipeline_elements.core import ( + ABCPipelineElement, + InvalidInstance, + RemovedInstance, + SingleInstanceTransformerMixin, + TransformingPipelineElement, ) -from molpipeline.utils.logging import print_elapsed_time +from molpipeline.error_handling import ( + ErrorFilter, + FilterReinserter, + _MultipleErrorFilter, +) +from molpipeline.pipeline._skl_adapter_pipeline import AdapterPipeline from molpipeline.utils.molpipeline_types import ( AnyElement, AnyPredictor, AnyStep, AnyTransformer, ) -from molpipeline.utils.value_checks import is_empty -__all__ = ["Pipeline"] +if TYPE_CHECKING: + from collections.abc import Generator, Iterable, Sequence -# Type definitions -_T = TypeVar("_T") -# Cannot be moved to utils.molpipeline_types due to circular imports + import joblib + + from molpipeline.utils.molpipeline_types import TypeFixedVarSeq + +__all__ = ["Pipeline"] _IndexedStep = tuple[int, str, AnyElement] -_AggStep = tuple[list[int], list[str], _MolPipeline] +_AggStep = tuple[list[int], list[str], "Pipeline"] _AggregatedPipelineStep = _IndexedStep | _AggStep -class Pipeline(_Pipeline): +def _agg_transformers( + transformer_list: Sequence[tuple[int, str, AnyElement]], + n_jobs: int = 1, +) -> tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement]: + """Aggregate transformers to a single step. + + Parameters + ---------- + transformer_list: list[tuple[int, str, AnyElement]] + List of transformers to aggregate. + n_jobs: int, optional + Number of cores used for aggregated steps. + + Returns + ------- + tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement] + Aggregated transformer. + If the list contains only one transformer, it is returned as is. + + """ + index_list = [step[0] for step in transformer_list] + name_list = [step[1] for step in transformer_list] + if len(transformer_list) == 1: + return transformer_list[0] + return ( + index_list, + name_list, + Pipeline( + [(step[1], step[2]) for step in transformer_list], + n_jobs=n_jobs, + ), + ) + + +def check_single_instance_support( + estimator: Any, +) -> TypeIs[SingleInstanceTransformerMixin | Pipeline]: + """Check if the estimator supports single instance processing. + + Parameters + ---------- + estimator: Any + Estimator to check. + + Returns + ------- + TypeIs[SingleInstanceTransformerMixin | Pipeline] + True if the estimator supports single instance processing. + + """ + if isinstance(estimator, SingleInstanceTransformerMixin): + return True + return isinstance(estimator, Pipeline) and estimator.supports_single_instance + + +class Pipeline(AdapterPipeline, TransformingPipelineElement): # pylint: disable=too-many-ancestors """Defines the pipeline which handles pipeline elements.""" steps: list[AnyStep] - # * Adapted methods from sklearn.pipeline.Pipeline * + + @property + def _filter_elements(self) -> list[ErrorFilter]: + """Get the elements which filter the input.""" + return [step[1] for step in self.steps if isinstance(step[1], ErrorFilter)] + + @property + def _filter_elements_agg(self) -> _MultipleErrorFilter: + """Get the aggregated filter element.""" + return _MultipleErrorFilter(self._filter_elements) + + @property + def supports_single_instance(self) -> bool: + """Check if the pipeline supports single instance.""" + return all( + isinstance(step[1], SingleInstanceTransformerMixin) + for step in self._modified_steps + ) def __init__( self, @@ -86,79 +153,19 @@ def __init__( self.n_jobs = n_jobs self._set_error_resinserter() - def _set_error_resinserter(self) -> None: - """Connect the error resinserters with the error filters.""" - error_replacer_list = [ - e_filler - for _, e_filler in self.steps - if isinstance(e_filler, FilterReinserter) - ] - error_filter_list = [ - n_filter for _, n_filter in self.steps if isinstance(n_filter, ErrorFilter) - ] - for step in self.steps: - if isinstance(step[1], PostPredictionWrapper): - if isinstance(step[1].wrapped_estimator, FilterReinserter): - error_replacer_list.append(step[1].wrapped_estimator) - for error_replacer in error_replacer_list: - error_replacer.select_error_filter(error_filter_list) - - def _validate_steps(self) -> None: - """Validate the steps. - - Raises - ------ - TypeError - If the steps do not implement fit and transform or are not 'passthrough'. - - """ - names = [name for name, _ in self.steps] - - # validate names - self._validate_names(names) - - # validate estimators - non_post_processing_steps = [e for _, _, e in self._agg_non_postpred_steps()] - transformer_list = non_post_processing_steps[:-1] - estimator = non_post_processing_steps[-1] - - for transformer in transformer_list: - if transformer is None or transformer == "passthrough": - continue - if not ( - hasattr(transformer, "fit") or hasattr(transformer, "fit_transform") - ) or not hasattr(transformer, "transform"): - raise TypeError( - f"All intermediate steps should be " - f"transformers and implement fit and transform " - f"or be the string 'passthrough' " - f"'{transformer}' (type {type(transformer)}) doesn't", - ) - - # We allow last estimator to be None as an identity transformation - if ( - estimator is not None - and estimator != "passthrough" - and not hasattr(estimator, "fit") - ): - raise TypeError( - f"Last step of Pipeline should implement fit " - f"or be the string 'passthrough'. " - f"'{estimator}' (type {type(estimator)}) doesn't", - ) - - # validate post-processing steps - # Calling steps automatically validates them - _ = self._post_processing_steps() - def _iter( self, with_final: bool = True, filter_passthrough: bool = True, - ) -> Iterable[_AggregatedPipelineStep]: + ) -> Generator[ + tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement], + Any, + None, + ]: """Iterate over all non post-processing steps. - Steps which are children of a ABCPipelineElement were aggregated to a MolPipeline. + Steps which are children of a ABCPipelineElement were aggregated to a + MolPipeline. Parameters ---------- @@ -167,11 +174,6 @@ def _iter( filter_passthrough: bool, optional If True, passthrough steps are filtered out. - Raises - ------ - AssertionError - If the pipeline has no steps. - Yields ------ _AggregatedPipelineStep @@ -179,321 +181,66 @@ def _iter( transformer. """ - last_element: _AggregatedPipelineStep | None = None + if self.supports_single_instance: + yield from super()._iter( + with_final=with_final, + filter_passthrough=filter_passthrough, + ) + return + transformers_to_agg: list[tuple[int, str, AnyTransformer]] + transformers_to_agg = [] + final_transformer_list: list[ + tuple[list[int], list[str], Pipeline] | tuple[int, str, AnyElement] + ] = [] + for i, (name_i, step_i) in enumerate(super()._modified_steps): + if isinstance(step_i, SingleInstanceTransformerMixin): + transformers_to_agg.append((i, name_i, step_i)) + else: + if transformers_to_agg: + if len(transformers_to_agg) == 1: + final_transformer_list.append(transformers_to_agg[0]) + else: + final_transformer_list.append( + _agg_transformers(transformers_to_agg, self.n_jobs), + ) + transformers_to_agg = [] + final_transformer_list.append((i, name_i, step_i)) - # This loop delays the output by one in order to identify the last step - for step in self._agg_non_postpred_steps(): - # Only happens for the first step - if last_element is None: - last_element = step - continue + # yield last step if anything remains + if transformers_to_agg: + final_transformer_list.append( + _agg_transformers(transformers_to_agg, self.n_jobs), + ) + stop = len(final_transformer_list) + if not with_final: + stop -= 1 + + for step in islice(final_transformer_list, 0, stop): if not filter_passthrough or ( step[2] is not None and step[2] != "passthrough" ): - yield last_element - last_element = step - - # This can only happen if no steps are set. - if last_element is None: - raise AssertionError("Pipeline needs to have at least one step!") - - if with_final and last_element[2] is not None: - if last_element[2] != "passthrough": - yield last_element - - @property - def _estimator_type(self) -> Any: - """Return the estimator type.""" - if self._final_estimator is None or self._final_estimator == "passthrough": - return None - if hasattr(self._final_estimator, "_estimator_type"): - # pylint: disable=protected-access - return self._final_estimator._estimator_type - return None + yield step @property def _final_estimator( self, - ) -> ( - Literal["passthrough"] - | AnyTransformer - | AnyPredictor - | _MolPipeline - | ABCPipelineElement - ): + ) -> Literal["passthrough"] | AnyTransformer | AnyPredictor | ABCPipelineElement: """Return the lst estimator which is not a PostprocessingTransformer.""" - element_list = list(self._agg_non_postpred_steps()) - last_element = element_list[-1] + element_list = list(self._iter(with_final=True)) + steps = [s for s in element_list if not isinstance(s[2], ErrorFilter)] + last_element = steps[-1] return last_element[2] - # pylint: disable=too-many-locals,too-many-branches - def _fit( - self, - X: Any, - y: Any = None, - routed_params: dict[str, Any] | None = None, - raw_params: dict[str, Any] | None = None, - ) -> tuple[Any, Any]: - """Fit the model by fitting all transformers except the final estimator. - - Data can be subsetted by the transformers. - - Parameters - ---------- - X : Any - Training data. - y : Any, optional (default=None) - Training objectives. - routed_params : dict[str, Any], optional - Parameters for each step as returned by process_routing. - Although this is marked as optional, it should not be None. - The awkward (argward?) typing is due to inheritance from sklearn. - Can be an empty dictionary. - raw_params : dict[str, Any], optional - Parameters passed by the user, used when `transform_input` - - Raises - ------ - AssertionError - If routed_params is None or if the transformer is 'passthrough'. - AssertionError - If the names are a list and the step is not a Pipeline. - - Returns - ------- - tuple[Any, Any] - The transformed data and the transformed objectives. - - """ - # shallow copy of steps - this should really be steps_ - self.steps = list(self.steps) - self._validate_steps() - if routed_params is None: - raise AssertionError("routed_params should not be None.") - - # Set up the memory - memory: joblib.Memory = check_memory(self.memory) - - fit_transform_one_cached = memory.cache(_fit_transform_one) - for step in self._iter(with_final=False, filter_passthrough=False): - step_idx, name, transformer = step - if transformer is None or transformer == "passthrough": - with print_elapsed_time("Pipeline", self._log_message(step_idx)): - continue - - if hasattr(memory, "location") and memory.location is None: - # we do not clone when caching is disabled to - # preserve backward compatibility - cloned_transformer = transformer - else: - cloned_transformer = clone(transformer) - if isinstance(cloned_transformer, _MolPipeline): - if routed_params: - step_params = { - "element_parameters": [routed_params[n] for n in name], - } - else: - step_params = {} - elif isinstance(name, list): - raise AssertionError( - "Names should not be a list, when the step is not a Pipeline", - ) - else: - step_params = self._get_metadata_for_step( - step_idx=step_idx, - step_params=routed_params[name], - all_params=raw_params, - ) - - # Fit or load from cache the current transformer - X, fitted_transformer = fit_transform_one_cached( - cloned_transformer, - X, - y, - None, - message_clsname="Pipeline", - message=self._log_message(step_idx), - params=step_params, - ) - # Replace the transformer of the step with the fitted - # transformer. This is necessary when loading the transformer - # from the cache. - if isinstance(fitted_transformer, _MolPipeline): - ele_list = fitted_transformer.element_list - if not isinstance(name, list) or not isinstance(step_idx, list): - raise AssertionError() - if not len(name) == len(step_idx) == len(ele_list): - raise AssertionError() - for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): - self.steps[idx_i] = (name_i, ele_i) - if y is not None: - y = fitted_transformer.co_transform(y) - for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): - self.steps[idx_i] = (name_i, ele_i) - self._set_error_resinserter() - elif isinstance(name, list) or isinstance(step_idx, list): - raise AssertionError() - else: - self.steps[step_idx] = (name, fitted_transformer) - if is_empty(X): - return np.array([]), np.array([]) - return X, y - - def _transform( - self, - X: Any, # pylint: disable=invalid-name - routed_params: Bunch, - ) -> Any: - """Transform the data, and skip final estimator. - - Call `transform` of each transformer in the pipeline except the last one, - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - routed_params: Bunch - parameters for each step as returned by process_routing - - Raises - ------ - AssertionError - If one of the transformers is 'passthrough' or does not implement - `transform`. - - Returns - ------- - Any - Result of calling `transform` on the second last estimator. - - """ - iter_input = X - do_routing = _routing_enabled() - if do_routing: - logger.warning("Routing is enabled and NOT fully tested!") - - for _, name, transform in self._iter(with_final=False): - if is_empty(iter_input): - if isinstance(transform, _MolPipeline): - _ = transform.transform(iter_input) - iter_input = [] - break - if transform == "passthrough": - raise AssertionError("Passthrough should have been filtered out.") - if hasattr(transform, "transform"): - if do_routing: - iter_input = transform.transform( # type: ignore[call-arg] - iter_input, - routed_params[name].transform, - ) - else: - iter_input = transform.transform(iter_input) - else: - raise AssertionError( - f"Non transformer ocurred in transformation step: {transform}.", - ) - return iter_input - - # * New implemented methods * - def _non_post_processing_steps( - self, - ) -> list[AnyStep]: - """Return all steps before the first PostPredictionTransformation. - - Raises - ------ - AssertionError - If a PostPredictionTransformation is found before the last step. - - Returns - ------- - list[AnyStep] - List of steps before the first PostPredictionTransformation. - - """ - non_post_processing_steps: list[AnyStep] = [] - start_adding = False - for step_name, step_estimator in self.steps[::-1]: - if not isinstance(step_estimator, PostPredictionTransformation): - start_adding = True - if start_adding: - if isinstance(step_estimator, PostPredictionTransformation): - raise AssertionError( - "PipelineElement of type PostPredictionTransformation occured " - "before the last step.", - ) - non_post_processing_steps.append((step_name, step_estimator)) - return list(non_post_processing_steps[::-1]) - - def _post_processing_steps(self) -> list[tuple[str, PostPredictionTransformation]]: - """Return last steps which are PostPredictionTransformation. - - Returns - ------- - list[tuple[str, PostPredictionTransformation]] - List of tuples containing the name and the PostPredictionTransformation. - - """ - post_processing_steps = [] - for step_name, step_estimator in self.steps[::-1]: - if isinstance(step_estimator, PostPredictionTransformation): - post_processing_steps.append((step_name, step_estimator)) - else: - break - return list(post_processing_steps[::-1]) - - def _agg_non_postpred_steps( - self, - ) -> Iterable[_AggregatedPipelineStep]: - """Generate (idx, (name, trans)) tuples from self.steps. - - When filter_passthrough is True, 'passthrough' and None transformers - are filtered out. - - Yields - ------ - _AggregatedPipelineStep - The _AggregatedPipelineStep is composed of the index, the name and the - transformer. - - """ - aggregated_transformer_list = [] - for i, (name_i, step_i) in enumerate(self._non_post_processing_steps()): - if isinstance(step_i, ABCPipelineElement): - aggregated_transformer_list.append((i, name_i, step_i)) - else: - if aggregated_transformer_list: - index_list = [step[0] for step in aggregated_transformer_list] - name_list = [step[1] for step in aggregated_transformer_list] - transformer_list = [step[2] for step in aggregated_transformer_list] - if len(aggregated_transformer_list) == 1: - yield index_list[0], name_list[0], transformer_list[0] - else: - pipeline = _MolPipeline(transformer_list, n_jobs=self.n_jobs) - yield index_list, name_list, pipeline - aggregated_transformer_list = [] - yield i, name_i, step_i - - # yield last step if anything remains - if aggregated_transformer_list: - index_list = [step[0] for step in aggregated_transformer_list] - name_list = [step[1] for step in aggregated_transformer_list] - transformer_list = [step[2] for step in aggregated_transformer_list] - - if len(aggregated_transformer_list) == 1: - yield index_list[0], name_list[0], transformer_list[0] - - elif len(aggregated_transformer_list) > 1: - pipeline = _MolPipeline(transformer_list, n_jobs=self.n_jobs) - yield index_list, name_list, pipeline - @_fit_context( # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False, ) - def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: + def fit( + self, + X: Any, # noqa: N803 + y: Any = None, + **fit_params: Any, + ) -> Self: """Fit the model. Fit all the transformers one after the other and transform the @@ -520,22 +267,12 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: Pipeline with fitted steps. """ - routed_params = self._check_method_params(method="fit", props=fit_params) - Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name - with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): - if self._final_estimator != "passthrough": - if is_empty(Xt): - logger.warning( - "All input rows were filtered out! Model is not fitted!", - ) - else: - fit_params_last_step = routed_params[ - self._non_post_processing_steps()[-1][0] - ] - self._final_estimator.fit(Xt, yt, **fit_params_last_step["fit"]) - + if self.supports_single_instance: + return self + super().fit(X, y, **fit_params) return self + # pylint: disable=duplicate-code def _can_fit_transform(self) -> bool: """Check if the final estimator can fit_transform or is passthrough. @@ -551,23 +288,19 @@ def _can_fit_transform(self) -> bool: or hasattr(self._final_estimator, "fit_transform") ) - def _can_decision_function(self) -> bool: - """Check if the final estimator implements decision_function. - - Returns - ------- - bool - True if the final estimator implements decision_function. - - """ - return hasattr(self._final_estimator, "decision_function") + # pylint: enable=duplicate-code @available_if(_can_fit_transform) @_fit_context( # estimators in Pipeline.steps are not validated yet prefer_skip_nested_validation=False, ) - def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: + def fit_transform( + self, + X: Any, # noqa: N803 + y: Any = None, + **params: Any, + ) -> Any: """Fit the model and transform with the final estimator. Fits all the transformers one after the other and transform the @@ -589,239 +322,68 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: each parameter name is prefixed such that parameter ``p`` for step ``s`` has key ``s__p``. - Raises - ------ - TypeError - If the last step does not implement `fit_transform` or `fit` and - `transform`. - - Returns - ------- - Xt : ndarray of shape (n_samples, n_transformed_features) - Transformed samples. - - """ - routed_params = self._check_method_params(method="fit_transform", props=params) - iter_input, iter_label = self._fit(X, y, routed_params) - last_step = self._final_estimator - with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): - if last_step == "passthrough": - pass - elif is_empty(iter_input): - logger.warning("All input rows were filtered out! Model is not fitted!") - else: - last_step_params = routed_params[ - self._non_post_processing_steps()[-1][0] - ] - if hasattr(last_step, "fit_transform"): - iter_input = last_step.fit_transform( - iter_input, - iter_label, - **last_step_params["fit_transform"], - ) - elif hasattr(last_step, "transform") and hasattr(last_step, "fit"): - last_step.fit(iter_input, iter_label, **last_step_params["fit"]) - iter_input = last_step.transform( - iter_input, - **last_step_params["transform"], - ) - else: - raise TypeError( - f"fit_transform of the final estimator" - f" {last_step.__class__.__name__} {last_step_params} does not " - f"match fit_transform of Pipeline {self.__class__.__name__}", - ) - for _, post_element in self._post_processing_steps(): - iter_input = post_element.fit_transform(iter_input, iter_label) - return iter_input - - @available_if(_final_estimator_has("predict")) - def predict(self, X: Any, **params: Any) -> Any: - """Transform the data, and apply `predict` with the final estimator. - - Call `transform` of each transformer in the pipeline. The transformed - data are finally passed to the final estimator that calls `predict` - method. Only valid if the final estimator implements `predict`. - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - **params : dict of string -> object - Parameters to the ``predict`` called at the end of all - transformations in the pipeline. Note that while this may be - used to return uncertainties from some models with return_std - or return_cov, uncertainties that are generated by the - transformations in the pipeline are not propagated to the - final estimator. - - .. versionadded:: 0.20 - Raises ------ AssertionError - If the final estimator does not implement `predict`. - In this case this function should not be available. - - Returns - ------- - y_pred : ndarray - Result of calling `predict` on the final estimator. - - """ - if _routing_enabled(): - routed_params = process_routing(self, "predict", **params) - else: - routed_params = process_routing(self, "predict") - - iter_input = self._transform(X, routed_params) - - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "predict"): - if _routing_enabled(): - iter_input = self._final_estimator.predict( - iter_input, - **routed_params[self._non_post_processing_steps()[-1][0]].predict, - ) - else: - iter_input = self._final_estimator.predict(iter_input, **params) - else: - raise AssertionError( - "Final estimator does not implement predict, hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps(): - iter_input = post_element.transform(iter_input) - return iter_input - - @available_if(_final_estimator_has("fit_predict")) - @_fit_context( - # estimators in Pipeline.steps are not validated yet - prefer_skip_nested_validation=False, - ) - def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: - """Transform the data, and apply `fit_predict` with the final estimator. - - Call `fit_transform` of each transformer in the pipeline. The - transformed data are finally passed to the final estimator that calls - `fit_predict` method. Only valid if the final estimator implements - `fit_predict`. - - Parameters - ---------- - X : iterable - Training data. Must fulfill input requirements of first step of - the pipeline. - - y : iterable, default=None - Training targets. Must fulfill label requirements for all steps - of the pipeline. - - **params : dict of string -> object - Parameters passed to the ``fit`` method of each step, where - each parameter name is prefixed such that parameter ``p`` for step - ``s`` has key ``s__p``. - - Raises - ------ + If PipelineElement is not a SingleInstanceTransformerMixin or Pipeline. AssertionError - If the final estimator does not implement `fit_predict`. - In this case this function should not be available. + If PipelineElement is not a TransformingPipelineElement. Returns ------- - y_pred : ndarray - Result of calling `fit_predict` on the final estimator. + Xt : ndarray of shape (n_samples, n_transformed_features) + Transformed samples. """ - routed_params = self._check_method_params(method="fit_predict", props=params) - iter_input, iter_label = self._fit(X, y, routed_params) - - params_last_step = routed_params[self._non_post_processing_steps()[-1][0]] - with print_elapsed_time("Pipeline", self._log_message(len(self.steps) - 1)): - if self._final_estimator == "passthrough": - y_pred = iter_input - elif is_empty(iter_input): - logger.warning("All input rows were filtered out! Model is not fitted!") - iter_input = [] - y_pred = [] - elif hasattr(self._final_estimator, "fit_predict"): - y_pred = self._final_estimator.fit_predict( - iter_input, - iter_label, - **params_last_step.get("fit_predict", {}), - ) - else: + if not self.supports_single_instance: + return super().fit_transform(X, **params) + iter_input = X + removed_row_dict: dict[ErrorFilter, list[int]] = { + error_filter: [] for error_filter in self._filter_elements + } + iter_idx_array = np.arange(len(iter_input)) + + # The meta elements merge steps which do not require fitting + for idx, _, i_element in self._iter(with_final=True): + if not isinstance(i_element, TransformingPipelineElement): raise AssertionError( - "Final estimator does not implement fit_predict, " - "hence this function should not be available.", + "PipelineElement is not a TransformingPipelineElement.", ) - for _, post_element in self._post_processing_steps(): - y_pred = post_element.fit_transform(y_pred, iter_label) - return y_pred - - @available_if(_final_estimator_has("predict_proba")) - def predict_proba(self, X: Any, **params: Any) -> Any: - """Transform the data, and apply `predict_proba` with the final estimator. - - Call `transform` of each transformer in the pipeline. The transformed - data are finally passed to the final estimator that calls `predict_proba` - method. Only valid if the final estimator implements `predict_proba`. - - Parameters - ---------- - X : iterable - Data to predict on. Must fulfill input requirements of first step - of the pipeline. - - **params : dict of string -> object - Parameters to the ``predict`` called at the end of all - transformations in the pipeline. Note that while this may be - used to return uncertainties from some models with return_std - or return_cov, uncertainties that are generated by the - transformations in the pipeline are not propagated to the - final estimator. - - Raises - ------ - AssertionError - If the final estimator does not implement `predict_proba`. - In this case this function should not be available. - - Returns - ------- - y_pred : ndarray - Result of calling `predict_proba` on the final estimator. - - """ - routed_params = process_routing(self, "predict_proba", **params) - iter_input = self._transform(X, routed_params) - - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "predict_proba"): - if _routing_enabled(): - iter_input = self._final_estimator.predict_proba( - iter_input, - **routed_params[ - self._non_post_processing_steps()[-1][0] - ].predict_proba, + if not check_single_instance_support(i_element): + raise AssertionError( + "PipelineElement is not a SingleInstanceTransformerMixin or" + " Pipeline with signle instance support.", ) - else: - iter_input = self._final_estimator.predict_proba(iter_input, **params) - else: - raise AssertionError( - "Final estimator does not implement predict_proba, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps(): - iter_input = post_element.transform(iter_input) + if isinstance(i_element, ErrorFilter): + continue + i_element.n_jobs = self.n_jobs + iter_input = i_element.pretransform(iter_input) + for error_filter in self._filter_elements: + iter_input = error_filter.transform(iter_input) + for idx in error_filter.error_indices: + removed_row_dict[error_filter].append(int(iter_idx_array[idx])) + iter_idx_array = error_filter.co_transform(iter_idx_array) + iter_input = i_element.assemble_output(iter_input) + i_element.n_jobs = 1 + + # Set removed rows to filter elements to allow for correct co_transform + iter_idx_array = np.arange(len(X)) + for error_filter in self._filter_elements: + removed_idx_list = removed_row_dict[error_filter] + error_filter.error_indices = [] + for new_idx, _idx in enumerate(iter_idx_array): + if _idx in removed_idx_list: + error_filter.error_indices.append(new_idx) + error_filter.n_total = len(iter_idx_array) + iter_idx_array = error_filter.co_transform(iter_idx_array) + error_replacer_list = [ + ele for _, ele in self.steps if isinstance(ele, FilterReinserter) + ] + for error_replacer in error_replacer_list: + error_replacer.select_error_filter(self._filter_elements) + iter_input = error_replacer.transform(iter_input) + for _, post_element in self._post_processing_steps: + iter_input = post_element.fit_transform(iter_input, y) return iter_input def _can_transform(self) -> bool: @@ -839,7 +401,11 @@ def _can_transform(self) -> bool: ) @available_if(_can_transform) - def transform(self, X: Any, **params: Any) -> Any: + def transform( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: """Transform the data, and apply `transform` with the final estimator. Call `transform` of each transformer in the pipeline. The transformed @@ -858,45 +424,38 @@ def transform(self, X: Any, **params: Any) -> Any: **params : Any Parameters to the ``transform`` method of each estimator. - Raises - ------ - AssertionError - If the final estimator does not implement `transform` or - `fit_transform` or is passthrough. - Returns ------- Xt : ndarray of shape (n_samples, n_transformed_features) Transformed data. """ - routed_params = process_routing(self, "transform", **params) - iter_input = X - for _, name, transform in self._iter(): - if transform == "passthrough": - continue - if is_empty(iter_input): - # This is done to prime the error filters - if isinstance(transform, _MolPipeline): - _ = transform.transform(iter_input) - iter_input = [] - break - if hasattr(transform, "transform"): - iter_input = transform.transform( - iter_input, - **routed_params[name].transform, - ) - else: - raise AssertionError( - "Non transformer ocurred in transformation step." - "This should have been caught in the validation step.", - ) - for _, post_element in self._post_processing_steps(): + if self.supports_single_instance: + output_generator = self._transform_iterator(X) + iter_input = self.assemble_output(output_generator) + else: + iter_input = super().transform(X, **params) + for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input, **params) return iter_input + def _can_decision_function(self) -> bool: + """Check if the final estimator implements decision_function. + + Returns + ------- + bool + True if the final estimator implements decision_function. + + """ + return hasattr(self._final_estimator, "decision_function") + @available_if(_can_decision_function) - def decision_function(self, X: Any, **params: Any) -> Any: + def decision_function( + self, + X: Any, # noqa: N803 + **params: Any, + ) -> Any: """Transform the data, and apply `decision_function` with the final estimator. Parameters @@ -907,72 +466,18 @@ def decision_function(self, X: Any, **params: Any) -> Any: **params : Any Parameters to the ``decision_function`` method of the final estimator. - Raises - ------ - AssertionError - If the final estimator does not implement `decision_function`. - Returns ------- Any Result of calling `decision_function` on the final estimator. """ - if _routing_enabled(): - routed_params = process_routing(self, "decision_function", **params) - else: - routed_params = process_routing(self, "decision_function") - - iter_input = self._transform(X, routed_params) - if self._final_estimator == "passthrough": - pass - elif is_empty(iter_input): - iter_input = [] - elif hasattr(self._final_estimator, "decision_function"): - if _routing_enabled(): - iter_input = self._final_estimator.decision_function( - iter_input, - **routed_params[self._final_estimator].predict, - ) - else: - iter_input = self._final_estimator.decision_function( - iter_input, - **params, - ) - else: - raise AssertionError( - "Final estimator does not implement `decision_function`, " - "hence this function should not be available.", - ) - for _, post_element in self._post_processing_steps(): + iter_input = super().decision_function(X, **params) + for _, post_element in self._post_processing_steps: iter_input = post_element.transform(iter_input) return iter_input - @property - def classes_(self) -> list[Any] | npt.NDArray[Any]: - """Return the classes of the last element. - - PostPredictionTransformation elements are not considered as last element. - - Raises - ------ - ValueError - If the last step is passthrough or has no classes_ attribute. - - """ - check_last = [ - step - for step in self.steps - if not isinstance(step[1], PostPredictionTransformation) - ] - last_step = check_last[-1][1] - if last_step == "passthrough": - raise ValueError("Last step is passthrough.") - if hasattr(last_step, "classes_"): - return last_step.classes_ - raise ValueError("Last step has no classes_ attribute.") - - def __sklearn_tags__(self) -> Tags: + def __sklearn_tags__(self) -> Tags: # noqa: PLW3201 """Return the sklearn tags. Notes @@ -1011,7 +516,8 @@ def __sklearn_tags__(self) -> Tags: pass try: - # Only the _final_estimator is changed from the original implementation is changed in the following 2 lines + # Only the _final_estimator is changed from the original implementation is + # changed in the following 2 lines if ( self._final_estimator is not None and self._final_estimator != "passthrough" @@ -1029,85 +535,130 @@ def __sklearn_tags__(self) -> Tags: return tags - def get_metadata_routing(self) -> MetadataRouter: - """Get metadata routing of this object. + def transform_single(self, input_value: Any) -> Any: + """Transform a single input according to the sequence of PipelineElements. - Please check :ref:`User Guide ` on how the routing - mechanism works. + Parameters + ---------- + input_value: Any + Molecular representation which is subsequently transformed. - Notes - ----- - This method is copied from the original sklearn implementation. - Changes are marked with a comment. + Raises + ------ + AssertionError + If the PipelineElement is not a SingleInstanceTransformerMixin. Returns ------- - MetadataRouter - A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating - routing information. + Any + Transformed molecular representation. """ - router = MetadataRouter(owner=self.__class__.__name__) - - # first we add all steps except the last one - for _, name, trans in self._iter(with_final=False, filter_passthrough=True): - method_mapping = MethodMapping() - # fit, fit_predict, and fit_transform call fit_transform if it - # exists, or else fit and transform - if hasattr(trans, "fit_transform"): - ( - method_mapping.add(caller="fit", callee="fit_transform") - .add(caller="fit_transform", callee="fit_transform") - .add(caller="fit_predict", callee="fit_transform") + log_block = BlockLogs() + iter_value = input_value + for _, p_element in self._modified_steps: + if not isinstance(p_element, SingleInstanceTransformerMixin): + raise AssertionError( + "PipelineElement is not a SingleInstanceTransformerMixin.", ) - else: - ( - method_mapping.add(caller="fit", callee="fit") - .add(caller="fit", callee="transform") - .add(caller="fit_transform", callee="fit") - .add(caller="fit_transform", callee="transform") - .add(caller="fit_predict", callee="fit") - .add(caller="fit_predict", callee="transform") + if not isinstance(p_element, ABCPipelineElement): + raise AssertionError( + "PipelineElement is not a ABCPipelineElement.", + ) + try: + if not isinstance(iter_value, RemovedInstance) or isinstance( + p_element, + FilterReinserter, + ): + iter_value = p_element.transform_single(iter_value) + except MolSanitizeException as err: + iter_value = InvalidInstance( + p_element.uuid, + f"RDKit MolSanitizeException: {err.args}", + p_element.name, ) + del log_block + return iter_value - ( - method_mapping.add(caller="predict", callee="transform") - .add(caller="predict", callee="transform") - .add(caller="predict_proba", callee="transform") - .add(caller="decision_function", callee="transform") - .add(caller="predict_log_proba", callee="transform") - .add(caller="transform", callee="transform") - .add(caller="inverse_transform", callee="inverse_transform") - .add(caller="score", callee="transform") - ) + def pretransform(self, value_list: Any) -> Any: + """Transform the input according to the sequence without assemble_output step. - router.add(method_mapping=method_mapping, **{name: trans}) + Parameters + ---------- + value_list: Any + Molecular representations which are subsequently transformed. - # Only the _non_post_processing_steps is changed from the original implementation is changed in the following line - final_name, final_est = self._non_post_processing_steps()[-1] - if final_est is None or final_est == "passthrough": - return router + Returns + ------- + Any + Transformed molecular representations. - # then we add the last step - method_mapping = MethodMapping() - if hasattr(final_est, "fit_transform"): - method_mapping.add(caller="fit_transform", callee="fit_transform") - else: - method_mapping.add(caller="fit", callee="fit").add( - caller="fit", - callee="transform", - ) - ( - method_mapping.add(caller="fit", callee="fit") - .add(caller="predict", callee="predict") - .add(caller="fit_predict", callee="fit_predict") - .add(caller="predict_proba", callee="predict_proba") - .add(caller="decision_function", callee="decision_function") - .add(caller="predict_log_proba", callee="predict_log_proba") - .add(caller="transform", callee="transform") - .add(caller="inverse_transform", callee="inverse_transform") - .add(caller="score", callee="score") + """ + return list(self._transform_iterator(value_list)) + + def assemble_output(self, value_list: Iterable[Any]) -> Any: + """Assemble the output of the pipeline. + + Parameters + ---------- + value_list: Iterable[Any] + Generator which yields the output of the pipeline. + + Returns + ------- + Any + Assembled output. + + """ + final_estimator = self._final_estimator + if isinstance(final_estimator, TransformingPipelineElement): + return final_estimator.assemble_output(value_list) + return list(value_list) + + def _transform_iterator(self, x_input: Any) -> Any: + """Transform the input according to the sequence of provided PipelineElements. + + Parameters + ---------- + x_input: Any + Molecular representations which are subsequently transformed. + + Yields + ------ + Any + Transformed molecular representations. + + """ + agg_filter = self._filter_elements_agg + for filter_element in self._filter_elements: + filter_element.error_indices = [] + parallel = Parallel( + n_jobs=self.n_jobs, + return_as="generator", + batch_size="auto", + ) + output_generator = parallel( + delayed(self.transform_single)(value) for value in x_input ) + for i, transformed_value in enumerate(output_generator): + if isinstance(transformed_value, RemovedInstance): + agg_filter.register_removed(i, transformed_value) + else: + yield transformed_value + agg_filter.set_total(len(x_input)) + + def co_transform(self, x_input: TypeFixedVarSeq) -> TypeFixedVarSeq: + """Filter flagged rows from the input. + + Parameters + ---------- + x_input: Any + Molecular representations which are subsequently filtered. - router.add(method_mapping=method_mapping, **{final_name: final_est}) - return router + Returns + ------- + Any + Filtered molecular representations. + + """ + return self._filter_elements_agg.co_transform(x_input) diff --git a/molpipeline/utils/molpipeline_types.py b/molpipeline/utils/molpipeline_types.py index 9240dab6..3664df1c 100644 --- a/molpipeline/utils/molpipeline_types.py +++ b/molpipeline/utils/molpipeline_types.py @@ -81,7 +81,7 @@ def set_params(self, **params: Any) -> Self: def fit( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 y: npt.NDArray[Any] | None, **fit_params: Any, ) -> Self: @@ -89,7 +89,7 @@ def fit( Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. y: npt.NDArray[Any] | None Target values. @@ -110,7 +110,7 @@ class AnyPredictor(AnySklearnEstimator, Protocol): def fit_predict( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 y: npt.NDArray[Any] | None, **fit_params: Any, ) -> npt.NDArray[Any]: @@ -118,7 +118,7 @@ def fit_predict( Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. y: npt.NDArray[Any] | None Target values. @@ -138,7 +138,7 @@ class AnyTransformer(AnySklearnEstimator, Protocol): def fit_transform( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 y: npt.NDArray[Any] | None, **fit_params: Any, ) -> npt.NDArray[Any]: @@ -146,7 +146,7 @@ def fit_transform( Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. y: npt.NDArray[Any] | None Target values. @@ -163,14 +163,14 @@ def fit_transform( def transform( self, - X: npt.NDArray[Any], # pylint: disable=invalid-name + X: npt.ArrayLike, # pylint: disable=invalid-name # noqa: N803 **params: Any, ) -> npt.NDArray[Any]: """Transform and return X according to object protocol. Parameters ---------- - X: npt.NDArray[Any] + X: npt.ArrayLike Model input. params: Any Additional parameters for transforming. diff --git a/tests/test_elements/test_error_handling.py b/tests/test_elements/test_error_handling.py index d6a79245..8ee8bf53 100644 --- a/tests/test_elements/test_error_handling.py +++ b/tests/test_elements/test_error_handling.py @@ -22,6 +22,8 @@ TEST_SMILES = ["NCCCO", "abc", "c1ccccc1"] EXPECTED_OUTPUT = ["NCCCO", None, "c1ccccc1"] +TOLERANCE = 0.000001 + class NoneTest(unittest.TestCase): """Unittest for None Handling.""" @@ -95,7 +97,7 @@ def test_dummy_remove_physchem_record_molpipeline(self) -> None: out = pipeline.transform(TEST_SMILES) out2 = pipeline2.fit_transform(TEST_SMILES) self.assertEqual(out.shape, out2.shape) - self.assertTrue(np.max(np.abs(out - out2)) < 0.000001) + self.assertTrue(np.max(np.abs(out - out2)) < TOLERANCE) def test_dummy_remove_physchem_record_autodetect_molpipeline(self) -> None: """Assert that invalid smiles are transformed to None.""" @@ -108,14 +110,14 @@ def test_dummy_remove_physchem_record_autodetect_molpipeline(self) -> None: ("mol2physchem", mol2physchem), ("remove_none", remove_none), ], + n_jobs=1, ) pipeline2 = clone(pipeline) pipeline.fit(TEST_SMILES) out = pipeline.transform(TEST_SMILES) - print(pipeline2["remove_none"].filter_everything) out2 = pipeline2.fit_transform(TEST_SMILES) self.assertEqual(out.shape, out2.shape) - self.assertTrue(np.max(np.abs(out - out2)) < 0.000001) + self.assertTrue(np.max(np.abs(out - out2)) < TOLERANCE) def test_dummy_fill_physchem_record_molpipeline(self) -> None: """Assert that invalid smiles are transformed to None.""" @@ -141,7 +143,7 @@ def test_dummy_fill_physchem_record_molpipeline(self) -> None: out2 = pipeline2.fit_transform(TEST_SMILES) self.assertEqual(out.shape, out2.shape) self.assertEqual(out.shape, (3, 215)) - self.assertTrue(np.nanmax(np.abs(out - out2)) < 0.000001) + self.assertTrue(np.nanmax(np.abs(out - out2)) < TOLERANCE) def test_replace_mixed_datatypes(self) -> None: """Assert that invalid values are replaced by fill value.""" @@ -246,7 +248,7 @@ def test_replace_mixed_datatypes_expected_failures(self) -> None: ) pipeline2 = clone(pipeline) - self.assertRaises(ValueError, pipeline.fit, test_values) + pipeline.fit(test_values) self.assertRaises(ValueError, pipeline.transform, test_values) self.assertRaises(ValueError, pipeline2.fit_transform, test_values) diff --git a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py index 4f11fef5..19f20c9f 100644 --- a/tests/test_elements/test_mol2mol/test_mol2mol_filter.py +++ b/tests/test_elements/test_mol2mol/test_mol2mol_filter.py @@ -95,7 +95,7 @@ def test_element_filter(self) -> None: ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -185,7 +185,7 @@ def test_complex_filter(self) -> None: ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -273,7 +273,7 @@ def test_smarts_smiles_filter(self) -> None: ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) @@ -396,7 +396,7 @@ def test_descriptor_filter(self) -> None: ] for test_params in test_params_list_with_results: - pipeline.set_params(**test_params["params"]) + pipeline.set_params(**test_params["params"]) # type: ignore filtered_smiles = pipeline.fit_transform(SMILES_LIST) self.assertEqual(filtered_smiles, test_params["result"]) diff --git a/tests/utils/mock_element.py b/tests/utils/mock_element.py index 2b89987c..5e7c6c65 100644 --- a/tests/utils/mock_element.py +++ b/tests/utils/mock_element.py @@ -3,18 +3,24 @@ from __future__ import annotations import copy -from collections.abc import Iterable -from typing import Any, Self +from typing import TYPE_CHECKING, Any, Self import numpy as np from molpipeline.abstract_pipeline_elements.core import ( InvalidInstance, + SingleInstanceTransformerMixin, TransformingPipelineElement, ) +if TYPE_CHECKING: + from collections.abc import Iterable -class MockTransformingPipelineElement(TransformingPipelineElement): + +class MockTransformingPipelineElement( + SingleInstanceTransformerMixin, + TransformingPipelineElement, +): """Mock element for testing.""" def __init__( @@ -40,6 +46,7 @@ def __init__( Unique identifier of PipelineElement. n_jobs: int, default=1 Number of jobs to run in parallel. + """ super().__init__(name=name, uuid=uuid, n_jobs=n_jobs) if invalid_values is None: @@ -59,6 +66,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary containing all parameters defining the object. + """ params = super().get_params(deep) if deep: @@ -81,6 +89,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self MockTransformingPipelineElement with updated parameters. + """ super().set_params(**parameters) if "invalid_values" in parameters: @@ -101,6 +110,7 @@ def pretransform_single(self, value: Any) -> Any: ------- Any Other value. + """ if value in self.invalid_values: return InvalidInstance( @@ -113,8 +123,8 @@ def pretransform_single(self, value: Any) -> Any: def assemble_output(self, value_list: Iterable[Any]) -> Any: """Aggregate rows, which in most cases is just return the list. - Some representations might be better representd as a single object. For example a list of vectors can - be transformed to a matrix. + Some representations might be better representd as a single object. + For example a list of vectors can be transformed to a matrix. Parameters ---------- @@ -125,6 +135,7 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: ------- Any Aggregated output. This can also be the original input. + """ if self.return_as_numpy_array: return np.array(list(value_list))