From eb6b818404978fc5f47b3fafa560d24545e95e76 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Tue, 22 Apr 2025 10:41:05 +0200 Subject: [PATCH 1/5] Remove unnecessary methods and attributes (#134) * remove copy method * remove additional_attributes and to_json method * remove mol counter * remove finish method * Allow None in as `identifier` --- .../abstract_pipeline_elements/core.py | 59 ------------------- molpipeline/any2mol/sdf2mol.py | 20 +++---- molpipeline/pipeline/_molpipeline.py | 6 -- .../test_mol2maccs_key_fingerprint.py | 3 +- .../test_mol2morgan_fingerprint.py | 5 +- .../test_mol2any/test_mol2path_fingerprint.py | 7 ++- 6 files changed, 16 insertions(+), 84 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 638cb3a7..0863d64f 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -15,7 +15,6 @@ from uuid import uuid4 -import numpy as np from joblib import Parallel, delayed from loguru import logger from rdkit import Chem @@ -209,11 +208,6 @@ def set_params(self, **parameters: Any) -> Self: setattr(self, att_name, att_value) return self - @property - def additional_attributes(self) -> dict[str, Any]: - """Any attribute relevant for recreating and exact copy, which is not a parameter.""" - return {} - @property def n_jobs(self) -> int: """Get the number of cores.""" @@ -236,12 +230,6 @@ def requires_fitting(self) -> bool: """Return whether the object requires fitting or not.""" return self._requires_fitting - def finish(self) -> None: - """Inform object that iteration has been finished. Does in most cases nothing. - - Called after all transform singles have been processed. From MolPipeline - """ - def fit(self, values: Any, labels: Any = None) -> Self: """Fit object to input_values. @@ -391,29 +379,6 @@ def parameters(self, **parameters: Any) -> None: """ self.set_params(**parameters) - def copy(self) -> Self: - """Copy the object. - - Raises - ------ - AssertionError - If the object cannot be copied. - - Returns - ------- - Self - Copy of the object. - - """ - recreated_object = self.__class__(**self.parameters) - for key, value in self.additional_attributes.items(): - if not hasattr(recreated_object, key): - raise AssertionError( - f"Cannot set attribute {key} on {self.__class__.__name__}. This should not happen!" - ) - setattr(recreated_object, key, copy.copy(value)) - return recreated_object - def fit_to_result(self, values: Any) -> Self: """Fit object to result of transformed values. @@ -589,32 +554,8 @@ def transform(self, values: Any) -> Any: output_rows = self.pretransform(values) output_rows = self.finalize_list(output_rows) output = self.assemble_output(output_rows) - self.finish() return output - def to_json(self) -> dict[str, Any]: - """Return all defining attributes of object as dict. - - Returns - ------- - dict[str, Any] - A dictionary with all attributes necessary to initialize a object with same parameters. - """ - json_dict: dict[str, Any] = { - "__name__": self.__class__.__name__, - "__module__": self.__class__.__module__, - } - json_dict.update(self.parameters) - if self.additional_attributes: - adittional_attributes = {} - for key, value in self.additional_attributes.items(): - if isinstance(value, np.ndarray): - adittional_attributes[key] = value.tolist() - else: - adittional_attributes[key] = value - json_dict["additional_attributes"] = adittional_attributes - return json_dict - class MolToMolPipelineElement(TransformingPipelineElement, abc.ABC): """Abstract PipelineElement where input and outputs are molecules.""" diff --git a/molpipeline/any2mol/sdf2mol.py b/molpipeline/any2mol/sdf2mol.py index 5139917a..a6bea806 100644 --- a/molpipeline/any2mol/sdf2mol.py +++ b/molpipeline/any2mol/sdf2mol.py @@ -2,7 +2,7 @@ from __future__ import annotations -from typing import Any +from typing import Any, Literal try: from typing import Self # type: ignore[attr-defined] @@ -23,12 +23,11 @@ class SDFToMol(_StringToMolPipelineElement): """PipelineElement transforming a list of SDF strings to mol_objects.""" - identifier: str - mol_counter: int + identifier: Literal["smiles"] | None def __init__( self, - identifier: str = "enumerate", + identifier: Literal["smiles"] | None = "smiles", name: str = "SDF2Mol", n_jobs: int = 1, uuid: str | None = None, @@ -37,8 +36,9 @@ def __init__( Parameters ---------- - identifier: str, default='enumerate' - Method of assigning identifiers to molecules. At the moment molecules are counted. + identifier: Literal["smiles"] | None, default='smiles' + Method of assigning identifiers to molecules. + If None, no identifier is assigned. name: str, default='SDF2Mol' Name of PipelineElement n_jobs: int, default=1 @@ -48,7 +48,6 @@ def __init__( """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.identifier = identifier - self.mol_counter = 0 def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. @@ -88,10 +87,6 @@ def set_params(self, **parameters: Any) -> Self: self.identifier = parameters["identifier"] return self - def finish(self) -> None: - """Reset the mol counter which assigns identifiers.""" - self.mol_counter = 0 - def pretransform_single(self, value: str) -> OptionalMol: """Transform an SDF-strings to a rdkit molecule. @@ -119,6 +114,5 @@ def pretransform_single(self, value: str) -> OptionalMol: self.name, ) if self.identifier == "smiles": - mol.SetProp("identifier", str(self.mol_counter)) - self.mol_counter += 1 + mol.SetProp("identifier", Chem.MolToSmiles(mol)) return mol diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py index 542d5cda..c38e225f 100644 --- a/molpipeline/pipeline/_molpipeline.py +++ b/molpipeline/pipeline/_molpipeline.py @@ -394,11 +394,6 @@ def assemble_output(self, value_list: Iterable[Any]) -> Any: return last_element.assemble_output(value_list) return list(value_list) - def _finish(self) -> None: - """Inform each pipeline element that the iterations have finished.""" - for p_element in self._element_list: - p_element.finish() - def _transform_iterator(self, x_input: Any) -> Any: """Transform the input according to the sequence of provided PipelineElements. @@ -430,7 +425,6 @@ def _transform_iterator(self, x_input: Any) -> Any: else: yield transformed_value agg_filter.set_total(len(x_input)) - self._finish() def co_transform(self, x_input: TypeFixedVarSeq) -> TypeFixedVarSeq: """Filter flagged rows from the input. diff --git a/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py index 2655a435..46603ac1 100644 --- a/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2maccs_key_fingerprint.py @@ -6,6 +6,7 @@ from typing import Any import numpy as np +from sklearn.base import clone from molpipeline import Pipeline from molpipeline.any2mol import SmilesToMol @@ -27,7 +28,7 @@ class TestMolToMACCSFP(unittest.TestCase): def test_can_be_constructed(self) -> None: """Test if the MolToMACCSFP pipeline element can be constructed.""" mol_fp = MolToMACCSFP() - mol_fp_copy = mol_fp.copy() + mol_fp_copy = clone(mol_fp) self.assertTrue(mol_fp_copy is not mol_fp) for key, value in mol_fp.get_params().items(): self.assertEqual(value, mol_fp_copy.get_params()[key]) diff --git a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py index 17a31bd0..ede05f56 100644 --- a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py @@ -6,6 +6,7 @@ from typing import Any import numpy as np +from sklearn.base import clone from molpipeline import Pipeline from molpipeline.abstract_pipeline_elements.core import InvalidInstance @@ -23,10 +24,10 @@ class TestMol2MorganFingerprint(unittest.TestCase): """Unittest for MolToFoldedMorganFingerprint, which calculates folded Morgan Fingerprints.""" - def test_can_be_constructed(self) -> None: + def test_clone(self) -> None: """Test if the MolToFoldedMorganFingerprint pipeline element can be constructed.""" mol_fp = MolToMorganFP() - mol_fp_copy = mol_fp.copy() + mol_fp_copy = clone(mol_fp) self.assertTrue(mol_fp_copy is not mol_fp) for key, value in mol_fp.get_params().items(): self.assertEqual(value, mol_fp_copy.get_params()[key]) diff --git a/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py index a0903ebe..ff160b5b 100644 --- a/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py @@ -6,6 +6,7 @@ from typing import Any import numpy as np +from sklearn.base import clone from molpipeline import Pipeline from molpipeline.any2mol import SmilesToMol @@ -23,10 +24,10 @@ class TestMol2PathFingerprint(unittest.TestCase): """Unittest for Mol2PathFP, which calculates the RDKit Path Fingerprint.""" - def test_can_be_constructed(self) -> None: - """Test if the Mol2PathFP pipeline element can be constructed.""" + def test_clone(self) -> None: + """Test if the Mol2PathFP pipeline element can be cloned.""" mol_fp = Mol2PathFP() - mol_fp_copy = mol_fp.copy() + mol_fp_copy = clone(mol_fp) self.assertTrue(mol_fp_copy is not mol_fp) for key, value in mol_fp.get_params().items(): self.assertEqual(value, mol_fp_copy.get_params()[key]) From a1d1b79017ae74fbfef39187a0c87fc09411930a Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Wed, 23 Apr 2025 14:43:14 +0200 Subject: [PATCH 2/5] Remove standard scaler from mol2float (#165) * remove standard scaler --- .../mol2any/mol2floatvector.py | 137 ++---------------- molpipeline/mol2any/mol2net_charge.py | 42 +++--- molpipeline/mol2any/mol2rdkit_phys_chem.py | 34 +++-- ruff.toml | 1 + .../test_mol2any/test_mol2concatenated.py | 47 +++--- .../test_mol2any/test_mol2net_charge.py | 26 ++-- .../test_mol2any/test_mol2rdkit_phys_chem.py | 82 ++++------- 7 files changed, 132 insertions(+), 237 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py index fe8aafb6..9e3e7e43 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2floatvector.py @@ -3,36 +3,31 @@ from __future__ import annotations import abc -from collections.abc import Iterable -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import TYPE_CHECKING import numpy as np -import numpy.typing as npt -from sklearn.base import clone -from sklearn.preprocessing import StandardScaler from molpipeline.abstract_pipeline_elements.core import ( InvalidInstance, MolToAnyPipelineElement, ) -from molpipeline.utils.molpipeline_types import AnyTransformer, RDKitMol + +if TYPE_CHECKING: + from collections.abc import Iterable + + import numpy.typing as npt + + from molpipeline.utils.molpipeline_types import RDKitMol class MolToDescriptorPipelineElement(MolToAnyPipelineElement): - """PipelineElement which generates a matrix from descriptor-vectors of each molecule.""" + """PipelineElement for generating descriptor-vectors.""" - _standardizer: AnyTransformer | None _output_type = "float" _feature_names: list[str] def __init__( self, - standardizer: AnyTransformer | None = StandardScaler(), name: str = "MolToDescriptorPipelineElement", n_jobs: int = 1, uuid: str | None = None, @@ -41,8 +36,6 @@ def __init__( Parameters ---------- - standardizer: AnyTransformer | None default=StandardScaler() - The output is post_processed according to the standardizer if not None. name: str, default='MolToDescriptorPipelineElement' Name of the PipelineElement. n_jobs: int, default=1 @@ -52,9 +45,6 @@ def __init__( """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) - self._standardizer = standardizer - if self._standardizer is not None: - self._requires_fitting = True self._mean = None self._std = None @@ -83,88 +73,9 @@ def assemble_output( ------- npt.NDArray[np.float64] Matrix with descriptor values of each molecule. - """ - return np.vstack(list(value_list)) - - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Return all parameters defined during object initialization. - - Parameters - ---------- - deep: bool, default=True - If True get a deep copy of the parameters. - - Returns - ------- - dict[str, Any] - Dictionary containing all parameters relevant to initialize the object with same properties. - """ - params = super().get_params(deep) - if deep: - if self._standardizer is not None: - params["standardizer"] = clone(self._standardizer) - else: - params["standardizer"] = None - else: - params["standardizer"] = self._standardizer - return params - - def set_params(self, **parameters: Any) -> Self: - """Set parameters. - - Parameters - ---------- - parameters: Any - Dictionary with parameter names and corresponding values. - - Returns - ------- - Self - Object with updated parameters. - """ - parameter_copy = dict(parameters) - standardizer = parameter_copy.pop("standardizer", None) - if standardizer is not None: - self._standardizer = standardizer - super().set_params(**parameter_copy) - return self - - def fit_to_result(self, values: list[npt.NDArray[np.float64]]) -> Self: - """Fit object to data. - - Parameters - ---------- - values: list[RDKitMol] - List of RDKit molecules to which the Pipeline element is fitted. - Returns - ------- - Self - Fitted MolToDescriptorPipelineElement. - """ - value_matrix = np.vstack(list(values)) - if self._standardizer is not None: - self._standardizer.fit(value_matrix, None) - return self - - def _normalize_matrix( - self, value_matrix: npt.NDArray[np.float64] - ) -> npt.NDArray[np.float64]: - """Normalize matrix with descriptor values. - - Parameters - ---------- - value_matrix: npt.NDArray[np.float64] - Matrix with descriptor values of molecules. - - Returns - ------- - npt.NDArray[np.float64] - Normalized matrix with descriptor values of molecules. """ - if self._standardizer is not None: - return self._standardizer.transform(value_matrix) - return value_matrix + return np.vstack(list(value_list)) def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: """Transform the list of molecules to sparse matrix. @@ -178,35 +89,19 @@ def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: ------- npt.NDArray[np.float64] Matrix with descriptor values of molecules. + """ descriptor_matrix: npt.NDArray[np.float64] = super().transform(values) return descriptor_matrix - def finalize_single( - self, value: npt.NDArray[np.float64] - ) -> npt.NDArray[np.float64]: - """Finalize single value. Here: standardize vector. - - Parameters - ---------- - value: Any - Single value to be finalized. - - Returns - ------- - Any - Finalized value. - """ - if self._standardizer is not None: - standadized_value = self._standardizer.transform(value.reshape(1, -1)) - return standadized_value.reshape(-1) - return value - @abc.abstractmethod def pretransform_single( - self, value: RDKitMol + self, + value: RDKitMol, ) -> npt.NDArray[np.float64] | InvalidInstance: - """Transform mol to dict, where items encode columns indices and values, respectively. + """Transform mol to dict. + + Items encode columns indices and values, respectively. Parameters ---------- diff --git a/molpipeline/mol2any/mol2net_charge.py b/molpipeline/mol2any/mol2net_charge.py index ee8a20f2..b8469675 100644 --- a/molpipeline/mol2any/mol2net_charge.py +++ b/molpipeline/mol2any/mol2net_charge.py @@ -3,7 +3,7 @@ from __future__ import annotations import copy -from typing import Any, Literal, TypeAlias +from typing import TYPE_CHECKING, Any, Literal, TypeAlias try: from typing import Self # type: ignore[attr-defined] @@ -13,17 +13,21 @@ import numpy as np import numpy.typing as npt from rdkit import Chem -from sklearn.preprocessing import StandardScaler +from rdkit.Chem import rdPartialCharges from molpipeline.abstract_pipeline_elements.core import InvalidInstance from molpipeline.abstract_pipeline_elements.mol2any.mol2floatvector import ( MolToDescriptorPipelineElement, ) -from molpipeline.utils.molpipeline_types import AnyTransformer, RDKitMol + +if TYPE_CHECKING: + from molpipeline.utils.molpipeline_types import RDKitMol # Methods to compute the net charge of a molecule. -# - "formal_charge" uses the formal charges of the atoms with rdkit.Chem.rdmolops.GetFormalCharge -# - "gasteiger" uses the Gasteiger charges of the atoms with rdkit.Chem.rdPartialCharges.ComputeGasteigerCharges +# - "formal_charge" uses the formal charges of the atoms +# - uses rdkit.Chem.rdmolops.GetFormalCharge +# - "gasteiger" uses the Gasteiger charges of the atoms with +# - uses rdkit.Chem.rdPartialCharges.ComputeGasteigerCharges MolToNetChargeMethod: TypeAlias = Literal["formal_charge", "gasteiger"] @@ -33,7 +37,6 @@ class MolToNetCharge(MolToDescriptorPipelineElement): def __init__( self, charge_method: MolToNetChargeMethod = "formal_charge", - standardizer: AnyTransformer | None = StandardScaler(), name: str = "MolToNetCharge", n_jobs: int = 1, uuid: str | None = None, @@ -43,24 +46,23 @@ def __init__( Parameters ---------- charge_method: MolToNetChargeMethod, optional (default="formal_charge") - Policy how to compute the net charge of a molecule. Can be "formal_charge" which uses sum - of the formal charges assigned to each atom. "gasteiger" computes the Gasteiger partial - charges and returns the rounded sum over the atoms. - standardizer: AnyTransformer, optional - Standardizer to use, by default StandardScaler() + Policy how to compute the net charge of a molecule. + "formal_charge" uses sum of the formal charges assigned to each atom. + "gasteiger" computes the Gasteiger partial charges and returns the rounded + sum over the atoms. name: str, optional Name of the pipeline element, by default "MolToNetCharge" n_jobs: int, optional Number of jobs to run in parallel, by default 1 uuid: str, optional UUID of the pipeline element, by default None + """ self._descriptor_list = ["NetCharge"] self._feature_names = self._descriptor_list self._charge_method = charge_method # pylint: disable=R0801 super().__init__( - standardizer=standardizer, name=name, n_jobs=n_jobs, uuid=uuid, @@ -77,7 +79,8 @@ def descriptor_list(self) -> list[str]: return self._descriptor_list[:] def _get_net_charge_gasteiger( - self, value: RDKitMol + self, + value: RDKitMol, ) -> npt.NDArray[np.float64] | InvalidInstance: """Transform a single molecule to it's net charge using Gasteiger charges. @@ -92,21 +95,22 @@ def _get_net_charge_gasteiger( ------- Optional[npt.NDArray[np.float64]] Net charge of the given molecule. + """ # copy molecule since ComputeGasteigerCharges modifies the molecule inplace value_copy = Chem.Mol(value) - Chem.rdPartialCharges.ComputeGasteigerCharges(value_copy) + rdPartialCharges.ComputeGasteigerCharges(value_copy) atoms_contributions = np.array( - [atom.GetDoubleProp("_GasteigerCharge") for atom in value_copy.GetAtoms()] + [atom.GetDoubleProp("_GasteigerCharge") for atom in value_copy.GetAtoms()], ) if np.any(np.isnan(atoms_contributions)): return InvalidInstance(self.uuid, "NaN in Gasteiger charges", self.name) # sum up the charges and round to the nearest integer. - net_charge = np.round(np.sum(atoms_contributions, keepdims=True)) - return net_charge + return np.round(np.sum(atoms_contributions, keepdims=True)) def pretransform_single( - self, value: RDKitMol + self, + value: RDKitMol, ) -> npt.NDArray[np.float64] | InvalidInstance: """Transform a single molecule to it's net charge. @@ -144,6 +148,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter of the pipeline element. + """ parent_dict = dict(super().get_params(deep=deep)) if deep: @@ -164,6 +169,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self Self + """ parameters_shallow_copy = dict(parameters) charge_policy = parameters_shallow_copy.pop("charge_policy", None) diff --git a/molpipeline/mol2any/mol2rdkit_phys_chem.py b/molpipeline/mol2any/mol2rdkit_phys_chem.py index aebfc30c..7e332a1a 100644 --- a/molpipeline/mol2any/mol2rdkit_phys_chem.py +++ b/molpipeline/mol2any/mol2rdkit_phys_chem.py @@ -4,7 +4,7 @@ from __future__ import annotations -from typing import Any, Callable +from typing import TYPE_CHECKING, Any try: from typing import Self # type: ignore[attr-defined] @@ -18,13 +18,16 @@ from loguru import logger from rdkit import Chem, rdBase from rdkit.Chem import Descriptors -from sklearn.preprocessing import StandardScaler from molpipeline.abstract_pipeline_elements.core import InvalidInstance from molpipeline.abstract_pipeline_elements.mol2any.mol2floatvector import ( MolToDescriptorPipelineElement, ) -from molpipeline.utils.molpipeline_types import AnyTransformer, RDKitMol + +if TYPE_CHECKING: + from collections.abc import Callable + + from molpipeline.utils.molpipeline_types import RDKitMol RDKIT_DESCRIPTOR_DICT: dict[str, Callable[[Chem.Mol], float]] RDKIT_DESCRIPTOR_DICT = dict(Descriptors.descList) @@ -37,7 +40,7 @@ class MolToRDKitPhysChem(MolToDescriptorPipelineElement): - """PipelineElement for creating a Descriptor vector based on RDKit phys-chem properties.""" + """Descriptor based on RDKit phys-chem properties.""" _descriptor_list: list[str] @@ -45,7 +48,6 @@ def __init__( self, descriptor_list: list[str] | None = None, return_with_errors: bool = False, - standardizer: AnyTransformer | None = StandardScaler(), log_exceptions: bool = True, name: str = "Mol2RDKitPhysChem", n_jobs: int = 1, @@ -56,12 +58,11 @@ def __init__( Parameters ---------- descriptor_list: list[str] | None, optional - List of descriptor names to calculate. If None, DEFAULT_DESCRIPTORS are used. + List of descriptor names to calculate. + If None, DEFAULT_DESCRIPTORS are used. return_with_errors: bool, default=False False: Returns an InvalidInstance if any error occurs during calculations. True: Returns a vector with NaN values for failed descriptor calculations. - standardizer: AnyTransformer | None, default=StandardScaler() - Standardizer to use. log_exceptions: bool, default=True Log traceback of exceptions occurring during descriptor calculation. name: str, default="Mol2RDKitPhysChem" @@ -70,13 +71,13 @@ def __init__( Number of jobs to use for parallelization. uuid: str | None, optional UUID of the PipelineElement. If None, a new UUID is generated. + """ self.descriptor_list = descriptor_list # type: ignore self._feature_names = self._descriptor_list self._return_with_errors = return_with_errors self._log_exceptions = log_exceptions super().__init__( - standardizer=standardizer, name=name, n_jobs=n_jobs, uuid=uuid, @@ -99,7 +100,8 @@ def descriptor_list(self, descriptor_list: list[str] | None) -> None: Parameters ---------- descriptor_list: list[str] | None - List of descriptor names to calculate. If None, DEFAULT_DESCRIPTORS are used. + List of descriptor names to calculate. + If None, DEFAULT_DESCRIPTORS are used. Raises ------ @@ -107,25 +109,28 @@ def descriptor_list(self, descriptor_list: list[str] | None) -> None: If an unknown descriptor name is used. ValueError If an empty descriptor_list is used. + """ if descriptor_list is None or descriptor_list is DEFAULT_DESCRIPTORS: # if None or DEFAULT_DESCRIPTORS are used, set the default descriptors self._descriptor_list = DEFAULT_DESCRIPTORS elif len(descriptor_list) == 0: raise ValueError( - "Empty descriptor_list is not allowed. Use None for default descriptors." + "Empty descriptor_list is not allowed. " + "Use None for default descriptors.", ) else: # check all user defined descriptors are valid for descriptor_name in descriptor_list: if descriptor_name not in RDKIT_DESCRIPTOR_DICT: raise ValueError( - f"Unknown descriptor function with name: {descriptor_name}" + f"Unknown descriptor function with name: {descriptor_name}", ) self._descriptor_list = descriptor_list def pretransform_single( - self, value: RDKitMol + self, + value: RDKitMol, ) -> npt.NDArray[np.float64] | InvalidInstance: """Transform a single molecule to a descriptor vector. @@ -139,6 +144,7 @@ def pretransform_single( npt.NDArray[np.float64] | InvalidInstance Descriptor vector for given molecule. Failure is indicated by an InvalidInstance. + """ vec = np.full((len(self._descriptor_list),), np.nan) log_block = rdBase.BlockLogs() @@ -166,6 +172,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter of the pipeline element. + """ parent_dict = dict(super().get_params(deep=deep)) if deep: @@ -190,6 +197,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self Self + """ parameters_shallow_copy = dict(parameters) params_list = ["descriptor_list", "return_with_errors", "log_exceptions"] diff --git a/ruff.toml b/ruff.toml index 04160882..a235c563 100644 --- a/ruff.toml +++ b/ruff.toml @@ -52,6 +52,7 @@ ignore = [ "PLR0913", # too-many-arguments "PGH003", # Blanket type ignore for types "PLW2901", # Redefined loop variable + "PLR6301", # could be a function, class method, or static method (does not respect inheritance) "S311", # suspicious-non-cryptographic-random-usage ] pylint = {max-positional-args=10 } diff --git a/tests/test_elements/test_mol2any/test_mol2concatenated.py b/tests/test_elements/test_mol2any/test_mol2concatenated.py index b3a2a933..ff03b3e7 100644 --- a/tests/test_elements/test_mol2any/test_mol2concatenated.py +++ b/tests/test_elements/test_mol2any/test_mol2concatenated.py @@ -4,14 +4,12 @@ import itertools import unittest -from typing import Any, Literal, get_args +from typing import TYPE_CHECKING, Any, Literal, get_args import numpy as np from rdkit import Chem -from sklearn.preprocessing import StandardScaler from molpipeline import Pipeline -from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import ( Mol2PathFP, @@ -24,9 +22,15 @@ from tests.utils.fingerprints import fingerprints_to_numpy from tests.utils.logging import capture_logs +if TYPE_CHECKING: + from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement + class TestConcatenatedFingerprint(unittest.TestCase): - """Unittest for MolToConcatenatedVector, which calculates concatenated fingerprints.""" + """Unittest for MolToConcatenatedVector. + + The MolToConcatenatedVector calculates concatenated fingerprints. + """ def test_generation(self) -> None: """Test if the feature concatenation works as expected.""" @@ -35,7 +39,7 @@ def test_generation(self) -> None: "sparse", "dense", "explicit_bit_vect", - ] + ], ) smiles = [ @@ -56,19 +60,19 @@ def test_generation(self) -> None: [ ( "RDKitPhysChem", - MolToRDKitPhysChem(standardizer=StandardScaler()), + MolToRDKitPhysChem(), ), ( "MorganFP", MolToMorganFP(return_as=fp_output_type), ), - ] + ], ) pipeline = Pipeline( [ ("smi2mol", SmilesToMol()), ("concat_vector_element", concat_vector_element), - ] + ], ) output = pipeline.fit_transform(smiles) @@ -79,9 +83,9 @@ def test_generation(self) -> None: [ concat_vector_element.element_list[0][1].transform(mol_list), fingerprints_to_numpy( - concat_vector_element.element_list[1][1].transform(mol_list) + concat_vector_element.element_list[1][1].transform(mol_list), ), - ] + ], ) pyschem_component: MolToRDKitPhysChem pyschem_component = concat_vector_element.element_list[0][1] # type: ignore @@ -107,8 +111,8 @@ def test_empty_element_list(self) -> None: ( "RDKitPhysChem", MolToRDKitPhysChem(), - ) - ] + ), + ], ) with self.assertRaises(ValueError): concat_elem.set_params(element_list=[]) @@ -147,7 +151,7 @@ def test_n_features(self) -> None: ) self.assertEqual( MolToConcatenatedVector( - [net_charge_elem, morgan_elem, physchem_elem] + [net_charge_elem, morgan_elem, physchem_elem], ).n_features, net_charge_elem[1].n_features + 16 + physchem_elem[1].n_features, ) @@ -196,7 +200,7 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals feature_names = conc_elem.feature_names if use_feature_names_prefix: - # test feature names are unique if prefix is used or only one element is used + # test feature names are unique self.assertEqual( len(feature_names), len(set(feature_names)), @@ -220,7 +224,11 @@ def test_features_names(self) -> None: # pylint: disable-msg=too-many-locals if use_feature_names_prefix: # feature_names should be prefixed with element name prefixes, feat_names = map( - list, zip(*[name.split("__") for name in relevant_names]) + list, + zip( + *[name.split("__") for name in relevant_names], + strict=True, + ), ) # test feature names are the same self.assertListEqual(elem_feature_names, feat_names) @@ -258,7 +266,8 @@ def test_logging_feature_names_uniqueness(self) -> None: self.assertEqual(len(output), 1) message = output[0] self.assertIn( - "Feature names in MolToConcatenatedVector are not unique", message + "Feature names in MolToConcatenatedVector are not unique", + message, ) self.assertEqual(message.record["level"].name, "WARNING") @@ -300,7 +309,8 @@ def test_getter_setter(self) -> None: self.assertEqual(concat_elem.get_params()["use_feature_names_prefix"], True) # test that there are no duplicates in feature names self.assertEqual( - len(concat_elem.feature_names), len(set(concat_elem.feature_names)) + len(concat_elem.feature_names), + len(set(concat_elem.feature_names)), ) params: dict[str, Any] = { "use_feature_names_prefix": False, @@ -309,7 +319,8 @@ def test_getter_setter(self) -> None: self.assertEqual(concat_elem.get_params()["use_feature_names_prefix"], False) # test that there are duplicates in feature names self.assertNotEqual( - len(concat_elem.feature_names), len(set(concat_elem.feature_names)) + len(concat_elem.feature_names), + len(set(concat_elem.feature_names)), ) diff --git a/tests/test_elements/test_mol2any/test_mol2net_charge.py b/tests/test_elements/test_mol2any/test_mol2net_charge.py index 50aa540b..95ef93d1 100644 --- a/tests/test_elements/test_mol2any/test_mol2net_charge.py +++ b/tests/test_elements/test_mol2any/test_mol2net_charge.py @@ -16,11 +16,11 @@ "c1cc(c(nc1)Cl)C(=O)Nc2c(c3c(s2)CCCCC3)C(=O)N", "Cc1ccc(cc1)S(=O)(=O)Nc2c(c3c(s2)C[C@@H](CC3)C)C(=O)N", "c1cc(oc1)CN=C2C=C(C(=CC2=C(O)[O-])S(=O)(=O)[NH-])Cl", - "C[C@@H]1[C@@H](OP2(O1)(O[C@H]([C@H](O2)C)C)C[NH+]3CCCCC3)C", # this one fails gasteiger charge computation + "C[C@@H]1[C@@H](OP2(O1)(O[C@H]([C@H](O2)C)C)C[NH+]3CCCCC3)C", # this one fails gasteiger charge computation # noqa: E501 ], "expected_net_charges_formal_charge": [2, 0, 0, -2, 1], "expected_net_charges_gasteiger": [2, -1, -1, -2, np.nan], - } + }, ) @@ -28,15 +28,18 @@ class TestNetChargeCalculator(unittest.TestCase): """Unittest for MolToNetCharge, which calculates net charges of molecules.""" def test_net_charge_calculation_formal_charge(self) -> None: - """Test if the net charge calculation works as expected for formal charges.""" - # we need the error filter and reinserter to handle the case where the charge calculation fails + """Test if the net charge calculation works as expected for formal charges. + + Error filter and reinserter are used to handle the case where the charge + calculation fails. + """ error_filter = ErrorFilter(filter_everything=True) pipeline = Pipeline( [ ("smi2mol", SmilesToMol()), ( "net_charge_element", - MolToNetCharge(charge_method="formal_charge", standardizer=None), + MolToNetCharge(charge_method="formal_charge"), ), ("error_filter", error_filter), ( @@ -54,19 +57,22 @@ def test_net_charge_calculation_formal_charge(self) -> None: .reshape(-1, 1), actual_net_charges, equal_nan=True, - ) + ), ) def test_net_charge_calculation_gasteiger(self) -> None: - """Test if the net charge calculation works as expected for gasteiger charges.""" - # we need the error filter and reinserter to handle the case where the charge calculation fails + """Test if the net charge calculation works as expected for gasteiger charges. + + Error filter and reinserter are used to handle the case where the charge + calculation fails. + """ error_filter = ErrorFilter(filter_everything=True) pipeline = Pipeline( [ ("smi2mol", SmilesToMol()), ( "net_charge_element", - MolToNetCharge(charge_method="gasteiger", standardizer=None), + MolToNetCharge(charge_method="gasteiger"), ), ("error_filter", error_filter), ( @@ -84,7 +90,7 @@ def test_net_charge_calculation_gasteiger(self) -> None: .reshape(-1, 1), actual_net_charges, equal_nan=True, - ) + ), ) diff --git a/tests/test_elements/test_mol2any/test_mol2rdkit_phys_chem.py b/tests/test_elements/test_mol2any/test_mol2rdkit_phys_chem.py index 5b013aaf..a1dc4331 100644 --- a/tests/test_elements/test_mol2any/test_mol2rdkit_phys_chem.py +++ b/tests/test_elements/test_mol2any/test_mol2rdkit_phys_chem.py @@ -5,7 +5,6 @@ import numpy as np import pandas as pd -from sklearn.preprocessing import StandardScaler from molpipeline import ErrorFilter, FilterReinserter, Pipeline from molpipeline.any2mol import SmilesToMol @@ -248,51 +247,19 @@ def test_descriptor_calculation(self) -> None: expected_df = pd.read_csv(data_path, sep="\t") descriptor_names = expected_df.drop(columns=["smiles"]).columns.tolist() smi2mol = SmilesToMol() - property_element = MolToRDKitPhysChem( - standardizer=None, descriptor_list=descriptor_names - ) + property_element = MolToRDKitPhysChem(descriptor_list=descriptor_names) pipeline = Pipeline( [ ("smi2mol", smi2mol), ("property_element", property_element), - ] + ], ) smiles = expected_df["smiles"].tolist() - property_vector = expected_df[descriptor_names].values + property_vector = expected_df[descriptor_names].to_numpy() output = pipeline.fit_transform(smiles) self.assertTrue(np.allclose(output, property_vector)) # add assertion here - def test_descriptor_normalization(self) -> None: - """Test if the normalization of RDKitPhysChem Descriptors works as expected.""" - smi2mol = SmilesToMol() - property_element = MolToRDKitPhysChem(standardizer=StandardScaler()) - pipeline = Pipeline( - [ - ("smi2mol", smi2mol), - ("property_element", property_element), - ] - ) - # pylint: disable=duplicate-code # test case molecules are allowed to be duplicated - smiles = [ - "CC", - "CCC", - "CCCO", - "CCNCO", - "C(C)CCO", - "CCO", - "CCCN", - "CCCC", - "CCOC", - "COO", - ] - output = pipeline.fit_transform(smiles) - non_zero_descriptors = output[:, (np.abs(output).sum(axis=0) != 0)] - self.assertTrue( - np.allclose(non_zero_descriptors.mean(axis=0), 0.0) - ) # add assertion here - self.assertTrue(np.allclose(non_zero_descriptors.std(axis=0), 1.0)) - def test_optional_nan_value_handling(self) -> None: """Test the handling of partly failed descriptor calculations.""" ok_smiles_list = [ @@ -304,45 +271,44 @@ def test_optional_nan_value_handling(self) -> None: ] # test with return_with_errors=False - property_element = MolToRDKitPhysChem( - standardizer=None, return_with_errors=False - ) + property_element = MolToRDKitPhysChem(return_with_errors=False) error_filter = ErrorFilter.from_element_list([property_element]) error_replacer = FilterReinserter.from_error_filter( - error_filter, fill_value=np.nan + error_filter, + fill_value=np.nan, ) - # note that we need the error filter and replacer here. Otherwise, the pipeline would fail on any error - # irrespective of the return_with_errors parameter + # note that we need the error filter and replacer here. + # Otherwise, the pipeline would fail on any error irrespective of the + # return_with_errors parameter pipeline = Pipeline( [ ("smi2mol", SmilesToMol()), ("property_element", property_element), ("error_filter", error_filter), ("error_replacer", error_replacer), - ] + ], ) output = pipeline.fit_transform(bad_smiles_list + ok_smiles_list) self.assertEqual(len(output), len(bad_smiles_list + ok_smiles_list)) # check expect-to-fail rows are ALL nan values self.assertTrue( - np.equal(np.isnan(output).all(axis=1), [True, False, False]).all() + np.equal(np.isnan(output).all(axis=1), [True, False, False]).all(), ) # check expected-not-to-fail rows contain zero nan values self.assertTrue( - np.equal(np.isnan(output).any(axis=1), [True, False, False]).all() + np.equal(np.isnan(output).any(axis=1), [True, False, False]).all(), ) # test with return_with_errors=True - property_element2 = MolToRDKitPhysChem( - standardizer=None, return_with_errors=True - ) + property_element2 = MolToRDKitPhysChem(return_with_errors=True) error_filter2 = ErrorFilter.from_element_list([property_element2]) filter_reinserter = FilterReinserter.from_error_filter( - error_filter2, fill_value=np.nan + error_filter2, + fill_value=np.nan, ) pipeline2 = Pipeline( [ @@ -350,18 +316,18 @@ def test_optional_nan_value_handling(self) -> None: ("property_element", property_element2), ("error_filter", error_filter2), ("error_replacer", filter_reinserter), - ] + ], ) output2 = pipeline2.fit_transform(bad_smiles_list + ok_smiles_list) self.assertEqual(len(output2), len(bad_smiles_list + ok_smiles_list)) # check expect-to-fail rows are ALL nan values self.assertTrue( - np.equal(np.isnan(output2).all(axis=1), [False, False, False]).all() + np.equal(np.isnan(output2).all(axis=1), [False, False, False]).all(), ) # check expected-not-to-fail rows contain zero nan values self.assertTrue( - np.equal(np.isnan(output2).any(axis=1), [True, False, False]).all() + np.equal(np.isnan(output2).any(axis=1), [True, False, False]).all(), ) def test_unknown_descriptor_name(self) -> None: @@ -369,7 +335,7 @@ def test_unknown_descriptor_name(self) -> None: self.assertRaises( ValueError, MolToRDKitPhysChem, - **{"descriptor_list": ["__NotADescriptor11Name:)"]}, + descriptor_list=["__NotADescriptor11Name:)"], ) def test_exception_handling(self) -> None: @@ -380,13 +346,15 @@ def test_exception_handling(self) -> None: ( "property_element", MolToRDKitPhysChem( - standardizer=None, return_with_errors=True, log_exceptions=False + return_with_errors=True, + log_exceptions=False, ), ), - ] + ], ) - # Without exception handling [HH] would raise a division-by-zero exception because it has 0 heavy atoms + # Without exception handling [HH] would raise a division-by-zero exception + # because it has 0 heavy atoms output = pipeline.fit_transform(["[HH]"]) self.assertTrue(output.shape == (1, len(DEFAULT_DESCRIPTORS))) @@ -395,7 +363,7 @@ def test_empty_descriptor_list(self) -> None: with self.assertRaises(ValueError) as context: MolToRDKitPhysChem(descriptor_list=[]) self.assertTrue( - str(context.exception).startswith("Empty descriptor_list is not allowed") + str(context.exception).startswith("Empty descriptor_list is not allowed"), ) From 6d801fe7a8fbfca87e687a1c052ba48f0ec99371 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Wed, 23 Apr 2025 17:25:34 +0200 Subject: [PATCH 3/5] Limit support to p311 and above (#166) * require python 3.11 and higher * remove tests for 3.10 * import Self from typing * make zip strict * Optional -> None * remove typing.Union * import Callable from collections --- .github/workflows/linting.yml | 2 +- .../abstract_pipeline_elements/core.py | 10 +- .../mol2any/mol2bitvector.py | 47 +++++--- .../mol2mol/filter.py | 87 ++++++++------ molpipeline/any2mol/sdf2mol.py | 8 +- molpipeline/any2mol/smiles2mol.py | 7 +- molpipeline/error_handling.py | 7 +- molpipeline/estimators/chemprop/abstract.py | 8 +- .../featurizer_wrapper/graph_wrapper.py | 7 +- molpipeline/estimators/chemprop/models.py | 7 +- .../connected_component_clustering.py | 7 +- .../estimators/leader_picker_clustering.py | 8 +- .../estimators/murcko_scaffold_clustering.py | 8 +- molpipeline/estimators/nearest_neighbor.py | 33 +++--- .../estimators/similarity_transformation.py | 7 +- molpipeline/experimental/custom_filter.py | 8 +- .../experimental/explainability/explainer.py | 9 +- .../explainability/visualization/heatmaps.py | 25 +++-- .../mol2any/mol2concatinated_vector.py | 53 +++++---- molpipeline/mol2any/mol2morgan_fingerprint.py | 8 +- molpipeline/mol2any/mol2net_charge.py | 9 +- molpipeline/mol2any/mol2path_fingerprint.py | 8 +- molpipeline/mol2any/mol2rdkit_phys_chem.py | 8 +- molpipeline/mol2mol/filter.py | 7 +- molpipeline/mol2mol/reaction.py | 8 +- molpipeline/mol2mol/scaffolds.py | 7 +- molpipeline/mol2mol/standardization.py | 7 +- molpipeline/pipeline/_molpipeline.py | 7 +- molpipeline/pipeline/_skl_pipeline.py | 106 +++++++++++------- molpipeline/post_prediction.py | 20 ++-- molpipeline/utils/comparison.py | 5 +- molpipeline/utils/kernel.py | 10 +- molpipeline/utils/molpipeline_types.py | 32 ++---- molpipeline/utils/subpipeline.py | 3 +- molpipeline/utils/value_conversions.py | 2 +- ...caffold_split_with_custom_estimators.ipynb | 2 +- ...ed_03_introduction_to_explainable_ai.ipynb | 2 +- pyproject.toml | 1 + ruff.toml | 2 +- .../test_any2mol/test_auto2mol.py | 72 +++++++----- tests/test_elements/test_error_handling.py | 18 +-- .../test_mol2any/test_mol2bin.py | 12 +- .../test_leader_picker_clustering.py | 9 +- .../test_model_selection/test_splitter.py | 18 ++- tests/test_pipeline.py | 43 ++++--- tests/test_utils/test_comparison.py | 2 +- tests/test_utils/test_json_operations.py | 4 +- tests/utils/execution_count.py | 7 +- tests/utils/mock_element.py | 7 +- 49 files changed, 396 insertions(+), 398 deletions(-) diff --git a/.github/workflows/linting.yml b/.github/workflows/linting.yml index 8b44c8e6..d9693960 100644 --- a/.github/workflows/linting.yml +++ b/.github/workflows/linting.yml @@ -182,7 +182,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.10", "3.11", "3.12", "3.13"] + python-version: ["3.11", "3.12", "3.13"] steps: - uses: actions/checkout@v3 - name: Set up Python ${{ matrix.python-version }} diff --git a/molpipeline/abstract_pipeline_elements/core.py b/molpipeline/abstract_pipeline_elements/core.py index 0863d64f..9e26bce0 100644 --- a/molpipeline/abstract_pipeline_elements/core.py +++ b/molpipeline/abstract_pipeline_elements/core.py @@ -6,13 +6,7 @@ import copy import inspect from collections.abc import Iterable -from typing import Any, NamedTuple, Union - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - +from typing import Any, NamedTuple, Self from uuid import uuid4 from joblib import Parallel, delayed @@ -54,7 +48,7 @@ def __repr__(self) -> str: ) -OptionalMol = Union[RDKitMol, InvalidInstance] +OptionalMol = RDKitMol | InvalidInstance class RemovedInstance: # pylint: disable=too-few-public-methods diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py index b4b0e27c..ee64bc85 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py @@ -5,12 +5,7 @@ import abc import copy from collections.abc import Iterable -from typing import Any, Literal, get_args, overload - -try: - from typing import Self, TypeAlias # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self, TypeAlias +from typing import Any, Literal, Self, TypeAlias, get_args, overload import numpy as np import numpy.typing as npt @@ -56,10 +51,11 @@ def __init__( rdkit.DataStructs.cDataStructs.ExplicitBitVect. name: str Name of PipelineElement. - n_jobs: + n_jobs: int, default=1 Number of jobs. - uuid: Optional[str] + uuid: str | None, optional Unique identifier. + """ super().__init__( name=name, @@ -80,17 +76,20 @@ def feature_names(self) -> list[str]: @overload def assemble_output( # type: ignore - self, value_list: Iterable[npt.NDArray[np.int_]] + self, + value_list: Iterable[npt.NDArray[np.int_]], ) -> npt.NDArray[np.int_]: ... @overload def assemble_output( - self, value_list: Iterable[dict[int, int]] + self, + value_list: Iterable[dict[int, int]], ) -> sparse.csr_matrix: ... @overload def assemble_output( - self, value_list: Iterable[ExplicitBitVect] + self, + value_list: Iterable[ExplicitBitVect], ) -> list[ExplicitBitVect]: ... def assemble_output( @@ -115,6 +114,7 @@ def assemble_output( ------- sparse.csr_matrix | npt.NDArray[np.int_] | list[ExplicitBitVect] Matrix of Morgan-fingerprint features. + """ if self._return_as == "explicit_bit_vect": # return as list of RDkit's ExplicitBitVect @@ -138,6 +138,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary of parameter names and values. + """ parameters = super().get_params(deep) if deep: @@ -172,7 +173,7 @@ def set_params(self, **parameters: Any) -> Self: if return_as not in get_args(OutputDatatype): raise ValueError( f"return_as has to be one of {get_args(OutputDatatype)}! " - f"(Received: {return_as})" + f"(Received: {return_as})", ) self._return_as = return_as super().set_params(**parameter_dict_copy) @@ -190,12 +191,14 @@ def transform(self, values: list[RDKitMol]) -> sparse.csr_matrix: ------- sparse.csr_matrix Sparse matrix of Morgan-fingerprint features. + """ return super().transform(values) @abc.abstractmethod def pretransform_single( - self, value: RDKitMol + self, + value: RDKitMol, ) -> dict[int, int] | npt.NDArray[np.int_] | ExplicitBitVect: """Transform mol to dict, where items encode columns indices and values, respectively. @@ -208,6 +211,7 @@ def pretransform_single( ------- dict[int, int] Dictionary to encode row in matrix. Keys: column index, values: column value. + """ @@ -237,6 +241,7 @@ def __init__( Number of jobs. uuid: str | None, optional Unique identifier. + """ super().__init__( return_as=return_as, @@ -254,10 +259,12 @@ def _get_fp_generator(self) -> rdFingerprintGenerator.FingerprintGenerator64: ------- rdFingerprintGenerator.FingerprintGenerator64 Fingerprint generator. + """ def pretransform_single( - self, value: RDKitMol + self, + value: RDKitMol, ) -> ExplicitBitVect | npt.NDArray[np.int_] | dict[int, int]: """Transform a single compound to a dictionary. @@ -274,6 +281,7 @@ def pretransform_single( If return_as is "explicit_bit_vect" return ExplicitBitVect. If return_as is "dense" return numpy array. If return_as is "sparse" return dictionary with feature-position as key and count as value. + """ fingerprint_generator = self._get_fp_generator() if self._return_as == "dense": @@ -306,6 +314,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary of parameter names and values. + """ parameters = super().get_params(deep) if deep: @@ -327,6 +336,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self Copied object with updated parameters. + """ parameter_dict_copy = dict(parameters) counted = parameter_dict_copy.pop("counted", None) @@ -400,7 +410,7 @@ def __init__( self._radius = radius else: raise ValueError( - f"Number of bits has to be a positive integer! (Received: {radius})" + f"Number of bits has to be a positive integer! (Received: {radius})", ) def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -415,6 +425,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary of parameter names and values. + """ parameters = super().get_params(deep) if deep: @@ -440,6 +451,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self PipelineElement with updated parameters. + """ parameter_copy = dict(parameters) radius = parameter_copy.pop("radius", None) @@ -477,11 +489,13 @@ def _explain_rdmol(self, mol_obj: RDKitMol) -> dict[int, list[tuple[int, int]]]: ------- dict[int, list[tuple[int, int]]] Dictionary with mapping from bit to atom index and radius. + """ raise NotImplementedError def bit2atom_mapping( - self, mol_obj: RDKitMol + self, + mol_obj: RDKitMol, ) -> dict[int, list[CircularAtomEnvironment]]: """Obtain set of atoms for all features. @@ -494,6 +508,7 @@ def bit2atom_mapping( ------- dict[int, list[CircularAtomEnvironment]] Dictionary with mapping from bit to encoded AtomEnvironments (which contain atom indices). + """ bit2atom_dict = self._explain_rdmol(mol_obj) result_dict: dict[int, list[CircularAtomEnvironment]] = {} diff --git a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py index ea9e3df2..19c3d6bb 100644 --- a/molpipeline/abstract_pipeline_elements/mol2mol/filter.py +++ b/molpipeline/abstract_pipeline_elements/mol2mol/filter.py @@ -2,12 +2,7 @@ import abc from collections.abc import Mapping, Sequence -from typing import Any, Literal, Optional, TypeAlias, Union - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Literal, Self, TypeAlias from molpipeline.abstract_pipeline_elements.core import ( InvalidInstance, @@ -29,7 +24,9 @@ def _within_boundaries( - lower_bound: Optional[float], upper_bound: Optional[float], property_value: float + lower_bound: float | None, + upper_bound: float | None, + property_value: float, ) -> bool: """Check if a value is within the specified boundaries. @@ -37,9 +34,9 @@ def _within_boundaries( Parameters ---------- - lower_bound: Optional[float] + lower_bound: float | None Lower boundary. - upper_bound: Optional[float] + upper_bound: float | None Upper boundary. property_value: float Property value to check. @@ -48,6 +45,7 @@ def _within_boundaries( ------- bool True if the value is within the boundaries, else False. + """ if lower_bound is not None and property_value < lower_bound: return False @@ -66,6 +64,7 @@ class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): - mode = "any" & keep_matches = False: Must not match any filter element. - mode = "all" & keep_matches = True: Needs to match all filter elements. - mode = "all" & keep_matches = False: Must not match all filter elements. + """ keep_matches: bool @@ -73,35 +72,44 @@ class BaseKeepMatchesFilter(MolToMolPipelineElement, abc.ABC): def __init__( self, - filter_elements: Union[ - Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], - Sequence[Any], - ], + filter_elements: ( + Mapping[ + Any, + FloatCountRange | IntCountRange | IntOrIntCountRange, + ] + | Sequence[Any] + ), keep_matches: bool = True, mode: FilterModeType = "any", - name: Optional[str] = None, + name: str | None = None, n_jobs: int = 1, - uuid: Optional[str] = None, + uuid: str | None = None, ) -> None: """Initialize BasePatternsFilter. Parameters ---------- - filter_elements: Union[Mapping[Any, Union[FloatCountRange, IntCountRange, IntOrIntCountRange]], Sequence[Any]] - List of filter elements. Typically can be a list of patterns or a dictionary with patterns as keys and - an int for exact count or a tuple of minimum and maximum. - NOTE: for each child class, the type of filter_elements must be specified by the filter_elements setter. - keep_matches: bool, optional (default: True) + filter_elements: Mapping[ + Any, FloatCountRange | IntCountRange | IntOrIntCountRange + ] | Sequence[Any] + List of filter elements. Typically can be a list of patterns or a dictionary + with patterns as keys and an int for exact count or a tuple of minimum and + maximum. + NOTE: for each child class, the type of filter_elements must be specified by + the filter_elements setter. + keep_matches: bool, default=True If True, molecules containing the specified patterns are kept, else removed. - mode: FilterModeType, optional (default: "any") - If "any", at least one of the specified patterns must be present in the molecule. + mode: FilterModeType, default="any" + If "any", at least one of the specified patterns must be present in the + molecule. If "all", all of the specified patterns must be present in the molecule. - name: Optional[str], optional (default: None) + name: str | None, optional Name of the pipeline element. - n_jobs: int, optional (default: 1) + n_jobs: int, default=1 Number of parallel jobs to use. - uuid: str, optional (default: None) + uuid: str, optional Unique identifier of the pipeline element. + """ super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) self.filter_elements = filter_elements # type: ignore @@ -119,14 +127,15 @@ def filter_elements( @abc.abstractmethod def filter_elements( self, - filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]], + filter_elements: Mapping[Any, FloatCountRange] | Sequence[Any], ) -> None: """Set filter elements as dict. Parameters ---------- - filter_elements: Union[Mapping[Any, FloatCountRange], Sequence[Any]] + filter_elements: Mapping[Any, FloatCountRange] | Sequence[Any] List of filter elements. + """ def set_params(self, **parameters: Any) -> Self: @@ -141,6 +150,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self Self. + """ parameter_copy = dict(parameters) if "keep_matches" in parameter_copy: @@ -164,6 +174,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameters of BaseKeepMatchesFilter. + """ params = super().get_params(deep=deep) params["keep_matches"] = self.keep_matches @@ -172,7 +183,8 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return params def pretransform_single( # pylint: disable=too-many-return-statements - self, value: RDKitMol + self, + value: RDKitMol, ) -> OptionalMol: """Invalidate or validate molecule based on specified filter. @@ -246,7 +258,9 @@ def pretransform_single( # pylint: disable=too-many-return-statements @abc.abstractmethod def _calculate_single_element_value( - self, filter_element: Any, value: RDKitMol + self, + filter_element: Any, + value: RDKitMol, ) -> float: """Calculate the value of a single match. @@ -261,6 +275,7 @@ def _calculate_single_element_value( ------- float Value of the match. + """ @@ -269,7 +284,7 @@ class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): Attributes ---------- - filter_elements: Union[Sequence[str], Mapping[str, IntOrIntCountRange]] + filter_elements: Mapping[str, IntCountRange] List of patterns to allow in molecules. Alternatively, a dictionary can be passed with patterns as keys and an int for exact count or a tuple of minimum and maximum. @@ -282,6 +297,7 @@ class BasePatternsFilter(BaseKeepMatchesFilter, abc.ABC): - mode = "any" & keep_matches = False: Must not match any filter element. - mode = "all" & keep_matches = True: Needs to match all filter elements. - mode = "all" & keep_matches = False: Must not match all filter elements. + """ _filter_elements: Mapping[str, IntCountRange] @@ -294,14 +310,15 @@ def filter_elements(self) -> Mapping[str, IntCountRange]: @filter_elements.setter def filter_elements( self, - patterns: Union[list[str], Mapping[str, IntOrIntCountRange]], + patterns: list[str] | Mapping[str, IntOrIntCountRange], ) -> None: """Set allowed filter elements (patterns) as dict. Parameters ---------- - patterns: Union[list[str], Mapping[str, IntOrIntCountRange]] + patterns: list[str] | Mapping[str, IntOrIntCountRange] List of patterns. + """ if isinstance(patterns, (list, set)): self._filter_elements = dict.fromkeys(patterns, (1, None)) @@ -351,10 +368,13 @@ def _pattern_to_mol(self, pattern: str) -> RDKitMol: ------- RDKitMol RDKitMol object of the pattern. + """ def _calculate_single_element_value( - self, filter_element: Any, value: RDKitMol + self, + filter_element: Any, + value: RDKitMol, ) -> int: """Calculate a single match count for a molecule. @@ -369,5 +389,6 @@ def _calculate_single_element_value( ------- int smarts match count value. + """ return len(value.GetSubstructMatches(self.patterns_mol_dict[filter_element])) diff --git a/molpipeline/any2mol/sdf2mol.py b/molpipeline/any2mol/sdf2mol.py index a6bea806..549e914e 100644 --- a/molpipeline/any2mol/sdf2mol.py +++ b/molpipeline/any2mol/sdf2mol.py @@ -2,14 +2,8 @@ from __future__ import annotations -from typing import Any, Literal - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - import copy +from typing import Any, Literal, Self from rdkit import Chem diff --git a/molpipeline/any2mol/smiles2mol.py b/molpipeline/any2mol/smiles2mol.py index 214d1edb..a2834206 100644 --- a/molpipeline/any2mol/smiles2mol.py +++ b/molpipeline/any2mol/smiles2mol.py @@ -2,12 +2,7 @@ from __future__ import annotations -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self from rdkit import Chem diff --git a/molpipeline/error_handling.py b/molpipeline/error_handling.py index 7626777a..87cad027 100644 --- a/molpipeline/error_handling.py +++ b/molpipeline/error_handling.py @@ -3,12 +3,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence -from typing import Any, Generic, TypeVar - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Generic, Self, TypeVar import numpy as np import numpy.typing as npt diff --git a/molpipeline/estimators/chemprop/abstract.py b/molpipeline/estimators/chemprop/abstract.py index 3e8bf9eb..f95167dc 100644 --- a/molpipeline/estimators/chemprop/abstract.py +++ b/molpipeline/estimators/chemprop/abstract.py @@ -2,13 +2,7 @@ import abc from collections.abc import Sequence -from typing import Any - -# pylint: disable=duplicate-code -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self import numpy as np import numpy.typing as npt diff --git a/molpipeline/estimators/chemprop/featurizer_wrapper/graph_wrapper.py b/molpipeline/estimators/chemprop/featurizer_wrapper/graph_wrapper.py index e5c8f80f..5697c052 100644 --- a/molpipeline/estimators/chemprop/featurizer_wrapper/graph_wrapper.py +++ b/molpipeline/estimators/chemprop/featurizer_wrapper/graph_wrapper.py @@ -1,12 +1,7 @@ """Wrapper for Chemprop GraphFeaturizer.""" from dataclasses import InitVar -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self from chemprop.featurizers.molgraph import ( SimpleMoleculeMolGraphFeaturizer as _SimpleMoleculeMolGraphFeaturizer, diff --git a/molpipeline/estimators/chemprop/models.py b/molpipeline/estimators/chemprop/models.py index 25f1f8ac..aa17d093 100644 --- a/molpipeline/estimators/chemprop/models.py +++ b/molpipeline/estimators/chemprop/models.py @@ -1,12 +1,7 @@ """Wrapper for Chemprop to make it compatible with scikit-learn.""" from collections.abc import Sequence -from typing import Any - -try: - from typing import Self -except ImportError: - from typing_extensions import Self +from typing import Any, Self import numpy as np import numpy.typing as npt diff --git a/molpipeline/estimators/connected_component_clustering.py b/molpipeline/estimators/connected_component_clustering.py index ed51f3bb..f8a2cf2e 100644 --- a/molpipeline/estimators/connected_component_clustering.py +++ b/molpipeline/estimators/connected_component_clustering.py @@ -3,7 +3,7 @@ from __future__ import annotations from numbers import Real -from typing import Any +from typing import Any, Self import numpy as np import numpy.typing as npt @@ -13,11 +13,6 @@ from sklearn.utils._param_validation import Interval from sklearn.utils.validation import validate_data -try: - from typing import Self -except ImportError: - from typing_extensions import Self - from molpipeline.estimators.algorithm.connected_component_clustering import ( calc_chunk_size_from_memory_requirement, connected_components_iterative_algorithm, diff --git a/molpipeline/estimators/leader_picker_clustering.py b/molpipeline/estimators/leader_picker_clustering.py index 1640a3ea..01ebfa6e 100644 --- a/molpipeline/estimators/leader_picker_clustering.py +++ b/molpipeline/estimators/leader_picker_clustering.py @@ -2,8 +2,10 @@ from __future__ import annotations +from collections.abc import Sequence from itertools import compress from numbers import Real +from typing import Any, Self import numpy as np import numpy.typing as npt @@ -13,12 +15,6 @@ from sklearn.base import BaseEstimator, ClusterMixin, _fit_context from sklearn.utils._param_validation import Interval -try: - from collections.abc import Sequence - from typing import Any, Self -except ImportError: - from typing_extensions import Self - class LeaderPickerClustering(ClusterMixin, BaseEstimator): """LeaderPicker clustering estimator (a sphere exclusion clustering algorithm).""" diff --git a/molpipeline/estimators/murcko_scaffold_clustering.py b/molpipeline/estimators/murcko_scaffold_clustering.py index fee7545c..89d00ba9 100644 --- a/molpipeline/estimators/murcko_scaffold_clustering.py +++ b/molpipeline/estimators/murcko_scaffold_clustering.py @@ -3,7 +3,7 @@ from __future__ import annotations from numbers import Integral -from typing import Any, Literal +from typing import Any, Literal, Self import numpy as np import numpy.typing as npt @@ -17,12 +17,6 @@ from molpipeline.mol2mol import EmptyMoleculeFilter, MakeScaffoldGeneric, MurckoScaffold from molpipeline.utils.molpipeline_types import AnyStep, OptionalMol -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - - __all__ = [ "MurckoScaffoldClustering", ] diff --git a/molpipeline/estimators/nearest_neighbor.py b/molpipeline/estimators/nearest_neighbor.py index 2bb59741..cccca2ce 100644 --- a/molpipeline/estimators/nearest_neighbor.py +++ b/molpipeline/estimators/nearest_neighbor.py @@ -2,14 +2,8 @@ from __future__ import annotations -from collections.abc import Sequence -from typing import Any, Callable, Literal, Union - -try: - from typing import Self -except ImportError: - from typing_extensions import Self - +from collections.abc import Callable, Sequence +from typing import Any, Literal, Self import numpy as np import numpy.typing as npt @@ -36,10 +30,10 @@ "precomputed", ] -AllMetrics = Union[ - SklearnNativeMetrics, - Callable[[Any, Any], float | npt.NDArray[np.float64] | Sequence[float]], -] +AllMetrics = ( + SklearnNativeMetrics + | Callable[[Any, Any], float | npt.NDArray[np.float64] | Sequence[float]] +) class NamedNearestNeighbors(NearestNeighbors): # pylint: disable=too-many-ancestors @@ -69,18 +63,23 @@ def __init__( algorithm : {'auto', 'ball_tree', 'kd_tree', 'brute'}, optional (default = 'auto') Algorithm used to compute the nearest neighbors. leaf_size : int, optional (default = 30) - Leaf size passed to BallTree or KDTree. This can affect the speed of the construction and query, - as well as the memory required to store the tree. The optimal value depends on the nature of the problem. - metric : Union[str, Callable], optional (default = 'minkowski') + Leaf size passed to BallTree or KDTree. + This can affect the speed of the construction and query, as well as the + memory required to store the tree. + The optimal value depends on the nature of the problem. + metric : str | Callable, default='minkowski' The distance metric to use for the tree. - The default metric is minkowski, and with p=2 is equivalent to the standard Euclidean metric. + The default metric is minkowski, and with p=2 is equivalent to the standard + Euclidean metric. p : int, optional (default = 2) Power parameter for the Minkowski metric. metric_params : dict, optional (default = None) Additional keyword arguments for the metric function. n_jobs : int, optional (default = None) - The number of parallel jobs to run for neighbors search. None means 1 unless in a joblib.parallel_backend context. + The number of parallel jobs to run for neighbors search. + None means 1 unless in a joblib.parallel_backend context. -1 means using all processors. + """ super().__init__( n_neighbors=n_neighbors, diff --git a/molpipeline/estimators/similarity_transformation.py b/molpipeline/estimators/similarity_transformation.py index 47be2a78..dd6dab09 100644 --- a/molpipeline/estimators/similarity_transformation.py +++ b/molpipeline/estimators/similarity_transformation.py @@ -2,12 +2,7 @@ from __future__ import annotations -from typing import Any - -try: - from typing import Self -except ImportError: - from typing_extensions import Self +from typing import Any, Self import numpy as np import numpy.typing as npt diff --git a/molpipeline/experimental/custom_filter.py b/molpipeline/experimental/custom_filter.py index aece9c95..a049b4d3 100644 --- a/molpipeline/experimental/custom_filter.py +++ b/molpipeline/experimental/custom_filter.py @@ -2,12 +2,8 @@ from __future__ import annotations -from typing import Any, Callable - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from collections.abc import Callable +from typing import Any, Self from molpipeline.abstract_pipeline_elements.core import ( InvalidInstance, diff --git a/molpipeline/experimental/explainability/explainer.py b/molpipeline/experimental/explainability/explainer.py index e2ce9079..85085c31 100644 --- a/molpipeline/experimental/explainability/explainer.py +++ b/molpipeline/experimental/explainability/explainer.py @@ -3,7 +3,8 @@ from __future__ import annotations import abc -from typing import Any, Callable +from collections.abc import Callable +from typing import Any import numpy as np import numpy.typing as npt @@ -11,11 +12,7 @@ import shap from scipy.sparse import issparse, spmatrix from sklearn.base import BaseEstimator - -try: - from typing import override # type: ignore[attr-defined] -except ImportError: - from typing_extensions import override +from typing_extensions import override from molpipeline import Pipeline from molpipeline.abstract_pipeline_elements.core import InvalidInstance, OptionalMol diff --git a/molpipeline/experimental/explainability/visualization/heatmaps.py b/molpipeline/experimental/explainability/visualization/heatmaps.py index 19f65b25..b0d19e8b 100644 --- a/molpipeline/experimental/explainability/visualization/heatmaps.py +++ b/molpipeline/experimental/explainability/visualization/heatmaps.py @@ -6,8 +6,7 @@ """ import abc -from collections.abc import Sequence -from typing import Callable +from collections.abc import Callable, Sequence import numpy as np import numpy.typing as npt @@ -85,13 +84,16 @@ def grid_field_center(self, x_idx: int, y_idx: int) -> tuple[float, float]: ------- tuple[float, float] Coordinates of center of cell. + """ x_coord = min(self.x_lim) + self.dx * (x_idx + 0.5) y_coord = min(self.y_lim) + self.dy * (y_idx + 0.5) return x_coord, y_coord def grid_field_lim( - self, x_idx: int, y_idx: int + self, + x_idx: int, + y_idx: int, ) -> tuple[tuple[float, float], tuple[float, float]]: """Get x and y coordinates for the upper left and lower right position of specified pixel. @@ -106,6 +108,7 @@ def grid_field_lim( ------- tuple[tuple[float, float], tuple[float, float]] Coordinates of upper left and lower right corner of cell. + """ upper_left = ( min(self.x_lim) + self.dx * x_idx, @@ -140,6 +143,7 @@ def __init__( Resolution (number of cells) along x-axis. y_res: int Resolution (number of cells) along y-axis. + """ super().__init__(x_lim, y_lim, x_res, y_res) self.color_grid = np.ones((self.x_res, self.y_res, 4)) @@ -179,6 +183,7 @@ def __init__( Resolution (number of cells) along y-axis. function_list: list[Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]], optional List of functions to be evaluated for each cell, by default None. + """ super().__init__(x_lim, y_lim, x_res, y_res) if function_list is not None: @@ -188,7 +193,8 @@ def __init__( self.values = np.zeros((self.x_res, self.y_res)) def add_function( - self, function: Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]] + self, + function: Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]], ) -> None: """Add a function to the grid which is evaluated for each cell. @@ -216,10 +222,10 @@ def evaluate(self) -> None: """ self.values = np.zeros((self.x_res, self.y_res)) x_y0_list = np.array( - [self.grid_field_center(x, 0)[0] for x in range(self.x_res)] + [self.grid_field_center(x, 0)[0] for x in range(self.x_res)], ) x0_y_list = np.array( - [self.grid_field_center(0, y)[1] for y in range(self.y_res)] + [self.grid_field_center(0, y)[1] for y in range(self.y_res)], ) xv, yv = np.meshgrid(x_y0_list, x0_y_list) xv = xv.ravel() @@ -231,7 +237,7 @@ def evaluate(self) -> None: if values.shape != self.values.shape: raise AssertionError( f"Function does not return correct shape. " - f"Shape was {(values.shape, self.values.shape)}" + f"Shape was {(values.shape, self.values.shape)}", ) self.values += values @@ -253,6 +259,7 @@ def map2color( ------- ColorGrid ColorGrid with colors corresponding to ValueGrid. + """ color_grid = ColorGrid(self.x_lim, self.y_lim, self.x_res, self.y_res) norm = normalizer(self.values) @@ -275,6 +282,7 @@ def get_color_normalizer_from_data( ------- colors.Normalize Normalizer for colors. + """ abs_max = np.max(np.abs(values)) normalizer = colors.Normalize(vmin=-abs_max, vmax=abs_max) @@ -292,12 +300,13 @@ def color_canvas(canvas: Draw.MolDraw2D, color_grid: ColorGrid) -> None: RDKit Draw.MolDraw2D canvas. color_grid: ColorGrid ColorGrid object to be drawn on the canvas. + """ # draw only grid points whose color is not white. # we check for the exact values of white (1,1,1). np.isclose returns almost the same pixels but is slightly slower. mask = np.where(~np.all(color_grid.color_grid[:, :, :3] == [1, 1, 1], axis=2)) - for x, y in zip(*mask): + for x, y in zip(*mask, strict=True): upper_left, lower_right = color_grid.grid_field_lim(x, y) upper_left, lower_right = Point2D(*upper_left), Point2D(*lower_right) canvas.SetColour(tuple(color_grid.color_grid[x, y])) diff --git a/molpipeline/mol2any/mol2concatinated_vector.py b/molpipeline/mol2any/mol2concatinated_vector.py index 6289a6f7..5955b4d1 100644 --- a/molpipeline/mol2any/mol2concatinated_vector.py +++ b/molpipeline/mol2any/mol2concatinated_vector.py @@ -3,12 +3,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self import numpy as np import numpy.typing as npt @@ -74,7 +69,8 @@ def __init__( self._set_element_execution_details(self._element_list) # set feature names self._feature_names = self._create_feature_names( - self._element_list, self._use_feature_names_prefix + self._element_list, + self._use_feature_names_prefix, ) self.set_params(**kwargs) @@ -101,7 +97,7 @@ def n_features(self) -> int: feature_count += element.n_bits else: raise ValueError( - f"Element {element} does not have n_features or n_bits." + f"Element {element} does not have n_features or n_bits.", ) return feature_count @@ -135,18 +131,19 @@ def _create_feature_names( ------- list[str] List of feature names. + """ feature_names = [] for name, element in element_list: if not hasattr(element, "feature_names"): raise ValueError( - f"Element {element} does not have feature_names attribute." + f"Element {element} does not have feature_names attribute.", ) if use_feature_names_prefix: # use element name as prefix feature_names.extend( - [f"{name}__{feature}" for feature in element.feature_names] # type: ignore[attr-defined] + [f"{name}__{feature}" for feature in element.feature_names], # type: ignore[attr-defined] ) else: feature_names.extend(element.feature_names) # type: ignore[attr-defined] @@ -155,12 +152,13 @@ def _create_feature_names( logger.warning( "Feature names in MolToConcatenatedVector are not unique." " Set use_feature_names_prefix=True and use unique pipeline element" - " names to avoid this." + " names to avoid this.", ) return feature_names def _set_element_execution_details( - self, element_list: list[tuple[str, MolToAnyPipelineElement]] + self, + element_list: list[tuple[str, MolToAnyPipelineElement]], ) -> None: """Set output type and requires fitting for the concatenated vector. @@ -168,6 +166,7 @@ def _set_element_execution_details( ---------- element_list: list[tuple[str, MolToAnyPipelineElement]] List of pipeline elements. + """ output_types = set() for _, element in self._element_list: @@ -194,6 +193,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameters defining the object. + """ parameters = super().get_params(deep) if deep: @@ -201,7 +201,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: (str(name), clone(ele)) for name, ele in self.element_list ] parameters["use_feature_names_prefix"] = bool( - self._use_feature_names_prefix + self._use_feature_names_prefix, ) else: parameters["element_list"] = self.element_list @@ -213,7 +213,9 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: return parameters def _set_element_list( - self, parameter_copy: dict[str, Any], **parameters: Any + self, + parameter_copy: dict[str, Any], + **parameters: Any, ) -> tuple[dict[str, Any], dict[str, Any]]: """Set the element list and run necessary configurations. @@ -233,6 +235,7 @@ def _set_element_list( ------- tuple[dict[str, Any], dict[str, Any]] Updated parameter_copy and parameters. + """ element_list = parameter_copy.pop("element_list", None) if element_list is not None: @@ -275,12 +278,14 @@ def set_params(self, **parameters: Any) -> Self: ------- Self Mol2ConcatenatedVector object with updated parameters. + """ parameter_copy = dict(parameters) # handle element_list parameter_copy, parameters = self._set_element_list( - parameter_copy, **parameters + parameter_copy, + **parameters, ) # handle use_feature_names_prefix @@ -312,6 +317,7 @@ def assemble_output( ------- npt.NDArray[np.float64] Matrix of shape (n_molecules, n_features) with concatenated features specified during init. + """ return np.vstack(list(value_list)) @@ -327,6 +333,7 @@ def transform(self, values: list[RDKitMol]) -> npt.NDArray[np.float64]: ------- npt.NDArray[np.float64] Matrix of shape (n_molecules, n_features) with concatenated features specified during init. + """ output: npt.NDArray[np.float64] = super().transform(values) return output @@ -349,13 +356,15 @@ def fit( ------- Self Fitted pipeline element. + """ for pipeline_element in self._element_list: pipeline_element[1].fit(values) return self def pretransform_single( - self, value: RDKitMol + self, + value: RDKitMol, ) -> list[npt.NDArray[np.float64] | dict[int, int]] | InvalidInstance: """Get pretransform of each element and concatenate for output. @@ -369,6 +378,7 @@ def pretransform_single( list[npt.NDArray[np.float64] | dict[int, int]] | InvalidInstance List of pretransformed values of each pipeline element. If any element returns None, InvalidInstance is returned. + """ final_vector = [] error_message = "" @@ -395,12 +405,14 @@ def finalize_single(self, value: Any) -> Any: ------- Any Finalized output. + """ final_vector_list = [] - for (_, element), sub_value in zip(self._element_list, value): + for (_, element), sub_value in zip(self._element_list, value, strict=True): final_value = element.finalize_single(sub_value) if isinstance(element, MolToFingerprintPipelineElement) and isinstance( - final_value, dict + final_value, + dict, ): vector = np.zeros(element.n_bits) vector[list(final_value.keys())] = np.array(list(final_value.values())) @@ -422,7 +434,10 @@ def fit_to_result(self, values: Any) -> Self: ------- Self Fitted pipeline element. + """ - for element, value in zip(self._element_list, zip(*values)): + for element, value in zip( + self._element_list, zip(*values, strict=True), strict=True + ): element[1].fit_to_result(value) return self diff --git a/molpipeline/mol2any/mol2morgan_fingerprint.py b/molpipeline/mol2any/mol2morgan_fingerprint.py index 62b974c2..7c40bdd2 100644 --- a/molpipeline/mol2any/mol2morgan_fingerprint.py +++ b/molpipeline/mol2any/mol2morgan_fingerprint.py @@ -2,14 +2,8 @@ from __future__ import annotations # for all the python 3.8 users out there. -from typing import Any, Literal - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - import copy +from typing import Any, Literal, Self from rdkit.Chem import AllChem, rdFingerprintGenerator diff --git a/molpipeline/mol2any/mol2net_charge.py b/molpipeline/mol2any/mol2net_charge.py index b8469675..5eb36916 100644 --- a/molpipeline/mol2any/mol2net_charge.py +++ b/molpipeline/mol2any/mol2net_charge.py @@ -3,12 +3,7 @@ from __future__ import annotations import copy -from typing import TYPE_CHECKING, Any, Literal, TypeAlias - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import TYPE_CHECKING, Any, Literal, Self, TypeAlias import numpy as np import numpy.typing as npt @@ -93,7 +88,7 @@ def _get_net_charge_gasteiger( Returns ------- - Optional[npt.NDArray[np.float64]] + npt.NDArray[np.float64] | InvalidInstance Net charge of the given molecule. """ diff --git a/molpipeline/mol2any/mol2path_fingerprint.py b/molpipeline/mol2any/mol2path_fingerprint.py index ac9c5ba3..19938a03 100644 --- a/molpipeline/mol2any/mol2path_fingerprint.py +++ b/molpipeline/mol2any/mol2path_fingerprint.py @@ -2,14 +2,8 @@ from __future__ import annotations # for all the python 3.8 users out there. -from typing import Any, Literal - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - import copy +from typing import Any, Literal, Self from rdkit.Chem import rdFingerprintGenerator diff --git a/molpipeline/mol2any/mol2rdkit_phys_chem.py b/molpipeline/mol2any/mol2rdkit_phys_chem.py index 7e332a1a..3f127678 100644 --- a/molpipeline/mol2any/mol2rdkit_phys_chem.py +++ b/molpipeline/mol2any/mol2rdkit_phys_chem.py @@ -4,14 +4,8 @@ from __future__ import annotations -from typing import TYPE_CHECKING, Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - import copy +from typing import TYPE_CHECKING, Any, Self import numpy as np import numpy.typing as npt diff --git a/molpipeline/mol2mol/filter.py b/molpipeline/mol2mol/filter.py index c50c01a3..5df39043 100644 --- a/molpipeline/mol2mol/filter.py +++ b/molpipeline/mol2mol/filter.py @@ -4,12 +4,7 @@ from collections import Counter from collections.abc import Mapping, Sequence -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self from loguru import logger from rdkit import Chem diff --git a/molpipeline/mol2mol/reaction.py b/molpipeline/mol2mol/reaction.py index 9bc04e60..82371878 100644 --- a/molpipeline/mol2mol/reaction.py +++ b/molpipeline/mol2mol/reaction.py @@ -4,15 +4,9 @@ from __future__ import annotations -from typing import Any, Literal - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - import copy import warnings +from typing import Any, Literal, Self from rdkit.Chem import AllChem diff --git a/molpipeline/mol2mol/scaffolds.py b/molpipeline/mol2mol/scaffolds.py index fdeac067..55ff4e17 100644 --- a/molpipeline/mol2mol/scaffolds.py +++ b/molpipeline/mol2mol/scaffolds.py @@ -2,12 +2,7 @@ from __future__ import annotations -from typing import Any - -try: - from typing import Self # pylint: disable=no-name-in-module -except ImportError: - from typing_extensions import Self +from typing import Any, Self from rdkit import Chem from rdkit.Chem.Scaffolds import MurckoScaffold as RDKIT_MurckoScaffold diff --git a/molpipeline/mol2mol/standardization.py b/molpipeline/mol2mol/standardization.py index fef9ddbd..f3c6465b 100644 --- a/molpipeline/mol2mol/standardization.py +++ b/molpipeline/mol2mol/standardization.py @@ -2,12 +2,7 @@ from __future__ import annotations -from typing import Any, Union - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self, Union from rdkit import Chem from rdkit.Chem import SaltRemover as rdkit_SaltRemover diff --git a/molpipeline/pipeline/_molpipeline.py b/molpipeline/pipeline/_molpipeline.py index c38e225f..3a9f2def 100644 --- a/molpipeline/pipeline/_molpipeline.py +++ b/molpipeline/pipeline/_molpipeline.py @@ -3,12 +3,7 @@ from __future__ import annotations from collections.abc import Iterable -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self import numpy as np from joblib import Parallel, delayed diff --git a/molpipeline/pipeline/_skl_pipeline.py b/molpipeline/pipeline/_skl_pipeline.py index a4029fb8..8a9d0507 100644 --- a/molpipeline/pipeline/_skl_pipeline.py +++ b/molpipeline/pipeline/_skl_pipeline.py @@ -6,12 +6,7 @@ from collections.abc import Iterable from copy import deepcopy -from typing import Any, Literal, TypeVar, Union - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Literal, Self, TypeVar import joblib import numpy as np @@ -56,7 +51,7 @@ _IndexedStep = tuple[int, str, AnyElement] _AggStep = tuple[list[int], list[str], _MolPipeline] -_AggregatedPipelineStep = Union[_IndexedStep, _AggStep] +_AggregatedPipelineStep = _IndexedStep | _AggStep class Pipeline(_Pipeline): @@ -77,7 +72,7 @@ def __init__( Parameters ---------- - steps: list[tuple[str, Union[AnyTransformer, AnyPredictor, ABCPipelineElement]]] + steps: list[tuple[str, AnyTransformer | AnyPredictor | ABCPipelineElement]] List of (name, Estimator) tuples. memory: str | joblib.Memory | None, optional Path to cache transformers. @@ -85,6 +80,7 @@ def __init__( If True, print additional information. n_jobs: int, optional Number of cores used for aggregated steps. + """ super().__init__(steps, memory=memory, verbose=verbose) self.n_jobs = n_jobs @@ -136,7 +132,7 @@ def _validate_steps(self) -> None: f"All intermediate steps should be " f"transformers and implement fit and transform " f"or be the string 'passthrough' " - f"'{transformer}' (type {type(transformer)}) doesn't" + f"'{transformer}' (type {type(transformer)}) doesn't", ) # We allow last estimator to be None as an identity transformation @@ -148,7 +144,7 @@ def _validate_steps(self) -> None: raise TypeError( f"Last step of Pipeline should implement fit " f"or be the string 'passthrough'. " - f"'{estimator}' (type {type(estimator)}) doesn't" + f"'{estimator}' (type {type(estimator)}) doesn't", ) # validate post-processing steps @@ -156,7 +152,9 @@ def _validate_steps(self) -> None: _ = self._post_processing_steps() def _iter( - self, with_final: bool = True, filter_passthrough: bool = True + self, + with_final: bool = True, + filter_passthrough: bool = True, ) -> Iterable[_AggregatedPipelineStep]: """Iterate over all non post-processing steps. @@ -189,9 +187,9 @@ def _iter( if last_element is None: last_element = step continue - if not filter_passthrough: - yield last_element - elif step[2] is not None and step[2] != "passthrough": + if not filter_passthrough or ( + step[2] is not None and step[2] != "passthrough" + ): yield last_element last_element = step @@ -292,13 +290,13 @@ def _fit( if isinstance(cloned_transformer, _MolPipeline): if routed_params: step_params = { - "element_parameters": [routed_params[n] for n in name] + "element_parameters": [routed_params[n] for n in name], } else: step_params = {} elif isinstance(name, list): raise AssertionError( - "Names should not be a list, when the step is not a Pipeline" + "Names should not be a list, when the step is not a Pipeline", ) else: step_params = self._get_metadata_for_step( @@ -326,11 +324,11 @@ def _fit( raise AssertionError() if not len(name) == len(step_idx) == len(ele_list): raise AssertionError() - for idx_i, name_i, ele_i in zip(step_idx, name, ele_list): + for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): self.steps[idx_i] = (name_i, ele_i) if y is not None: y = fitted_transformer.co_transform(y) - for idx_i, name_i, ele_i in zip(step_idx, name, ele_list): + for idx_i, name_i, ele_i in zip(step_idx, name, ele_list, strict=True): self.steps[idx_i] = (name_i, ele_i) self._set_error_resinserter() elif isinstance(name, list) or isinstance(step_idx, list): @@ -369,6 +367,7 @@ def _transform( ------- Any Result of calling `transform` on the second last estimator. + """ iter_input = X do_routing = _routing_enabled() @@ -386,13 +385,14 @@ def _transform( if hasattr(transform, "transform"): if do_routing: iter_input = transform.transform( # type: ignore[call-arg] - iter_input, routed_params[name].transform + iter_input, + routed_params[name].transform, ) else: iter_input = transform.transform(iter_input) else: raise AssertionError( - f"Non transformer ocurred in transformation step: {transform}." + f"Non transformer ocurred in transformation step: {transform}.", ) return iter_input @@ -411,6 +411,7 @@ def _non_post_processing_steps( ------- list[AnyStep] List of steps before the first PostPredictionTransformation. + """ non_post_processing_steps: list[AnyStep] = [] start_adding = False @@ -421,7 +422,7 @@ def _non_post_processing_steps( if isinstance(step_estimator, PostPredictionTransformation): raise AssertionError( "PipelineElement of type PostPredictionTransformation occured " - "before the last step." + "before the last step.", ) non_post_processing_steps.append((step_name, step_estimator)) return list(non_post_processing_steps[::-1]) @@ -433,6 +434,7 @@ def _post_processing_steps(self) -> list[tuple[str, PostPredictionTransformation ------- list[tuple[str, PostPredictionTransformation]] List of tuples containing the name and the PostPredictionTransformation. + """ post_processing_steps = [] for step_name, step_estimator in self.steps[::-1]: @@ -489,7 +491,7 @@ def _agg_non_postpred_steps( @_fit_context( # estimators in Pipeline.steps are not validated yet - prefer_skip_nested_validation=False + prefer_skip_nested_validation=False, ) def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: """Fit the model. @@ -516,6 +518,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: ------- self : object Pipeline with fitted steps. + """ routed_params = self._check_method_params(method="fit", props=fit_params) Xt, yt = self._fit(X, y, routed_params) # pylint: disable=invalid-name @@ -523,7 +526,7 @@ def fit(self, X: Any, y: Any = None, **fit_params: Any) -> Self: if self._final_estimator != "passthrough": if is_empty(Xt): logger.warning( - "All input rows were filtered out! Model is not fitted!" + "All input rows were filtered out! Model is not fitted!", ) else: fit_params_last_step = routed_params[ @@ -540,6 +543,7 @@ def _can_fit_transform(self) -> bool: ------- bool True if the final estimator can fit_transform or is passthrough. + """ return ( self._final_estimator == "passthrough" @@ -554,13 +558,14 @@ def _can_decision_function(self) -> bool: ------- bool True if the final estimator implements decision_function. + """ return hasattr(self._final_estimator, "decision_function") @available_if(_can_fit_transform) @_fit_context( # estimators in Pipeline.steps are not validated yet - prefer_skip_nested_validation=False + prefer_skip_nested_validation=False, ) def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: """Fit the model and transform with the final estimator. @@ -610,18 +615,21 @@ def fit_transform(self, X: Any, y: Any = None, **params: Any) -> Any: ] if hasattr(last_step, "fit_transform"): iter_input = last_step.fit_transform( - iter_input, iter_label, **last_step_params["fit_transform"] + iter_input, + iter_label, + **last_step_params["fit_transform"], ) elif hasattr(last_step, "transform") and hasattr(last_step, "fit"): last_step.fit(iter_input, iter_label, **last_step_params["fit"]) iter_input = last_step.transform( - iter_input, **last_step_params["transform"] + iter_input, + **last_step_params["transform"], ) else: raise TypeError( f"fit_transform of the final estimator" f" {last_step.__class__.__name__} {last_step_params} does not " - f"match fit_transform of Pipeline {self.__class__.__name__}" + f"match fit_transform of Pipeline {self.__class__.__name__}", ) for _, post_element in self._post_processing_steps(): iter_input = post_element.fit_transform(iter_input, iter_label) @@ -661,11 +669,12 @@ def predict(self, X: Any, **params: Any) -> Any: ------- y_pred : ndarray Result of calling `predict` on the final estimator. + """ if _routing_enabled(): routed_params = process_routing(self, "predict", **params) else: - routed_params = process_routing(self, "predict", **{}) + routed_params = process_routing(self, "predict") iter_input = self._transform(X, routed_params) @@ -683,7 +692,7 @@ def predict(self, X: Any, **params: Any) -> Any: iter_input = self._final_estimator.predict(iter_input, **params) else: raise AssertionError( - "Final estimator does not implement predict, hence this function should not be available." + "Final estimator does not implement predict, hence this function should not be available.", ) for _, post_element in self._post_processing_steps(): iter_input = post_element.transform(iter_input) @@ -692,7 +701,7 @@ def predict(self, X: Any, **params: Any) -> Any: @available_if(_final_estimator_has("fit_predict")) @_fit_context( # estimators in Pipeline.steps are not validated yet - prefer_skip_nested_validation=False + prefer_skip_nested_validation=False, ) def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: """Transform the data, and apply `fit_predict` with the final estimator. @@ -727,6 +736,7 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: ------- y_pred : ndarray Result of calling `fit_predict` on the final estimator. + """ routed_params = self._check_method_params(method="fit_predict", props=params) iter_input, iter_label = self._fit(X, y, routed_params) @@ -741,12 +751,14 @@ def fit_predict(self, X: Any, y: Any = None, **params: Any) -> Any: y_pred = [] elif hasattr(self._final_estimator, "fit_predict"): y_pred = self._final_estimator.fit_predict( - iter_input, iter_label, **params_last_step.get("fit_predict", {}) + iter_input, + iter_label, + **params_last_step.get("fit_predict", {}), ) else: raise AssertionError( "Final estimator does not implement fit_predict, " - "hence this function should not be available." + "hence this function should not be available.", ) for _, post_element in self._post_processing_steps(): y_pred = post_element.fit_transform(y_pred, iter_label) @@ -784,6 +796,7 @@ def predict_proba(self, X: Any, **params: Any) -> Any: ------- y_pred : ndarray Result of calling `predict_proba` on the final estimator. + """ routed_params = process_routing(self, "predict_proba", **params) iter_input = self._transform(X, routed_params) @@ -805,7 +818,7 @@ def predict_proba(self, X: Any, **params: Any) -> Any: else: raise AssertionError( "Final estimator does not implement predict_proba, " - "hence this function should not be available." + "hence this function should not be available.", ) for _, post_element in self._post_processing_steps(): iter_input = post_element.transform(iter_input) @@ -818,9 +831,11 @@ def _can_transform(self) -> bool: ------- bool True if the final estimator can transform or is passthrough. + """ return self._final_estimator == "passthrough" or hasattr( - self._final_estimator, "transform" + self._final_estimator, + "transform", ) @available_if(_can_transform) @@ -868,12 +883,13 @@ def transform(self, X: Any, **params: Any) -> Any: break if hasattr(transform, "transform"): iter_input = transform.transform( - iter_input, **routed_params[name].transform + iter_input, + **routed_params[name].transform, ) else: raise AssertionError( "Non transformer ocurred in transformation step." - "This should have been caught in the validation step." + "This should have been caught in the validation step.", ) for _, post_element in self._post_processing_steps(): iter_input = post_element.transform(iter_input, **params) @@ -900,11 +916,12 @@ def decision_function(self, X: Any, **params: Any) -> Any: ------- Any Result of calling `decision_function` on the final estimator. + """ if _routing_enabled(): routed_params = process_routing(self, "decision_function", **params) else: - routed_params = process_routing(self, "decision_function", **{}) + routed_params = process_routing(self, "decision_function") iter_input = self._transform(X, routed_params) if self._final_estimator == "passthrough": @@ -914,16 +931,18 @@ def decision_function(self, X: Any, **params: Any) -> Any: elif hasattr(self._final_estimator, "decision_function"): if _routing_enabled(): iter_input = self._final_estimator.decision_function( - iter_input, **routed_params[self._final_estimator].predict + iter_input, + **routed_params[self._final_estimator].predict, ) else: iter_input = self._final_estimator.decision_function( - iter_input, **params + iter_input, + **params, ) else: raise AssertionError( "Final estimator does not implement `decision_function`, " - "hence this function should not be available." + "hence this function should not be available.", ) for _, post_element in self._post_processing_steps(): iter_input = post_element.transform(iter_input) @@ -965,6 +984,7 @@ def __sklearn_tags__(self) -> Tags: ------- Tags The sklearn tags. + """ tags = super().__sklearn_tags__() @@ -974,7 +994,7 @@ def __sklearn_tags__(self) -> Tags: try: if self.steps[0][1] is not None and self.steps[0][1] != "passthrough": tags.input_tags.pairwise = get_tags( - self.steps[0][1] + self.steps[0][1], ).input_tags.pairwise # WARNING: the sparse tag can be incorrect. # Some Pipelines accepting sparse data are wrongly tagged sparse=False. @@ -1025,6 +1045,7 @@ def get_metadata_routing(self) -> MetadataRouter: MetadataRouter A :class:`~sklearn.utils.metadata_routing.MetadataRouter` encapsulating routing information. + """ router = MetadataRouter(owner=self.__class__.__name__) @@ -1073,7 +1094,8 @@ def get_metadata_routing(self) -> MetadataRouter: method_mapping.add(caller="fit_transform", callee="fit_transform") else: method_mapping.add(caller="fit", callee="fit").add( - caller="fit", callee="transform" + caller="fit", + callee="transform", ) ( method_mapping.add(caller="fit", callee="fit") diff --git a/molpipeline/post_prediction.py b/molpipeline/post_prediction.py index 38a7b65a..c428be71 100644 --- a/molpipeline/post_prediction.py +++ b/molpipeline/post_prediction.py @@ -3,12 +3,7 @@ from __future__ import annotations import abc -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self from numpy import typing as npt from sklearn.base import BaseEstimator, TransformerMixin @@ -40,6 +35,7 @@ def transform(self, X: Any, **params: Any) -> Any: # pylint: disable=invalid-na ------- npt.NDArray[Any] Transformed data. + """ @@ -83,7 +79,7 @@ def fit( ---------- X : npt.NDArray[Any] Input data. - y : Optional[npt.NDArray[Any]] + y : npt.NDArray[Any] | None, optional Target data. **params : Any Additional parameters for fitting. @@ -92,6 +88,7 @@ def fit( ------- Self Fitted PostPredictionWrapper. + """ if isinstance(self.wrapped_estimator, FilterReinserter): self.wrapped_estimator.fit(X, **params) @@ -129,7 +126,7 @@ def transform( if hasattr(self.wrapped_estimator, "transform"): return self.wrapped_estimator.transform(X, **params) raise AttributeError( - f"Estimator {self.wrapped_estimator} has neither predict nor transform method." + f"Estimator {self.wrapped_estimator} has neither predict nor transform method.", ) def fit_transform( @@ -167,7 +164,7 @@ def fit_transform( return self.wrapped_estimator.fit_transform(X) return self.wrapped_estimator.fit_transform(X, y, **params) raise AttributeError( - f"Estimator {self.wrapped_estimator} has neither fit_predict nor fit_transform method." + f"Estimator {self.wrapped_estimator} has neither fit_predict nor fit_transform method.", ) def inverse_transform( @@ -190,11 +187,12 @@ def inverse_transform( ------- npt.NDArray[Any] Inverse transformed data. + """ if hasattr(self.wrapped_estimator, "inverse_transform"): return self.wrapped_estimator.inverse_transform(X) raise AttributeError( - f"Estimator {self.wrapped_estimator} has no inverse_transform method." + f"Estimator {self.wrapped_estimator} has no inverse_transform method.", ) def get_params(self, deep: bool = True) -> dict[str, Any]: @@ -209,6 +207,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameters. + """ param_dict = {"wrapped_estimator": self.wrapped_estimator} if deep: @@ -228,6 +227,7 @@ def set_params(self, **params: Any) -> Self: ------- dict[str, Any] Parameters. + """ param_copy = dict(params) if "wrapped_estimator" in param_copy: diff --git a/molpipeline/utils/comparison.py b/molpipeline/utils/comparison.py index 42cf32bf..dbd422c8 100644 --- a/molpipeline/utils/comparison.py +++ b/molpipeline/utils/comparison.py @@ -43,7 +43,8 @@ def remove_irrelevant_params(params: _T) -> _T: def compare_recursive( # pylint: disable=too-many-return-statements - value_a: Any, value_b: Any + value_a: Any, + value_b: Any, ) -> bool: """Compare two values recursively. @@ -74,7 +75,7 @@ def compare_recursive( # pylint: disable=too-many-return-statements if isinstance(value_a, (list, tuple)): if len(value_a) != len(value_b): return False - for val_a, val_b in zip(value_a, value_b): + for val_a, val_b in zip(value_a, value_b, strict=True): if not compare_recursive(val_a, val_b): return False return True diff --git a/molpipeline/utils/kernel.py b/molpipeline/utils/kernel.py index 7cf179ce..23d20fe3 100644 --- a/molpipeline/utils/kernel.py +++ b/molpipeline/utils/kernel.py @@ -1,7 +1,5 @@ """Contains functions for molecular similarity.""" -from typing import Union - import numpy as np import numpy.typing as npt from scipy import sparse @@ -62,13 +60,13 @@ def tanimoto_distance_sparse( def self_tanimoto_similarity( - matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]], + matrix_a: sparse.csr_matrix | npt.NDArray[np.int_], ) -> npt.NDArray[np.float64]: """Calculate a matrix of tanimoto similarity between feature matrix a and itself. Parameters ---------- - matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]] + matrix_a: sparse.csr_matrix | npt.NDArray[np.int_] Feature matrix. Raises @@ -92,13 +90,13 @@ def self_tanimoto_similarity( def self_tanimoto_distance( - matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]], + matrix_a: sparse.csr_matrix | npt.NDArray[np.int_], ) -> npt.NDArray[np.float64]: """Calculate a matrix of tanimoto distance between feature matrix a and itself. Parameters ---------- - matrix_a: Union[sparse.csr_matrix, npt.NDArray[np.int_]] + matrix_a: sparse.csr_matrix | npt.NDArray[np.int_] Feature matrix. Returns diff --git a/molpipeline/utils/molpipeline_types.py b/molpipeline/utils/molpipeline_types.py index a39229b6..9240dab6 100644 --- a/molpipeline/utils/molpipeline_types.py +++ b/molpipeline/utils/molpipeline_types.py @@ -4,20 +4,7 @@ from collections.abc import Sequence from numbers import Number -from typing import ( - Any, - Literal, - Optional, - Protocol, - TypeAlias, - TypeVar, - Union, -) - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Literal, Protocol, Self, TypeAlias, TypeVar import numpy as np import numpy.typing as npt @@ -49,14 +36,14 @@ TypeFixedVarSeq = TypeVar("TypeFixedVarSeq", bound=Sequence[_T] | npt.NDArray[_NT]) # type: ignore AnyVarSeq = TypeVar("AnyVarSeq", bound=Sequence[Any] | npt.NDArray[Any]) -FloatCountRange: TypeAlias = tuple[Optional[float], Optional[float]] -IntCountRange: TypeAlias = tuple[Optional[int], Optional[int]] +FloatCountRange: TypeAlias = tuple[float | None, float | None] +IntCountRange: TypeAlias = tuple[int | None, int | None] # IntOrIntCountRange for Typing of count ranges # - a single int for an exact value match # - a range given as a tuple with a lower and upper bound # - both limits are optional -IntOrIntCountRange: TypeAlias = Union[int, IntCountRange] +IntOrIntCountRange: TypeAlias = int | IntCountRange class AnySklearnEstimator(Protocol): @@ -74,6 +61,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Parameter names mapped to their values. + """ def set_params(self, **params: Any) -> Self: @@ -88,6 +76,7 @@ def set_params(self, **params: Any) -> Self: ------- Self Estimator with updated parameters. + """ def fit( @@ -112,6 +101,7 @@ def fit( ------- Self Fitted estimator. + """ @@ -139,6 +129,7 @@ def fit_predict( ------- npt.NDArray[Any] Predictions. + """ @@ -167,6 +158,7 @@ def fit_transform( ------- npt.NDArray[Any] Transformed array. + """ def transform( @@ -187,10 +179,10 @@ def transform( ------- npt.NDArray[Any] Transformed array. + """ -AnyElement = Union[ - AnyTransformer, AnyPredictor, ABCPipelineElement, Literal["passthrough"] -] +AnyElement = AnyTransformer | AnyPredictor | ABCPipelineElement | Literal["passthrough"] + AnyStep = tuple[str, AnyElement] diff --git a/molpipeline/utils/subpipeline.py b/molpipeline/utils/subpipeline.py index 46462125..407436c1 100644 --- a/molpipeline/utils/subpipeline.py +++ b/molpipeline/utils/subpipeline.py @@ -2,7 +2,8 @@ from __future__ import annotations -from typing import Any, Callable +from collections.abc import Callable +from typing import Any from sklearn.base import BaseEstimator diff --git a/molpipeline/utils/value_conversions.py b/molpipeline/utils/value_conversions.py index 8b25c17e..3bab21f4 100644 --- a/molpipeline/utils/value_conversions.py +++ b/molpipeline/utils/value_conversions.py @@ -10,7 +10,7 @@ def count_value_to_tuple(count: IntOrIntCountRange) -> IntCountRange: Parameters ---------- - count: Union[int, IntCountRange] + count: IntOrIntCountRange Count value. Can be a single int or a tuple of two values. Raises diff --git a/notebooks/02_scaffold_split_with_custom_estimators.ipynb b/notebooks/02_scaffold_split_with_custom_estimators.ipynb index ec4e5e2b..96338ccc 100644 --- a/notebooks/02_scaffold_split_with_custom_estimators.ipynb +++ b/notebooks/02_scaffold_split_with_custom_estimators.ipynb @@ -420,7 +420,7 @@ "\n", " # print the performance for predicting the presence of nitrogens on the test set\n", " for smi, pred, label in zip(\n", - " smiles_data[test], predictions[:, 1], has_nitrogen_label[test]\n", + " smiles_data[test], predictions[:, 1], has_nitrogen_label[test], strict=True\n", " ):\n", " print(f\"fold {i}:\", smi, f\"prediction={pred:.2f}\", f\"label={label}\")\n", " print(\n", diff --git a/notebooks/advanced_03_introduction_to_explainable_ai.ipynb b/notebooks/advanced_03_introduction_to_explainable_ai.ipynb index 3b6236c0..dfa217d4 100644 --- a/notebooks/advanced_03_introduction_to_explainable_ai.ipynb +++ b/notebooks/advanced_03_introduction_to_explainable_ai.ipynb @@ -310,7 +310,7 @@ "source": [ "mols = [Chem.MolFromSmiles(smiles) for smiles in df[\"pubchem_smiles\"]]\n", "for prop_name in [\"name\", \"origin\", \"pIC50\"]:\n", - " for mol, prop in zip(mols, df[prop_name]):\n", + " for mol, prop in zip(mols, df[prop_name], strict=True):\n", " mol.SetProp(prop_name, str(prop))\n", "mols[1]" ] diff --git a/pyproject.toml b/pyproject.toml index 2eb94220..ac56f021 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -12,6 +12,7 @@ authors = [ ] description = "Integration of rdkit functionality into sklearn pipelines." readme = "README.md" +requires-python = ">=3.11" dependencies = [ "joblib>=1.3.0", "loguru>=0.7.3", diff --git a/ruff.toml b/ruff.toml index a235c563..ac50abba 100644 --- a/ruff.toml +++ b/ruff.toml @@ -1,4 +1,4 @@ - target-version = "py310" + target-version = "py311" [lint] preview = true select = [ diff --git a/tests/test_elements/test_any2mol/test_auto2mol.py b/tests/test_elements/test_any2mol/test_auto2mol.py index cd6f5948..506fdea2 100644 --- a/tests/test_elements/test_any2mol/test_auto2mol.py +++ b/tests/test_elements/test_any2mol/test_auto2mol.py @@ -68,10 +68,10 @@ def test_auto2mol_for_smiles(self) -> None: SmilesToMol(), BinaryToMol(), SDFToMol(), - ) + ), ), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_smiles) @@ -79,8 +79,12 @@ def test_auto2mol_for_smiles(self) -> None: self.assertTrue( all( Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) - for smiles_mol, original_mol in zip(actual_mols, expected_mols) - ) + for smiles_mol, original_mol in zip( + actual_mols, + expected_mols, + strict=True, + ) + ), ) del log_block @@ -95,7 +99,7 @@ def test_auto2mol_for_inchi(self) -> None: "Auto2Mol", AutoToMol(), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_inchis) @@ -103,8 +107,12 @@ def test_auto2mol_for_inchi(self) -> None: self.assertTrue( all( Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) - for smiles_mol, original_mol in zip(actual_mols, expected_mols) - ) + for smiles_mol, original_mol in zip( + actual_mols, + expected_mols, + strict=True, + ) + ), ) del log_block @@ -119,10 +127,10 @@ def test_auto2mol_for_sdf(self) -> None: SmilesToMol(), BinaryToMol(), SDFToMol(), - ) + ), ), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform([SDF_P86_B_400]) @@ -158,10 +166,10 @@ def test_auto2mol_for_binary(self) -> None: SmilesToMol(), BinaryToMol(), SDFToMol(), - ) + ), ), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_bin_mols) @@ -169,8 +177,12 @@ def test_auto2mol_for_binary(self) -> None: self.assertTrue( all( Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) - for smiles_mol, original_mol in zip(actual_mols, expected_mols) - ) + for smiles_mol, original_mol in zip( + actual_mols, + expected_mols, + strict=True, + ) + ), ) del log_block @@ -200,10 +212,10 @@ def test_auto2mol_for_molecule(self) -> None: SmilesToMol(), BinaryToMol(), SDFToMol(), - ) + ), ), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_mols) @@ -211,8 +223,12 @@ def test_auto2mol_for_molecule(self) -> None: self.assertTrue( all( Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) - for smiles_mol, original_mol in zip(actual_mols, expected_mols) - ) + for smiles_mol, original_mol in zip( + actual_mols, + expected_mols, + strict=True, + ) + ), ) del log_block @@ -244,10 +260,10 @@ def test_auto2mol_mixed_inputs(self) -> None: SmilesToMol(), BinaryToMol(), SDFToMol(), - ) + ), ), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_inputs) @@ -255,8 +271,8 @@ def test_auto2mol_mixed_inputs(self) -> None: self.assertTrue( all( Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) - for smiles_mol, original_mol in zip(actual_mols, test_mols) - ) + for smiles_mol, original_mol in zip(actual_mols, test_mols, strict=True) + ), ) del log_block @@ -278,17 +294,18 @@ def test_auto2mol_invalid_input_nones(self) -> None: SmilesToMol(), BinaryToMol(), SDFToMol(), - ) + ), ), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_inputs) self.assertEqual(len(test_inputs), len(actual_mols)) self.assertEqual( - Chem.MolToInchi(actual_mols[0]), Chem.MolToInchi(MOL_P86_LIGAND) + Chem.MolToInchi(actual_mols[0]), + Chem.MolToInchi(MOL_P86_LIGAND), ) self.assertTrue(isinstance(actual_mols[1], InvalidInstance)) self.assertEqual(Chem.MolToInchi(actual_mols[2]), Chem.MolToInchi(MOL_BENZENE)) @@ -310,7 +327,7 @@ def test_auto2mol_invalid_input_no_matching_reader(self) -> None: "Auto2Mol", AutoToMol(elements=(SmilesToMol(),)), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_inputs) @@ -320,7 +337,8 @@ def test_auto2mol_invalid_input_no_matching_reader(self) -> None: self.assertEqual(Chem.MolToInchi(actual_mols[1]), Chem.MolToInchi(MOL_BENZENE)) self.assertTrue(isinstance(actual_mols[2], InvalidInstance)) self.assertEqual( - Chem.MolToInchi(actual_mols[3]), Chem.MolToInchi(MOL_CHLOROBENZENE) + Chem.MolToInchi(actual_mols[3]), + Chem.MolToInchi(MOL_CHLOROBENZENE), ) del log_block @@ -339,7 +357,7 @@ def test_auto2mol_invalid_input_empty_elements(self) -> None: "Auto2Mol", AutoToMol(elements=()), ), - ] + ], ) log_block = rdBase.BlockLogs() actual_mols = pipeline.fit_transform(test_inputs) diff --git a/tests/test_elements/test_error_handling.py b/tests/test_elements/test_error_handling.py index 9405dad4..d6a79245 100644 --- a/tests/test_elements/test_error_handling.py +++ b/tests/test_elements/test_error_handling.py @@ -32,7 +32,7 @@ def test_error_dummy_fill_molpipeline(self) -> None: mol2smi = MolToSmiles() remove_error = ErrorFilter.from_element_list([smi2mol, mol2smi]) replace_error = PostPredictionWrapper( - FilterReinserter.from_error_filter(remove_error, fill_value=None) + FilterReinserter.from_error_filter(remove_error, fill_value=None), ) pipeline = Pipeline( @@ -41,10 +41,10 @@ def test_error_dummy_fill_molpipeline(self) -> None: ("mol2smi", mol2smi), ("remove_error", remove_error), ("replace_error", replace_error), - ] + ], ) out = pipeline.fit_transform(TEST_SMILES) - for pred_val, true_val in zip(out, EXPECTED_OUTPUT): + for pred_val, true_val in zip(out, EXPECTED_OUTPUT, strict=True): self.assertEqual(pred_val, true_val) def test_error_dummy_remove_record_molpipeline(self) -> None: @@ -57,7 +57,7 @@ def test_error_dummy_remove_record_molpipeline(self) -> None: ("smi2mol", smi2mol), ("mol2smi", mol2smi), ("error_filter", error_filter), - ] + ], ) out = pipeline.transform(TEST_SMILES) self.assertEqual(len(out), 2) @@ -123,7 +123,7 @@ def test_dummy_fill_physchem_record_molpipeline(self) -> None: mol2physchem = MolToRDKitPhysChem() remove_none = ErrorFilter.from_element_list([smi2mol, mol2physchem]) fill_none = PostPredictionWrapper( - FilterReinserter.from_error_filter(remove_none, fill_value=10) + FilterReinserter.from_error_filter(remove_none, fill_value=10), ) pipeline = Pipeline( @@ -171,13 +171,14 @@ def test_replace_mixed_datatypes(self) -> None: mock2mock = MockTransformingPipelineElement( invalid_values={ - test_values[1] + test_values[1], }, # replaces element at index 1 with an invalid instance return_as_numpy_array=as_numpy_array, ) error_filter = ErrorFilter.from_element_list([mock2mock]) error_replacer = FilterReinserter.from_error_filter( - error_filter=error_filter, fill_value=fill_value + error_filter=error_filter, + fill_value=fill_value, ) pipeline = Pipeline( [ @@ -232,7 +233,8 @@ def test_replace_mixed_datatypes_expected_failures(self) -> None: error_filter = ErrorFilter.from_element_list([mock2mock]) fill_value: list[Any] = [] error_replacer = FilterReinserter.from_error_filter( - error_filter=error_filter, fill_value=fill_value + error_filter=error_filter, + fill_value=fill_value, ) pipeline = Pipeline( [ diff --git a/tests/test_elements/test_mol2any/test_mol2bin.py b/tests/test_elements/test_mol2any/test_mol2bin.py index 3af9aed2..1bd99d78 100644 --- a/tests/test_elements/test_mol2any/test_mol2bin.py +++ b/tests/test_elements/test_mol2any/test_mol2bin.py @@ -46,7 +46,7 @@ def test_mol_to_binary(self) -> None: [ ("Smiles2Mol", SmilesToMol()), ("Mol2Binary", MolToBinary()), - ] + ], ) log_block = rdBase.BlockLogs() binary_mols = pipeline.fit_transform(test_smiles) @@ -55,8 +55,12 @@ def test_mol_to_binary(self) -> None: self.assertTrue( all( Chem.MolToInchi(smiles_mol) == Chem.MolToInchi(original_mol) - for smiles_mol, original_mol in zip(actual_mols, expected_mols) - ) + for smiles_mol, original_mol in zip( + actual_mols, + expected_mols, + strict=True, + ) + ), ) del log_block @@ -65,7 +69,7 @@ def test_mol_to_binary_invalid_input(self) -> None: pipeline = Pipeline( [ ("Mol2Binary", MolToBinary()), - ] + ], ) # test empty molecule diff --git a/tests/test_estimators/test_leader_picker_clustering.py b/tests/test_estimators/test_leader_picker_clustering.py index d2011ad0..ce228f71 100644 --- a/tests/test_estimators/test_leader_picker_clustering.py +++ b/tests/test_estimators/test_leader_picker_clustering.py @@ -85,7 +85,10 @@ def test_leader_picker_pipeline(self) -> None: expected_centroids = [[0, 1, 2, 3, 4, 5], [0, 1, 2, 4]] for dist, exp_labels, exp_centroids in zip( - distances, expected_labels, expected_centroids + distances, + expected_labels, + expected_centroids, + strict=True, ): leader_picker = LeaderPickerClustering(distance_threshold=dist) pipeline = Pipeline( @@ -94,7 +97,9 @@ def test_leader_picker_pipeline(self) -> None: ( "morgan2", MolToMorganFP( - return_as="explicit_bit_vect", n_bits=1024, radius=2 + return_as="explicit_bit_vect", + n_bits=1024, + radius=2, ), ), ("leader_picker", leader_picker), diff --git a/tests/test_experimental/test_model_selection/test_splitter.py b/tests/test_experimental/test_model_selection/test_splitter.py index 563725c8..86b685b1 100644 --- a/tests/test_experimental/test_model_selection/test_splitter.py +++ b/tests/test_experimental/test_model_selection/test_splitter.py @@ -39,7 +39,8 @@ def test_splitting_produces_expected_sizes_when_data_allows_it(self) -> None: groups = range(10) for split_mode in get_args(SplitModeOption): split_generator = GroupShuffleSplit( - train_size=train_size, split_mode=split_mode + train_size=train_size, + split_mode=split_mode, ).split(X, y, groups) X_train, X_test = next(split_generator) # pylint: disable=invalid-name self.assertEqual(len(X_train), exp_train) @@ -78,7 +79,10 @@ def test_different_input(self) -> None: test_size = 1.0 / 3 for split_mode in get_args(SplitModeOption): gss = GroupShuffleSplit( - n_splits, test_size=test_size, random_state=0, split_mode=split_mode + n_splits, + test_size=test_size, + random_state=0, + split_mode=split_mode, ) # Make sure the repr works @@ -95,10 +99,10 @@ def test_different_input(self) -> None: l_train_unique = np.unique(groups_i_array[train]) l_test_unique = np.unique(groups_i_array[test]) self.assertFalse( - np.any(np.isin(groups_i_array[train], l_test_unique)) + np.any(np.isin(groups_i_array[train], l_test_unique)), ) self.assertFalse( - np.any(np.isin(groups_i_array[test], l_train_unique)) + np.any(np.isin(groups_i_array[test], l_train_unique)), ) # Second test: train and test add up to all the data @@ -113,12 +117,13 @@ def test_different_input(self) -> None: # Fourth test: # unique train and test groups are correct, +- 1 for rounding error self.assertLessEqual( - abs(len(l_test_unique) - round(test_size * len(l_unique))), 1 + abs(len(l_test_unique) - round(test_size * len(l_unique))), + 1, ) self.assertLessEqual( abs( len(l_train_unique) - - round((1.0 - test_size) * len(l_unique)) + - round((1.0 - test_size) * len(l_unique)), ), 1, ) @@ -145,6 +150,7 @@ def test_compare_to_sklearn_implementation(self) -> None: for (train_mp, test_mp), (train_sk, test_sk) in zip( gss_molpipeline.split(X, y, groups=groups_i), gss_sklearn.split(X, y, groups=groups_i), + strict=True, ): # test that for the same seed the exact same splits are produced assert_array_equal(train_mp, train_sk) diff --git a/tests/test_pipeline.py b/tests/test_pipeline.py index 84eb6ae4..a8348b3f 100644 --- a/tests/test_pipeline.py +++ b/tests/test_pipeline.py @@ -54,7 +54,7 @@ def test_fit_transform_single_core(self) -> None: [ ("smi2mol", smi2mol), ("morgan", mol2morgan), - ] + ], ) # Run pipeline @@ -73,11 +73,11 @@ def test_sklearn_pipeline(self) -> None: ("smi2mol", smi2mol), ("morgan", mol2morgan), ("decision_tree", d_tree), - ] + ], ) s_pipeline.fit(TEST_SMILES, CONTAINS_OX) predicted_value_array = s_pipeline.predict(TEST_SMILES) - for pred_val, true_val in zip(predicted_value_array, CONTAINS_OX): + for pred_val, true_val in zip(predicted_value_array, CONTAINS_OX, strict=True): self.assertEqual(pred_val, true_val) def test_sklearn_pipeline_parallel(self) -> None: @@ -96,7 +96,7 @@ def test_sklearn_pipeline_parallel(self) -> None: s_pipeline.fit(TEST_SMILES, CONTAINS_OX) out = s_pipeline.predict(TEST_SMILES) self.assertEqual(len(out), len(CONTAINS_OX)) - for pred_val, true_val in zip(out, CONTAINS_OX): + for pred_val, true_val in zip(out, CONTAINS_OX, strict=True): self.assertEqual(pred_val, true_val) def test_salt_removal(self) -> None: @@ -119,11 +119,13 @@ def test_salt_removal(self) -> None: ("empty_mol_filter", empty_mol_filter), ("remove_charge", remove_charge), ("mol2smi", mol2smi), - ] + ], ) generated_smiles = salt_remover_pipeline.transform(smiles_with_salt_list) for generated_smiles, smiles_without_salt in zip( - generated_smiles, smiles_without_salt_list + generated_smiles, + smiles_without_salt_list, + strict=True, ): self.assertEqual(generated_smiles, smiles_without_salt) @@ -146,7 +148,7 @@ def test_json_generation(self) -> None: ("metal_disconnector", metal_disconnector), ("salt_remover", salt_remover), ("physchem", physchem), - ] + ], ) # Convert pipeline to json @@ -156,7 +158,9 @@ def test_json_generation(self) -> None: self.assertTrue(isinstance(loaded_pipeline, Pipeline)) # Compare pipeline elements for loaded_element, original_element in zip( - loaded_pipeline.steps, pipeline_element_list + loaded_pipeline.steps, + pipeline_element_list, + strict=True, ): if loaded_element[1] == "passthrough": self.assertEqual(loaded_element[1], original_element) @@ -176,7 +180,7 @@ def test_fit_transform_record_remove_nones(self) -> None: mol2morgan = MolToMorganFP(radius=FP_RADIUS, n_bits=FP_SIZE) empty_mol_filter = EmptyMoleculeFilter() remove_none = ErrorFilter.from_element_list( - [smi2mol, salt_remover, mol2morgan, empty_mol_filter] + [smi2mol, salt_remover, mol2morgan, empty_mol_filter], ) # Create pipeline pipeline = Pipeline( @@ -197,7 +201,9 @@ def test_fit_transform_record_remove_nones(self) -> None: def test_caching(self) -> None: """Test if the caching gives the same results and is faster on the second run.""" molecule_net_logd_df = pd.read_csv( - TEST_DATA_DIR / "molecule_net_logd.tsv.gz", sep="\t", nrows=20 + TEST_DATA_DIR / "molecule_net_logd.tsv.gz", + sep="\t", + nrows=20, ) prediction_list = [] for cache_activated in [False, True]: @@ -263,7 +269,7 @@ def test_gridsearchcv(self) -> None: "physchem__descriptor_list": [ ["HeavyAtomMolWt"], ["HeavyAtomMolWt", "HeavyAtomCount"], - ] + ], }, }, ] @@ -313,7 +319,9 @@ def test_gridsearch_cache(self) -> None: } # First without caching data_df = pd.read_csv( - TEST_DATA_DIR / "molecule_net_logd.tsv.gz", sep="\t", nrows=20 + TEST_DATA_DIR / "molecule_net_logd.tsv.gz", + sep="\t", + nrows=20, ) best_param_dict = {} prediction_dict = {} @@ -339,7 +347,7 @@ def test_gridsearch_cache(self) -> None: grid_search_cv.fit(data_df["smiles"].tolist(), data_df["exp"].tolist()) best_param_dict[cache_activated] = grid_search_cv.best_params_ prediction_dict[cache_activated] = grid_search_cv.predict( - data_df["smiles"].tolist() + data_df["smiles"].tolist(), ) mem.clear(warn=False) self.assertEqual(best_param_dict[True], best_param_dict[False]) @@ -360,13 +368,16 @@ def test_calibrated_classifier(self) -> None: ( "error_replacer", PostPredictionWrapper( - FilterReinserter.from_error_filter(error_filter, np.nan) + FilterReinserter.from_error_filter(error_filter, np.nan), ), ), - ] + ], ) calibrated_pipeline = CalibratedClassifierCV( - s_pipeline, cv=2, ensemble=True, method="isotonic" + s_pipeline, + cv=2, + ensemble=True, + method="isotonic", ) calibrated_pipeline.fit(TEST_SMILES, CONTAINS_OX) predicted_value_array = calibrated_pipeline.predict(TEST_SMILES) diff --git a/tests/test_utils/test_comparison.py b/tests/test_utils/test_comparison.py index 636faf72..d99e82f0 100644 --- a/tests/test_utils/test_comparison.py +++ b/tests/test_utils/test_comparison.py @@ -1,6 +1,6 @@ """Test the comparison functions.""" -from typing import Callable +from collections.abc import Callable from unittest import TestCase from molpipeline import Pipeline diff --git a/tests/test_utils/test_json_operations.py b/tests/test_utils/test_json_operations.py index 24f612ac..29a6ad15 100644 --- a/tests/test_utils/test_json_operations.py +++ b/tests/test_utils/test_json_operations.py @@ -44,7 +44,9 @@ def test_pipeline_reconstruction(self) -> None: # Separate comparison of the steps as models cannot be compared directly for (orig_name, orig_obj), (recreated_name, recreated_obj) in zip( - original_steps, recreated_steps + original_steps, + recreated_steps, + strict=True, ): # Remove the model from the original params del original_params[orig_name] diff --git a/tests/utils/execution_count.py b/tests/utils/execution_count.py index abf37d6a..6263a54f 100644 --- a/tests/utils/execution_count.py +++ b/tests/utils/execution_count.py @@ -2,12 +2,7 @@ from __future__ import annotations -from typing import Any - -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self +from typing import Any, Self from sklearn.base import BaseEstimator from sklearn.ensemble import RandomForestRegressor diff --git a/tests/utils/mock_element.py b/tests/utils/mock_element.py index 2a1eddec..2b89987c 100644 --- a/tests/utils/mock_element.py +++ b/tests/utils/mock_element.py @@ -4,15 +4,10 @@ import copy from collections.abc import Iterable -from typing import Any +from typing import Any, Self import numpy as np -try: - from typing import Self # type: ignore[attr-defined] -except ImportError: - from typing_extensions import Self - from molpipeline.abstract_pipeline_elements.core import ( InvalidInstance, TransformingPipelineElement, From 9a472e3ca6a4dee555f058a9195f7202ace8eb26 Mon Sep 17 00:00:00 2001 From: "Christian W. Feldmann" <128160984+c-w-feldmann@users.noreply.github.com> Date: Tue, 6 May 2025 17:23:08 +0200 Subject: [PATCH 4/5] 168 refactor moltofingerprint pipeline element (#169) * Remove ABCMorganFingerprintPipelineElement * fix types and imports * add explainablity to PathFP * fix additional output * remove unnecessary comment * deduplicate code * use properties * use properties * move properties to MolToRDKitGenFPElement * rework test without modifying the functionallity * add test to assure that mappings and original fp match * move bitmapping to individual classes * test bitmapping for mol2path * broaden expected types * black * negate return statement to restore intended behaviour * extend tests to PathFP --- .../mol2any/mol2bitvector.py | 258 ++++++------------ .../experimental/explainability/explainer.py | 92 ++++--- .../explainability/fingerprint_utils.py | 35 ++- molpipeline/mol2any/mol2morgan_fingerprint.py | 131 ++++++--- molpipeline/mol2any/mol2path_fingerprint.py | 56 +++- .../test_mol2morgan_fingerprint.py | 42 ++- .../test_mol2any/test_mol2path_fingerprint.py | 34 ++- .../test_shap_explainers.py | 148 ++++++---- 8 files changed, 456 insertions(+), 340 deletions(-) diff --git a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py index ee64bc85..deb6b915 100644 --- a/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py +++ b/molpipeline/abstract_pipeline_elements/mol2any/mol2bitvector.py @@ -4,19 +4,31 @@ import abc import copy -from collections.abc import Iterable -from typing import Any, Literal, Self, TypeAlias, get_args, overload +from typing import ( + TYPE_CHECKING, + Any, + Literal, + Self, + TypeAlias, + get_args, + overload, +) import numpy as np import numpy.typing as npt -from rdkit.Chem import rdFingerprintGenerator -from rdkit.DataStructs import ExplicitBitVect -from scipy import sparse from molpipeline.abstract_pipeline_elements.core import MolToAnyPipelineElement from molpipeline.utils.matrices import sparse_from_index_value_dicts -from molpipeline.utils.molpipeline_types import RDKitMol -from molpipeline.utils.substructure_handling import CircularAtomEnvironment + +if TYPE_CHECKING: + from collections.abc import Iterable, Mapping, Sequence + + from rdkit.Chem import rdFingerprintGenerator + from rdkit.DataStructs.cDataStructs import ExplicitBitVect + from scipy import sparse + + from molpipeline.utils.molpipeline_types import RDKitMol + from molpipeline.utils.substructure_handling import AtomEnvironment # possible output types for a fingerprint: # - "sparse" is a sparse csr_matrix @@ -26,7 +38,7 @@ class MolToFingerprintPipelineElement(MolToAnyPipelineElement, abc.ABC): - """Abstract class for PipelineElements which transform molecules to integer vectors.""" + """Abstract PipelineElement which transform molecules to integer vectors.""" _n_bits: int _feature_names: list[str] @@ -45,10 +57,11 @@ def __init__( Parameters ---------- return_as: Literal["sparse", "dense", "explicit_bit_vect"] - Type of output. When "sparse" the fingerprints will be returned as a scipy.sparse.csr_matrix - holding a sparse representation of the bit vectors. With "dense" a numpy matrix will be returned. - With "explicit_bit_vect" the fingerprints will be returned as a list of RDKit's - rdkit.DataStructs.cDataStructs.ExplicitBitVect. + Type of output. When "sparse" the fingerprints will be returned as a + scipy.sparse.csr_matrix holding a sparse representation of the bit vectors. + With "dense" a numpy matrix will be returned. + With "explicit_bit_vect" the fingerprints will be returned as a list of + RDKit's rdkit.DataStructs.cDataStructs.ExplicitBitVect. name: str Name of PipelineElement. n_jobs: int, default=1 @@ -104,11 +117,12 @@ def assemble_output( Parameters ---------- - value_list: Iterable[dict[int, int]] | Iterable[npt.NDArray[np.int_]] | Iterable[ExplicitBitVect] + value_list: Iterable[dict[int, int]] | Iterable[npt.NDArray[np.int_]] | + Iterable[ExplicitBitVect] Either Iterable of dicts which encode the rows of the feature matrix. Keys: column index, values: column value. Each dict represents one molecule. - Or an Iterable of RDKit's ExplicitBitVect or an Iterable of numpy arrays representing the - fingerprint list. + Or an Iterable of RDKit's ExplicitBitVect or an Iterable of numpy arrays + representing the fingerprint list. Returns ------- @@ -180,7 +194,7 @@ def set_params(self, **parameters: Any) -> Self: return self def transform(self, values: list[RDKitMol]) -> sparse.csr_matrix: - """Transform the list of molecules to sparse matrix of Morgan-fingerprint features. + """Transform the list of molecules to a sparse matrix. Parameters ---------- @@ -200,7 +214,9 @@ def pretransform_single( self, value: RDKitMol, ) -> dict[int, int] | npt.NDArray[np.int_] | ExplicitBitVect: - """Transform mol to dict, where items encode columns indices and values, respectively. + """Transform mol to dict. + + Items encode columns indices and values, respectively. Parameters ---------- @@ -210,7 +226,9 @@ def pretransform_single( Returns ------- dict[int, int] - Dictionary to encode row in matrix. Keys: column index, values: column value. + Dictionary to encode row in matrix. + Keys: column index + Values: column value """ @@ -218,8 +236,42 @@ def pretransform_single( class MolToRDKitGenFPElement(MolToFingerprintPipelineElement, abc.ABC): """Abstract class for PipelineElements using the FingeprintGenerator64.""" + @property + def n_bits(self) -> int: + """Get number of bits in (or size of) fingerprint.""" + return self._n_bits + + @n_bits.setter + def n_bits(self, value: int) -> None: + """Set number of bits in Morgan fingerprint. + + Parameters + ---------- + value: int + Number of bits in Morgan fingerprint. + + Raises + ------ + ValueError + If value is not a positive integer. + + """ + if not isinstance(value, int) or value < 1: + raise ValueError( + f"Number of bits has to be a integer > 0! (Received: {value})", + ) + self._n_bits = value + + @property + def output_type(self) -> str: + """Get output type.""" + if self.counted: + return "integer" + return "binary" + def __init__( self, + n_bits: int = 2048, counted: bool = False, return_as: OutputDatatype = "sparse", name: str = "MolToRDKitGenFin", @@ -230,6 +282,8 @@ def __init__( Parameters ---------- + n_bits: int, default=2048 + Number of bits in fingerprint. counted: bool, default=False Whether to count the bits or not. return_as: Literal["sparse", "dense", "explicit_bit_vect"], default="sparse" @@ -249,6 +303,7 @@ def __init__( n_jobs=n_jobs, uuid=uuid, ) + self.n_bits = n_bits self.counted = counted @abc.abstractmethod @@ -280,7 +335,8 @@ def pretransform_single( ExplicitBitVect | npt.NDArray[np.int_] | dict[int, int] If return_as is "explicit_bit_vect" return ExplicitBitVect. If return_as is "dense" return numpy array. - If return_as is "sparse" return dictionary with feature-position as key and count as value. + If return_as is "sparse" return dictionary with feature-position as key and + count as value. """ fingerprint_generator = self._get_fp_generator() @@ -345,158 +401,11 @@ def set_params(self, **parameters: Any) -> Self: super().set_params(**parameter_dict_copy) return self - -class ABCMorganFingerprintPipelineElement(MolToRDKitGenFPElement, abc.ABC): - """Abstract Class for Morgan fingerprints.""" - - @property - def output_type(self) -> str: - """Get output type.""" - if self.counted: - return "integer" - return "binary" - - # pylint: disable=R0913 - def __init__( - self, - radius: int = 2, - use_features: bool = False, - counted: bool = False, - return_as: Literal["sparse", "dense", "explicit_bit_vect"] = "sparse", - name: str = "AbstractMorgan", - n_jobs: int = 1, - uuid: str | None = None, - ): - """Initialize abstract class. - - Parameters - ---------- - radius: int, default=2 - Radius of fingerprint. - use_features: bool, default=False - Whether to represent atoms by element or category (donor, acceptor. etc.) - counted: bool, default=False - Whether to count the bits or not. - return_as: Literal["sparse", "dense", "explicit_bit_vect"], default="sparse" - Type of output. - When "sparse" the fingerprints will be returned as a scipy.sparse.csr_matrix - holding a sparse representation of the bit vectors. - With "dense" a numpy matrix will be returned. - With "explicit_bit_vect" the fingerprints will be returned as a list of - RDKit's rdkit.DataStructs.cDataStructs.ExplicitBitVect. - name: str, default="AbstractMorgan" - Name of PipelineElement. - n_jobs: int, default=1 - Number of jobs. - uuid: str | None, optional - Unique identifier. - - Raises - ------ - ValueError - If radius is not a positive integer. - - """ - # pylint: disable=R0801 - super().__init__( - return_as=return_as, - counted=counted, - name=name, - n_jobs=n_jobs, - uuid=uuid, - ) - self._use_features = use_features - if isinstance(radius, int) and radius >= 0: - self._radius = radius - else: - raise ValueError( - f"Number of bits has to be a positive integer! (Received: {radius})", - ) - - def get_params(self, deep: bool = True) -> dict[str, Any]: - """Get object parameters relevant for copying the class. - - Parameters - ---------- - deep: bool - If True get a deep copy of the parameters. - - Returns - ------- - dict[str, Any] - Dictionary of parameter names and values. - - """ - parameters = super().get_params(deep) - if deep: - parameters["radius"] = copy.copy(self.radius) - parameters["use_features"] = copy.copy(self.use_features) - else: - parameters["radius"] = self.radius - parameters["use_features"] = self.use_features - - # remove fill_value from parameters - parameters.pop("fill_value", None) - return parameters - - def set_params(self, **parameters: Any) -> Self: - """Set parameters. - - Parameters - ---------- - parameters: Any - Dictionary of parameter names and values. - - Returns - ------- - Self - PipelineElement with updated parameters. - - """ - parameter_copy = dict(parameters) - radius = parameter_copy.pop("radius", None) - use_features = parameter_copy.pop("use_features", None) - - # explicitly check for None, since 0 is a valid value - if radius is not None: - self._radius = radius - # explicitly check for None, since False is a valid value - if use_features is not None: - self._use_features = bool(use_features) - super().set_params(**parameter_copy) - return self - - @property - def radius(self) -> int: - """Get radius of Morgan fingerprint.""" - return self._radius - - @property - def use_features(self) -> bool: - """Get whether to encode atoms by features or not.""" - return self._use_features - @abc.abstractmethod - def _explain_rdmol(self, mol_obj: RDKitMol) -> dict[int, list[tuple[int, int]]]: - """Get central atom and radius of all features in molecule. - - Parameters - ---------- - mol_obj: RDKitMol - RDKit molecule to be encoded. - - Returns - ------- - dict[int, list[tuple[int, int]]] - Dictionary with mapping from bit to atom index and radius. - - """ - raise NotImplementedError - def bit2atom_mapping( self, mol_obj: RDKitMol, - ) -> dict[int, list[CircularAtomEnvironment]]: + ) -> Mapping[int, Sequence[AtomEnvironment]]: """Obtain set of atoms for all features. Parameters @@ -506,17 +415,8 @@ def bit2atom_mapping( Returns ------- - dict[int, list[CircularAtomEnvironment]] - Dictionary with mapping from bit to encoded AtomEnvironments (which contain atom indices). + Mapping[int, Sequence[AtomEnvironment]] + Dictionary with mapping from bit to encoded + AtomEnvironments (which contain atom indices). """ - bit2atom_dict = self._explain_rdmol(mol_obj) - result_dict: dict[int, list[CircularAtomEnvironment]] = {} - # Iterating over all present bits and respective matches - for bit, matches in bit2atom_dict.items(): # type: int, list[tuple[int, int]] - result_dict[bit] = [] - for central_atom, radius in matches: # type: int, int - env = CircularAtomEnvironment.from_mol(mol_obj, central_atom, radius) - result_dict[bit].append(env) - # Transforming default dict to dict - return result_dict diff --git a/molpipeline/experimental/explainability/explainer.py b/molpipeline/experimental/explainability/explainer.py index 85085c31..b88921d7 100644 --- a/molpipeline/experimental/explainability/explainer.py +++ b/molpipeline/experimental/explainability/explainer.py @@ -3,19 +3,19 @@ from __future__ import annotations import abc -from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any import numpy as np import numpy.typing as npt import pandas as pd import shap from scipy.sparse import issparse, spmatrix -from sklearn.base import BaseEstimator from typing_extensions import override -from molpipeline import Pipeline from molpipeline.abstract_pipeline_elements.core import InvalidInstance, OptionalMol +from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( + MolToRDKitGenFPElement, +) from molpipeline.experimental.explainability.explanation import ( AtomExplanationMixin, BondExplanationMixin, @@ -28,9 +28,15 @@ from molpipeline.experimental.explainability.fingerprint_utils import ( fingerprint_shap_to_atomweights, ) -from molpipeline.mol2any import MolToMorganFP from molpipeline.utils.subpipeline import SubpipelineExtractor, get_model_from_pipeline +if TYPE_CHECKING: + from collections.abc import Callable + + from sklearn.base import BaseEstimator + + from molpipeline import Pipeline + def _to_dense( feature_matrix: npt.NDArray[Any] | spmatrix, @@ -46,6 +52,7 @@ def _to_dense( ------- Any The input features in a compatible format. + """ if issparse(feature_matrix): return feature_matrix.todense() # type: ignore[union-attr] @@ -84,7 +91,8 @@ def _get_prediction_function( # This function might also be put at a more central position in the lib. def _get_predictions( - pipeline: Pipeline, feature_matrix: npt.NDArray[Any] | spmatrix + pipeline: Pipeline, + feature_matrix: npt.NDArray[Any] | spmatrix, ) -> npt.NDArray[np.float64]: """Get the predictions of a model. @@ -101,6 +109,7 @@ def _get_predictions( ------- npt.NDArray[np.float64] The predictions. + """ prediction_function = _get_prediction_function(pipeline) prediction = prediction_function(feature_matrix) @@ -110,7 +119,7 @@ def _get_predictions( def _convert_shap_feature_weights_to_atom_weights( feature_weights: npt.NDArray[np.float64], molecule: OptionalMol, - featurization_element: MolToMorganFP, + featurization_element: MolToRDKitGenFPElement, feature_vector: npt.NDArray[np.float64], ) -> npt.NDArray[np.float64]: """Convert SHAP feature weights to atom weights. @@ -121,7 +130,7 @@ def _convert_shap_feature_weights_to_atom_weights( The feature weights. molecule : OptionalMol The molecule. - featurization_element : MolToMorganFP + featurization_element : MolToRDKitGenFPElement The featurization element. feature_vector : npt.NDArray[np.float64] The feature vector. @@ -141,23 +150,23 @@ def _convert_shap_feature_weights_to_atom_weights( """ if isinstance(molecule, InvalidInstance): raise ValueError( - "Molecule is None. Cannot convert SHAP values to atom weights." + "Molecule is None. Cannot convert SHAP values to atom weights.", ) if feature_weights.ndim == 1: # regression case feature_weights_present_bits_only = feature_weights.copy() - elif feature_weights.ndim == 2: + elif feature_weights.ndim == 2: # noqa: PLR2004 # binary classification case. Take the weights for the positive class. feature_weights_present_bits_only = feature_weights[:, 1].copy() else: raise ValueError( - "Unsupported number of dimensions for feature weights. Expected 1 or 2." + "Unsupported number of dimensions for feature weights. Expected 1 or 2.", ) # reset shap values for bits that are not present in the molecule feature_weights_present_bits_only[feature_vector == 0] = 0 - atom_weights = np.array( + return np.array( fingerprint_shap_to_atomweights( molecule, featurization_element, @@ -165,7 +174,6 @@ def _convert_shap_feature_weights_to_atom_weights( ), dtype=np.float64, ) - return atom_weights class AbstractSHAPExplainer(abc.ABC): # pylint: disable=too-few-public-methods @@ -174,7 +182,7 @@ class AbstractSHAPExplainer(abc.ABC): # pylint: disable=too-few-public-methods @abc.abstractmethod def explain( self, - X: Any, # pylint: disable=invalid-name + X: Any, # pylint: disable=invalid-name # noqa: N803 **kwargs: Any, ) -> list[SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation]: """Explain the predictions for the input data. @@ -190,12 +198,14 @@ def explain( ------- list[SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation] List of explanations corresponding to the input samples. + """ -class SHAPExplainerAdapter( - AbstractSHAPExplainer, abc.ABC -): # pylint: disable=too-few-public-methods +class SHAPExplainerAdapter( # pylint: disable=too-few-public-methods + AbstractSHAPExplainer, + abc.ABC, +): """Adapter for SHAP explainer wrappers for handling molecules and pipelines.""" # used for dynamically defining the return type of the explain method @@ -244,7 +254,7 @@ def __init__( # determine type of returned explanation featurization_element = self.featurization_subpipeline.steps[-1][1] # type: ignore[union-attr] - if isinstance(featurization_element, MolToMorganFP): + if isinstance(featurization_element, MolToRDKitGenFPElement): self.return_element_type_ = SHAPFeatureAndAtomExplanation else: self.return_element_type_ = SHAPFeatureExplanation @@ -264,26 +274,27 @@ def _prediction_is_valid(prediction: Any) -> bool: ------- bool Whether the prediction is valid. + """ - # if no prediction could be obtained (length is 0); the prediction guaranteed failed. + # if no prediction could be obtained (length is 0); the prediction guaranteed + # failed. if len(prediction) == 0: return False # use pandas.isna function to check for invalid predictions, e.g. None, np.nan, # pd.NA. Note that fill values like 0 will be considered as valid predictions. - if pd.isna(prediction).any(): - return False - - return True + return not pd.isna(prediction).any() @override def explain( - self, X: Any, **kwargs: Any + self, + X: Any, + **kwargs: Any, ) -> list[SHAPFeatureExplanation | SHAPFeatureAndAtomExplanation]: """Explain the predictions for the input data. - If the calculation of the SHAP values for an input sample fails, the explanation will be invalid. - This can be checked with the Explanation.is_valid() method. + If the calculation of the SHAP values for an input sample fails, the explanation + will be invalid. This can be checked with the Explanation.is_valid() method. Parameters ---------- @@ -314,7 +325,8 @@ def explain( # get predictions prediction = _get_predictions(self.pipeline, input_sample) if not self._prediction_is_valid(prediction): - # we use the prediction to check if the input is valid. If not, we cannot explain it. + # we use the prediction to check if the input is valid. + # If not, we cannot explain it. explanation_results.append(self.return_element_type_()) continue @@ -344,8 +356,9 @@ def explain( bond_weights = None if issubclass( - self.return_element_type_, AtomExplanationMixin - ) and isinstance(featurization_element, MolToMorganFP): + self.return_element_type_, + AtomExplanationMixin, + ) and isinstance(featurization_element, MolToRDKitGenFPElement): # for Morgan fingerprint, we can map the shap values to atom weights atom_weights = _convert_shap_feature_weights_to_atom_weights( feature_weights, @@ -363,7 +376,8 @@ def explain( explanation_data["feature_vector"] = feature_vector if not hasattr(featurization_element, "feature_names"): raise ValueError( - "Featurization element does not have a get_feature_names method." + "Featurization element does not have a get_feature_names " + "method.", ) explanation_data["feature_names"] = featurization_element.feature_names # type: ignore[union-attr] @@ -375,7 +389,7 @@ def explain( explanation_data["bond_weights"] = bond_weights if issubclass(self.return_element_type_, SHAPExplanationMixin): explanation_data["expected_value"] = np.atleast_1d( - self.explainer.expected_value + self.explainer.expected_value, ) explanation_results.append(self.return_element_type_(**explanation_data)) @@ -409,6 +423,7 @@ def __init__( The pipeline containing the model to explain. kwargs : Any Additional keyword arguments for SHAP's Explainer. + """ explainer = self._create_explainer(pipeline, **kwargs) super().__init__(pipeline, explainer) @@ -428,18 +443,18 @@ def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.TreeExplainer: ------- shap.TreeExplainer The explainer object. + """ model = get_model_from_pipeline(pipeline, raise_not_found=True) - explainer = shap.TreeExplainer( + return shap.TreeExplainer( model, **kwargs, ) - return explainer -class SHAPKernelExplainer( - SHAPExplainerAdapter -): # pylint: disable=too-few-public-methods +class SHAPKernelExplainer( # pylint: disable=too-few-public-methods + SHAPExplainerAdapter, +): """Wrapper for SHAP's KernelExplainer that can handle pipelines and molecules.""" def __init__( @@ -455,6 +470,7 @@ def __init__( The pipeline containing the model to explain. kwargs : Any Additional keyword arguments for SHAP's Explainer. + """ explainer = self._create_explainer(pipeline, **kwargs) super().__init__(pipeline, explainer) @@ -474,11 +490,11 @@ def _create_explainer(pipeline: Pipeline, **kwargs: Any) -> shap.KernelExplainer ------- shap.KernelExplainer The explainer object. + """ model = get_model_from_pipeline(pipeline, raise_not_found=True) prediction_function = _get_prediction_function(model) - explainer = shap.KernelExplainer( + return shap.KernelExplainer( prediction_function, **kwargs, ) - return explainer diff --git a/molpipeline/experimental/explainability/fingerprint_utils.py b/molpipeline/experimental/explainability/fingerprint_utils.py index 755d4c7d..17687d7f 100644 --- a/molpipeline/experimental/explainability/fingerprint_utils.py +++ b/molpipeline/experimental/explainability/fingerprint_utils.py @@ -3,18 +3,24 @@ from __future__ import annotations from collections import defaultdict -from collections.abc import Sequence +from typing import TYPE_CHECKING import numpy as np import numpy.typing as npt -from molpipeline.abstract_pipeline_elements.core import RDKitMol -from molpipeline.mol2any import MolToMorganFP -from molpipeline.utils.substructure_handling import AtomEnvironment +if TYPE_CHECKING: + from collections.abc import Sequence + + from molpipeline.abstract_pipeline_elements.core import RDKitMol + from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( + MolToRDKitGenFPElement, + ) + from molpipeline.utils.substructure_handling import AtomEnvironment def assign_prediction_importance( - bit_dict: dict[int, Sequence[AtomEnvironment]], weights: npt.NDArray[np.float64] + bit_dict: dict[int, Sequence[AtomEnvironment]], + weights: npt.NDArray[np.float64], ) -> dict[int, float]: """Assign the prediction importance. @@ -40,7 +46,7 @@ def assign_prediction_importance( """ atom_contribution: dict[int, float] = defaultdict(lambda: 0) - for bit, atom_env_list in bit_dict.items(): # type: int, Sequence[AtomEnvironment] + for bit, atom_env_list in bit_dict.items(): n_machtes = len(atom_env_list) for atom_set in atom_env_list: for atom in atom_set.environment_atoms: @@ -50,13 +56,15 @@ def assign_prediction_importance( if not np.isclose(sum(weights), sum(atom_contribution.values())).all(): raise AssertionError( f"Weights and atom contributions don't sum to the same value:" - f" {weights.sum()} != {sum(atom_contribution.values())}" + f" {weights.sum()} != {sum(atom_contribution.values())}", ) return atom_contribution def fingerprint_shap_to_atomweights( - mol: RDKitMol, fingerprint_element: MolToMorganFP, shap_mat: npt.NDArray[np.float64] + mol: RDKitMol, + fingerprint_element: MolToRDKitGenFPElement, + shap_mat: npt.NDArray[np.float64], ) -> list[float]: """Convert SHAP values to atom weights. @@ -67,7 +75,7 @@ def fingerprint_shap_to_atomweights( ---------- mol : RDKitMol The molecule. - fingerprint_element : MolToMorganFP + fingerprint_element : MolToRDKitGenFPElement The fingerprint element. shap_mat : npt.NDArray[np.float64] The SHAP values. @@ -76,14 +84,11 @@ def fingerprint_shap_to_atomweights( ------- list[float] The atom weights. + """ bit_atom_env_dict: dict[int, Sequence[AtomEnvironment]] bit_atom_env_dict = dict( - fingerprint_element.bit2atom_mapping(mol) + fingerprint_element.bit2atom_mapping(mol), ) # MyPy invariants make me do this. atom_weight_dict = assign_prediction_importance(bit_atom_env_dict, shap_mat) - atom_weight_list = [ - atom_weight_dict[a_idx] if a_idx in atom_weight_dict else 0 - for a_idx in range(mol.GetNumAtoms()) - ] - return atom_weight_list + return [atom_weight_dict.get(a_idx, 0) for a_idx in range(mol.GetNumAtoms())] diff --git a/molpipeline/mol2any/mol2morgan_fingerprint.py b/molpipeline/mol2any/mol2morgan_fingerprint.py index 7c40bdd2..bc46d121 100644 --- a/molpipeline/mol2any/mol2morgan_fingerprint.py +++ b/molpipeline/mol2any/mol2morgan_fingerprint.py @@ -2,24 +2,81 @@ from __future__ import annotations # for all the python 3.8 users out there. -import copy -from typing import Any, Literal, Self +from typing import TYPE_CHECKING, Any, Literal, Self -from rdkit.Chem import AllChem, rdFingerprintGenerator +from rdkit.Chem import rdFingerprintGenerator from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( - ABCMorganFingerprintPipelineElement, + MolToRDKitGenFPElement, ) -from molpipeline.utils.molpipeline_types import RDKitMol +from molpipeline.utils.substructure_handling import CircularAtomEnvironment +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence -class MolToMorganFP(ABCMorganFingerprintPipelineElement): + from molpipeline.utils.molpipeline_types import RDKitMol + + +class MolToMorganFP(MolToRDKitGenFPElement): """Folded Morgan Fingerprint. Feature-mapping to vector-positions is arbitrary. """ + _radius: int + _use_features: bool + + @property + def radius(self) -> int: + """Get radius of Morgan fingerprint.""" + return self._radius + + @radius.setter + def radius(self, value: int) -> None: + """Set radius of Morgan fingerprint. + + Parameters + ---------- + value: int + Radius of Morgan fingerprint. + + Raises + ------ + ValueError + If value is not a positive integer. + + """ + if not isinstance(value, int) or value < 0: + raise ValueError( + f"Radius has to be a positive integer! (Received: {value})", + ) + self._radius = value + + @property + def use_features(self) -> bool: + """Get whether to encode atoms by features or not.""" + return self._use_features + + @use_features.setter + def use_features(self, value: bool) -> None: + """Set whether to encode atoms by features or not. + + Parameters + ---------- + value: bool + Whether to encode atoms by features or not. + + Raises + ------ + ValueError + If value is not a boolean. + + """ + if not isinstance(value, bool): + raise ValueError(f"Use features has to be a boolean! (Received: {value})") + self._use_features = value + # pylint: disable=R0913 def __init__( self, @@ -64,28 +121,19 @@ def __init__( [1] https://rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints [2] https://rdkit.org/docs/GettingStartedInPython.html#feature-definitions-used-in-the-morgan-fingerprints - Raises - ------ - ValueError - If n_bits is not a positive integer. - """ # pylint: disable=R0801 super().__init__( - radius=radius, - use_features=use_features, + n_bits=n_bits, counted=counted, return_as=return_as, name=name, n_jobs=n_jobs, uuid=uuid, ) - if not isinstance(n_bits, int) or n_bits < 1: - raise ValueError( - f"Number of bits has to be a integer > 0! (Received: {n_bits})" - ) - self._n_bits = n_bits - self._feature_names = [f"morgan_{i}" for i in range(self._n_bits)] + self.use_features = use_features + self.radius = radius + self._feature_names = [f"morgan_{i}" for i in range(self.n_bits)] def get_params(self, deep: bool = True) -> dict[str, Any]: """Return all parameters defining the object. @@ -99,12 +147,12 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary of parameters. + """ parameters = super().get_params(deep) - if deep: - parameters["n_bits"] = copy.copy(self._n_bits) - else: - parameters["n_bits"] = self._n_bits + parameters["n_bits"] = self.n_bits + parameters["radius"] = self.radius + parameters["use_features"] = self.use_features return parameters def set_params(self, **parameters: Any) -> Self: @@ -119,11 +167,15 @@ def set_params(self, **parameters: Any) -> Self: ------- Self MolToMorganFP pipeline element with updated parameters. + """ parameter_copy = dict(parameters) - n_bits = parameter_copy.pop("n_bits", None) - if n_bits is not None: - self._n_bits = n_bits + if "n_bits" in parameter_copy: + self.n_bits = parameter_copy.pop("n_bits") + if "radius" in parameter_copy: + self.radius = parameter_copy.pop("radius") + if "use_features" in parameter_copy: + self.use_features = parameter_copy.pop("use_features") super().set_params(**parameter_copy) return self @@ -137,13 +189,17 @@ def _get_fp_generator( ------- rdFingerprintGenerator.FingerprintGenerator RDKit fingerprint generator. + """ return rdFingerprintGenerator.GetMorganGenerator( radius=self.radius, - fpSize=self._n_bits, + fpSize=self.n_bits, ) - def _explain_rdmol(self, mol_obj: RDKitMol) -> dict[int, list[tuple[int, int]]]: + def bit2atom_mapping( + self, + mol_obj: RDKitMol, + ) -> Mapping[int, Sequence[CircularAtomEnvironment]]: """Get central atom and radius of all features in molecule. Parameters @@ -153,13 +209,20 @@ def _explain_rdmol(self, mol_obj: RDKitMol) -> dict[int, list[tuple[int, int]]]: Returns ------- - dict[int, list[tuple[int, int]]] - Dictionary with bit position as key and list of tuples with atom index and radius as value. + Mapping[int, list[tuple[int, int]]] + Dictionary with bit position as key and list of tuples with atom index and + radius as value. + """ fp_generator = self._get_fp_generator() - additional_output = AllChem.AdditionalOutput() + additional_output = rdFingerprintGenerator.AdditionalOutput() additional_output.AllocateBitInfoMap() - # using the dense fingerprint here, to get indices after folding _ = fp_generator.GetFingerprint(mol_obj, additionalOutput=additional_output) - bit_info = additional_output.GetBitInfoMap() - return bit_info + result_dict: dict[int, list[CircularAtomEnvironment]] = {} + # Iterating over all present bits and respective matches + for bit, matches in additional_output.GetBitInfoMap().items(): + result_dict[bit] = [] + for central_atom, radius in matches: + env = CircularAtomEnvironment.from_mol(mol_obj, central_atom, radius) + result_dict[bit].append(env) + return result_dict diff --git a/molpipeline/mol2any/mol2path_fingerprint.py b/molpipeline/mol2any/mol2path_fingerprint.py index 19938a03..dd5c1ea6 100644 --- a/molpipeline/mol2any/mol2path_fingerprint.py +++ b/molpipeline/mol2any/mol2path_fingerprint.py @@ -3,17 +3,23 @@ from __future__ import annotations # for all the python 3.8 users out there. import copy -from typing import Any, Literal, Self +from typing import TYPE_CHECKING, Any, Literal, Self from rdkit.Chem import rdFingerprintGenerator from molpipeline.abstract_pipeline_elements.mol2any.mol2bitvector import ( MolToRDKitGenFPElement, ) +from molpipeline.utils.substructure_handling import AtomEnvironment + +if TYPE_CHECKING: + from collections.abc import Mapping, Sequence + + from molpipeline.utils.molpipeline_types import RDKitMol class Mol2PathFP( - MolToRDKitGenFPElement + MolToRDKitGenFPElement, ): # pylint: disable=too-many-instance-attributes """Folded Path Fingerprint. @@ -22,7 +28,7 @@ class Mol2PathFP( """ # pylint: disable=too-many-arguments,too-many-locals,too-many-positional-arguments - def __init__( + def __init__( # noqa: PLR0917 self, min_path: int = 1, max_path: int = 7, @@ -88,9 +94,11 @@ def __init__( ------ ValueError If the number of bits is not a positive integer. + """ # pylint: disable=R0801 super().__init__( + n_bits=n_bits, counted=counted, return_as=return_as, name=name, @@ -99,7 +107,7 @@ def __init__( ) if not isinstance(n_bits, int) or n_bits < 1: raise ValueError( - f"Number of bits has to be an integer > 0! (Received: {n_bits})" + f"Number of bits has to be an integer > 0! (Received: {n_bits})", ) self._n_bits = n_bits self._feature_names = [f"path_{i}" for i in range(self._n_bits)] @@ -125,6 +133,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: ------- dict[str, Any] Dictionary of parameters. + """ parameters = super().get_params(deep) if deep: @@ -137,7 +146,7 @@ def get_params(self, deep: bool = True) -> dict[str, Any]: parameters["count_bounds"] = copy.copy(self._count_bounds) parameters["num_bits_per_feature"] = int(self._num_bits_per_feature) parameters["atom_invariants_generator"] = copy.copy( - self._atom_invariants_generator + self._atom_invariants_generator, ) parameters["n_bits"] = int(self._n_bits) else: @@ -165,6 +174,7 @@ def set_params(self, **parameters: Any) -> Self: ------- Self MolToMorganFP pipeline element with updated parameters. + """ parameter_copy = dict(parameters) min_path = parameter_copy.pop("min_path", None) @@ -192,7 +202,8 @@ def set_params(self, **parameters: Any) -> Self: if num_bits_per_feature is not None: self._num_bits_per_feature = num_bits_per_feature atom_invariants_generator = parameter_copy.pop( - "atom_invariants_generator", None + "atom_invariants_generator", + None, ) if atom_invariants_generator is not None: self._atom_invariants_generator = atom_invariants_generator @@ -209,6 +220,7 @@ def _get_fp_generator(self) -> rdFingerprintGenerator.FingerprintGenerator64: ------- rdFingerprintGenerator.GetRDKitFPGenerator RDKit Path fingerprint generator. + """ return rdFingerprintGenerator.GetRDKitFPGenerator( minPath=self._min_path, @@ -222,3 +234,35 @@ def _get_fp_generator(self) -> rdFingerprintGenerator.FingerprintGenerator64: numBitsPerFeature=self._num_bits_per_feature, atomInvariantsGenerator=self._atom_invariants_generator, ) + + def bit2atom_mapping( + self, + mol_obj: RDKitMol, + ) -> Mapping[int, Sequence[AtomEnvironment]]: + """Get central atom and radius of all features in molecule. + + Parameters + ---------- + mol_obj: RDKitMol + RDKit molecule object + + Returns + ------- + Mapping[int, list[tuple[int, int]]] + Dictionary with bit position as key and list of tuples with atom index and + radius as value. + + """ + fp_generator = self._get_fp_generator() + additional_output = rdFingerprintGenerator.AdditionalOutput() + additional_output.AllocateBitInfoMap() + additional_output.AllocateBitPaths() + _ = fp_generator.GetFingerprint(mol_obj, additionalOutput=additional_output) + result_dict: dict[int, list[AtomEnvironment]] = {} + # Iterating over all present bits and respective matches + for bit, matches in additional_output.GetBitPaths().items(): + result_dict[bit] = [] + for atom_sequence in matches: + env = AtomEnvironment(set(atom_sequence)) + result_dict[bit].append(env) + return result_dict diff --git a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py index ede05f56..36f819b5 100644 --- a/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2morgan_fingerprint.py @@ -22,10 +22,10 @@ class TestMol2MorganFingerprint(unittest.TestCase): - """Unittest for MolToFoldedMorganFingerprint, which calculates folded Morgan Fingerprints.""" + """Unittest for MolToFoldedMorganFingerprint.""" def test_clone(self) -> None: - """Test if the MolToFoldedMorganFingerprint pipeline element can be constructed.""" + """Test cloning MolToFoldedMorganFingerprint.""" mol_fp = MolToMorganFP() mol_fp_copy = clone(mol_fp) self.assertTrue(mol_fp_copy is not mol_fp) @@ -49,7 +49,7 @@ def test_counted_bits(self) -> None: pipeline.set_params(mol_fp__counted=True) output_counted = pipeline.fit_transform(test_smiles) self.assertTrue( - np.all(np.flatnonzero(output_counted) == np.flatnonzero(output_binary)) + np.all(np.flatnonzero(output_counted) == np.flatnonzero(output_binary)), ) self.assertTrue(np.all(output_counted >= output_binary)) self.assertTrue(np.any(output_counted > output_binary)) @@ -67,7 +67,9 @@ def test_return_value_types(self) -> None: sparse_morgan = MolToMorganFP(radius=2, n_bits=1024, return_as="sparse") dense_morgan = MolToMorganFP(radius=2, n_bits=1024, return_as="dense") explicit_bit_vect_morgan = MolToMorganFP( - radius=2, n_bits=1024, return_as="explicit_bit_vect" + radius=2, + n_bits=1024, + return_as="explicit_bit_vect", ) sparse_pipeline = Pipeline( [ @@ -91,7 +93,7 @@ def test_return_value_types(self) -> None: sparse_output = sparse_pipeline.fit_transform(test_smiles) dense_output = dense_pipeline.fit_transform(test_smiles) explicit_bit_vect_morgan_output = explicit_bit_vect_pipeline.fit_transform( - test_smiles + test_smiles, ) self.assertTrue(np.all(sparse_output.toarray() == dense_output)) @@ -100,7 +102,7 @@ def test_return_value_types(self) -> None: np.equal( dense_output, np.array(explicit_bit_vect_morgan_output), - ).all() + ).all(), ) def test_setter_getter(self) -> None: @@ -116,8 +118,8 @@ def test_setter_getter(self) -> None: self.assertEqual(mol_fp.get_params()["n_bits"], 1024) self.assertEqual(mol_fp.get_params()["return_as"], "dense") - def test_setter_getter_error_handling(self) -> None: - """Test if the setters and getters work as expected when errors are encountered.""" + def test_setter_invalid_input(self) -> None: + """Test if the setters raise an error for invalid input.""" mol_fp = MolToMorganFP() params: dict[str, Any] = { "radius": 2, @@ -139,7 +141,9 @@ def test_bit2atom_mapping(self) -> None: sparse_morgan = MolToMorganFP(radius=2, n_bits=n_bits, return_as="sparse") dense_morgan = MolToMorganFP(radius=2, n_bits=n_bits, return_as="dense") explicit_bit_vect_morgan = MolToMorganFP( - radius=2, n_bits=n_bits, return_as="explicit_bit_vect" + radius=2, + n_bits=n_bits, + return_as="explicit_bit_vect", ) smi2mol = SmilesToMol() @@ -163,6 +167,26 @@ def test_feature_names(self) -> None: # feature names should be unique self.assertEqual(len(feature_names), len(set(feature_names))) + def test_bit_mapping(self) -> None: + """Test if the mapped bits are identical to the original bits. + + Raises + ------ + AssertionError + The SMILES provided by the unit test are invalid. + + """ + mol_fp = MolToMorganFP(n_bits=1024) + + for smiles in test_smiles: + mol = SmilesToMol().transform([smiles])[0] + if isinstance(mol, InvalidInstance): + raise AssertionError(f"Invalid molecule: {smiles}") + fp = mol_fp.transform([mol]) + explained_bits = mol_fp.bit2atom_mapping(mol) + self.assertEqual(fp[0].nonzero()[1].shape[0], len(explained_bits)) + self.assertEqual(sorted(fp[0].nonzero()[1]), sorted(explained_bits.keys())) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py b/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py index ff160b5b..208f8cf9 100644 --- a/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py +++ b/tests/test_elements/test_mol2any/test_mol2path_fingerprint.py @@ -9,6 +9,7 @@ from sklearn.base import clone from molpipeline import Pipeline +from molpipeline.abstract_pipeline_elements.core import InvalidInstance from molpipeline.any2mol import SmilesToMol from molpipeline.mol2any import Mol2PathFP @@ -41,7 +42,8 @@ def test_output_types(self) -> None: sparse_path_fp = Mol2PathFP(n_bits=1024, return_as="sparse") dense_path_fp = Mol2PathFP(n_bits=1024, return_as="dense") explicit_bit_vect_path_fp = Mol2PathFP( - n_bits=1024, return_as="explicit_bit_vect" + n_bits=1024, + return_as="explicit_bit_vect", ) sparse_pipeline = Pipeline( [ @@ -65,7 +67,7 @@ def test_output_types(self) -> None: sparse_output = sparse_pipeline.fit_transform(test_smiles) dense_output = dense_pipeline.fit_transform(test_smiles) explicit_bit_vect_path_fp_output = explicit_bit_vect_pipeline.fit_transform( - test_smiles + test_smiles, ) self.assertTrue(np.all(sparse_output.toarray() == dense_output)) @@ -74,7 +76,7 @@ def test_output_types(self) -> None: np.equal( dense_output, np.array(explicit_bit_vect_path_fp_output), - ).all() + ).all(), ) def test_counted_bits(self) -> None: @@ -91,7 +93,7 @@ def test_counted_bits(self) -> None: pipeline.set_params(mol_fp__counted=True) output_counted = pipeline.fit_transform(test_smiles) self.assertTrue( - np.all(np.flatnonzero(output_counted) == np.flatnonzero(output_binary)) + np.all(np.flatnonzero(output_counted) == np.flatnonzero(output_binary)), ) self.assertTrue(np.all(output_counted >= output_binary)) self.assertTrue(np.any(output_counted > output_binary)) @@ -121,8 +123,8 @@ def test_setter_getter(self) -> None: self.assertEqual(mol_fp.get_params()["counted"], True) self.assertEqual(mol_fp.get_params()["n_bits"], 1024) - def test_setter_getter_error_handling(self) -> None: - """Test if the setters and getters work as expected when errors are encountered.""" + def test_setter_invalid_input(self) -> None: + """Test if the setters raise an error for invalid input.""" mol_fp = Mol2PathFP() params: dict[str, Any] = { "min_path": 2, @@ -139,6 +141,26 @@ def test_feature_names(self) -> None: # feature names should be unique self.assertEqual(len(feature_names), len(set(feature_names))) + def test_bit_mapping(self) -> None: + """Test if the mapped bits are identical to the original bits. + + Raises + ------ + AssertionError + The SMILES provided by the unit test are invalid. + + """ + mol_fp = Mol2PathFP(n_bits=1024) + + for smiles in test_smiles: + mol = SmilesToMol().transform([smiles])[0] + if isinstance(mol, InvalidInstance): + raise AssertionError(f"Invalid molecule: {smiles}") + fp = mol_fp.transform([mol]) + explained_bits = mol_fp.bit2atom_mapping(mol) + self.assertEqual(fp[0].nonzero()[1].shape[0], len(explained_bits)) + self.assertEqual(sorted(fp[0].nonzero()[1]), sorted(explained_bits.keys())) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_experimental/test_explainability/test_shap_explainers.py b/tests/test_experimental/test_explainability/test_shap_explainers.py index 97519511..245801aa 100644 --- a/tests/test_experimental/test_explainability/test_shap_explainers.py +++ b/tests/test_experimental/test_explainability/test_shap_explainers.py @@ -1,8 +1,10 @@ """Test SHAP's TreeExplainer wrapper.""" import unittest +from itertools import product import numpy as np +import numpy.typing as npt import pandas as pd from rdkit import Chem, rdBase from sklearn.base import BaseEstimator, is_classifier, is_regressor @@ -24,8 +26,10 @@ SHAPKernelExplainer, SHAPTreeExplainer, ) +from molpipeline.experimental.explainability.explainer import SHAPExplainerAdapter from molpipeline.experimental.explainability.explanation import AtomExplanationMixin from molpipeline.mol2any import ( + Mol2PathFP, MolToConcatenatedVector, MolToMorganFP, MolToRDKitPhysChem, @@ -102,7 +106,7 @@ def _test_valid_explanation( all( isinstance(name, str) and len(name) > 0 for name in explanation.feature_names # type: ignore[union-attr] - ) + ), ) self.assertEqual( len(explanation.feature_names), # type: ignore @@ -126,18 +130,20 @@ def _test_valid_explanation( elif is_classifier(estimator): self.assertTrue((2,), explanation.prediction.shape) # type: ignore[union-attr] if isinstance(explainer, SHAPTreeExplainer) and isinstance( - estimator, GradientBoostingClassifier + estimator, + GradientBoostingClassifier, ): - # there is currently a bug in SHAP's TreeExplainer for GradientBoostingClassifier - # https://github.com/shap/shap/issues/3177 returning only one feature weight - # which is also based on log odds. This check is a workaround until the bug is fixed. + # there is currently a bug in SHAP's TreeExplainer for + # GradientBoostingClassifier https://github.com/shap/shap/issues/3177 + # returning only one feature weight which is also based on log odds. + # This check is a workaround until the bug is fixed. self.assertEqual( (nof_features,), explanation.feature_weights.shape, # type: ignore[union-attr] ) elif isinstance(estimator, SVC): - # SVC seems to be handled differently by SHAP. It returns only a one dimensional - # feature array for binary classification. + # SVC seems to be handled differently by SHAP. It returns only a one + # dimensional feature array for binary classification. self.assertTrue( (1,), explanation.prediction.shape, # type: ignore[union-attr] @@ -162,10 +168,62 @@ def _test_valid_explanation( (explanation.molecule.GetNumAtoms(),), # type: ignore[union-attr] ) + def _test_pipeline_explanation( + self, + pipeline: Pipeline, + explainer_type: type[SHAPExplainerAdapter], + test_smiles: list[str], + labels: npt.ArrayLike, + ) -> None: + """Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines. + + Parameters + ---------- + pipeline : Pipeline + The pipeline to be tested. + explainer_type : type[SHAPExplainerAdapter] + The explainer used to generate the explanation. + test_smiles : list[str] + The SMILES strings of the molecules. + labels : npt.ArrayLike + The labels of the molecules. + + """ + pipeline.fit(test_smiles, labels) + + # some explainers require additional kwargs + explainer_kwargs = {} + if explainer_type == SHAPKernelExplainer: + explainer_kwargs = construct_kernel_shap_kwargs(pipeline, test_smiles) + + explainer = explainer_type(pipeline, **explainer_kwargs) + explanations = explainer.explain(test_smiles) + self.assertEqual(len(explanations), len(test_smiles)) + + self.assertTrue( + issubclass(explainer.return_element_type_, AtomExplanationMixin), + ) + + # get the subpipeline that extracts the molecule from the input data + mol_reader_subpipeline = SubpipelineExtractor( + pipeline, + ).get_molecule_reader_subpipeline() + self.assertIsInstance(mol_reader_subpipeline, Pipeline) + + for i, explanation in enumerate(explanations): + self._test_valid_explanation( + explanation, + pipeline.named_steps["model"], + mol_reader_subpipeline, # type: ignore[arg-type] + pipeline.named_steps["encoding"].n_bits, + test_smiles[i], + explainer=explainer, # type: ignore[arg-type] + ) + def test_explanations_fingerprint_pipeline( # pylint: disable=too-many-locals self, ) -> None: - """Test SHAP's TreeExplainer wrapper on MolPipeline's pipelines with fingerprints.""" + """Test SHAP's TreeExplainer wrapper on Pipelines with fingerprints.""" tree_estimators = [ RandomForestClassifier(n_estimators=2, random_state=_RANDOM_STATE), RandomForestRegressor(n_estimators=2, random_state=_RANDOM_STATE), @@ -185,50 +243,33 @@ def test_explanations_fingerprint_pipeline( # pylint: disable=too-many-locals SHAPTreeExplainer, ] explainer_estimators = [tree_estimators + other_estimators, tree_estimators] + mol_encoder_list = [ + MolToMorganFP(radius=1, n_bits=n_bits), + Mol2PathFP(n_bits=n_bits, min_path=1, max_path=2), + ] for estimators, explainer_type in zip( - explainer_estimators, explainer_types, strict=True + explainer_estimators, + explainer_types, + strict=True, ): # test explanations with different estimators - for estimator in estimators: + for estimator, encoding in product(estimators, mol_encoder_list): pipeline = Pipeline( [ ("smi2mol", SmilesToMol()), - ("morgan", MolToMorganFP(radius=1, n_bits=n_bits)), + ("encoding", encoding), ("model", estimator), - ] + ], ) - pipeline.fit(TEST_SMILES, CONTAINS_OX) - - # some explainers require additional kwargs - explainer_kwargs = {} - if explainer_type == SHAPKernelExplainer: - explainer_kwargs = construct_kernel_shap_kwargs( - pipeline, TEST_SMILES - ) - - explainer = explainer_type(pipeline, **explainer_kwargs) - explanations = explainer.explain(TEST_SMILES) - self.assertEqual(len(explanations), len(TEST_SMILES)) - - self.assertTrue( - issubclass(explainer.return_element_type_, AtomExplanationMixin) - ) - - # get the subpipeline that extracts the molecule from the input data - mol_reader_subpipeline = SubpipelineExtractor( - pipeline - ).get_molecule_reader_subpipeline() - self.assertIsInstance(mol_reader_subpipeline, Pipeline) - - for i, explanation in enumerate(explanations): - self._test_valid_explanation( - explanation, - estimator, - mol_reader_subpipeline, # type: ignore[arg-type] - n_bits, - TEST_SMILES[i], - explainer=explainer, # type: ignore[arg-type] + self.assertEqual(pipeline.named_steps["encoding"].n_bits, n_bits) + self.assertIs(pipeline.named_steps["model"], estimator) + with self.subTest(estimator=estimator, encoding=encoding): + self._test_pipeline_explanation( + pipeline, + explainer_type, + TEST_SMILES, + CONTAINS_OX, ) # pylint: disable=too-many-locals @@ -262,13 +303,13 @@ def test_explanations_pipeline_with_invalid_inputs(self) -> None: ("error_filter", error_filter1), ("morgan", MolToMorganFP(radius=1, n_bits=64)), ("model", estimator), - ] + ], ) # pipeline with ErrorFilter and FilterReinserter error_filter2 = ErrorFilter() error_reinserter2 = PostPredictionWrapper( - FilterReinserter.from_error_filter(error_filter2, fill_value) + FilterReinserter.from_error_filter(error_filter2, fill_value), ) pipeline2 = Pipeline( [ @@ -278,7 +319,7 @@ def test_explanations_pipeline_with_invalid_inputs(self) -> None: ("morgan", MolToMorganFP(radius=1, n_bits=n_bits)), ("model", estimator), ("error_reinserter", error_reinserter2), - ] + ], ) for pipeline in [pipeline1, pipeline2]: @@ -289,12 +330,13 @@ def test_explanations_pipeline_with_invalid_inputs(self) -> None: explanations = explainer.explain(TEST_SMILES_WITH_BAD_SMILES) del log_block self.assertEqual( - len(explanations), len(TEST_SMILES_WITH_BAD_SMILES) + len(explanations), + len(TEST_SMILES_WITH_BAD_SMILES), ) # get the subpipeline that extracts the molecule from the input data mol_reader_subpipeline = SubpipelineExtractor( - pipeline + pipeline, ).get_molecule_reader_subpipeline() self.assertIsNotNone(mol_reader_subpipeline) @@ -328,7 +370,7 @@ def test_explanations_pipeline_with_physchem(self) -> None: ("smi2mol", SmilesToMol()), ("physchem", MolToRDKitPhysChem()), ("model", estimator), - ] + ], ) pipeline.fit(TEST_SMILES, CONTAINS_OX) @@ -339,7 +381,7 @@ def test_explanations_pipeline_with_physchem(self) -> None: # get the subpipeline that extracts the molecule from the input data mol_reader_subpipeline = SubpipelineExtractor( - pipeline + pipeline, ).get_molecule_reader_subpipeline() self.assertIsNotNone(mol_reader_subpipeline) @@ -386,11 +428,11 @@ def test_explanations_pipeline_with_concatenated_features(self) -> None: "MorganFP", MolToMorganFP(radius=1, n_bits=n_bits), ), - ] + ], ), ), ("model", estimator), - ] + ], ) pipeline.fit(TEST_SMILES, CONTAINS_OX) @@ -401,7 +443,7 @@ def test_explanations_pipeline_with_concatenated_features(self) -> None: # get the subpipeline that extracts the molecule from the input data mol_reader_subpipeline = SubpipelineExtractor( - pipeline + pipeline, ).get_molecule_reader_subpipeline() self.assertIsNotNone(mol_reader_subpipeline) From 3486648213ea95525fd83042bf5f52212c66defe Mon Sep 17 00:00:00 2001 From: Christian Feldmann Date: Tue, 13 May 2025 18:15:02 +0200 Subject: [PATCH 5/5] remove mol_counter from testing --- tests/test_elements/test_any2mol/test_sdf2mol.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_elements/test_any2mol/test_sdf2mol.py b/tests/test_elements/test_any2mol/test_sdf2mol.py index 3fd56572..afdbbe95 100644 --- a/tests/test_elements/test_any2mol/test_sdf2mol.py +++ b/tests/test_elements/test_any2mol/test_sdf2mol.py @@ -124,7 +124,6 @@ def test_initialization(self) -> None: self.assertEqual(sdf2mol.identifier, "smiles") self.assertEqual(sdf2mol.name, "CustomName") self.assertEqual(sdf2mol.n_jobs, 2) - self.assertEqual(sdf2mol.mol_counter, 0) def test_pretransform_valid_sdf(self) -> None: """Test transformation of valid SDF string.