From e3ea8002943abdc2efe21dbdd31c0157c38d27e6 Mon Sep 17 00:00:00 2001 From: Jochen Sieg Date: Tue, 6 May 2025 21:35:03 +0200 Subject: [PATCH] running version --- molpipeline/mol2any/mol2data_from_csv.py | 250 ++++++++++++++++++ .../test_mol2any/test_mol2data_from_csv.py | 215 +++++++++++++++ 2 files changed, 465 insertions(+) create mode 100644 molpipeline/mol2any/mol2data_from_csv.py create mode 100644 tests/test_elements/test_mol2any/test_mol2data_from_csv.py diff --git a/molpipeline/mol2any/mol2data_from_csv.py b/molpipeline/mol2any/mol2data_from_csv.py new file mode 100644 index 00000000..8784eec6 --- /dev/null +++ b/molpipeline/mol2any/mol2data_from_csv.py @@ -0,0 +1,250 @@ +"""Element that reads features/descriptors from a file.""" + +from __future__ import annotations + +import warnings +from collections.abc import Iterable +from pathlib import Path +from typing import Any, Literal, Sequence + +import numpy as np +import numpy.typing as npt +import pandas as pd + +from molpipeline.abstract_pipeline_elements.core import ( + MolToAnyPipelineElement, + InvalidInstance, +) + +from molpipeline.mol2any import MolToSmiles, MolToInchi, MolToInchiKey +from molpipeline.utils.molpipeline_types import RDKitMol + + +def _mol_to_identifier(mol: RDKitMol, id_type: str) -> str: + """Convert a molecule to its identifier. + + Parameters + ---------- + mol: RDKitMol + Molecule to convert + id_type: str + Type of identifier to use. Can be "smiles", "inchi", or "inchikey". + + Returns + ------- + str + Identifier for the molecule + + Raises + ------ + ValueError + If id_type is not one of "smiles", "inchi", or "inchikey" + """ + if id_type == "smiles": + return MolToSmiles().transform_single(mol) + elif id_type == "inchi": + return MolToInchi().transform_single(mol) + elif id_type == "inchikey": + return MolToInchiKey().transform_single(mol) + else: + raise ValueError(f"Invalid id_type: {id_type}") + + +class MolToDataFromCSV(MolToAnyPipelineElement): + """Pipeline element that reads precalculated descriptors from a CSV file. + + Maps molecules to their descriptors using an identifier column (e.g. SMILES, InChI). + """ + + def __init__( + self, + feature_file_path: str | Path, + identifier_column: str, + feature_columns: list[str], + id_type: Literal["smiles", "inchi", "inchikey"] = "smiles", + missing_value_strategy: Literal["invalid_instance", "nan"] = "invalid_instance", + name: str = "MolToFeaturesFromFile", + n_jobs: int = 1, + uuid: str | None = None, + ) -> None: + """Initialize MolToFeaturesFromFile. + + Parameters + ---------- + feature_file_path: str | Path + Path to the file containing precalculated features + identifier_column: str + Name of the column containing molecule identifiers + feature_columns: list[str] + List of column names to extract as features + id_type: Literal["smiles", "inchi", "inchikey"], optional + Type of identifier to use for molecule matching. Default is "smiles" + missing_value_strategy: Literal["invalid_instance", "nan"], optional + Strategy for handling missing values. Default is "invalid_instance" + name: str, optional + Name of the pipeline element. Default is "MolToFeaturesFromFile" + n_jobs: int, optional + Number of parallel jobs. Default is 1 + uuid: str | None, optional + UUID of the pipeline element + + Raises + ------ + ValueError + If feature_columns is empty + FileNotFoundError + If feature_file_path doesn't exist + """ + if not feature_columns: + raise ValueError("Empty feature_columns is not allowed") + + self.feature_file_path = Path(feature_file_path) + self.identifier_column = identifier_column + self.feature_columns = feature_columns + self.id_type = id_type + self.missing_value_strategy = missing_value_strategy + + if not self.feature_file_path.exists(): + raise FileNotFoundError(f"Feature file not found: {self.feature_file_path}") + + self.features_df = MolToDataFromCSV._read_data_table( + self.feature_file_path, + self.identifier_column, + self.feature_columns, + ) + + # TODO check for uniqueness of identifier_column. Drop duplicates if necessary + + # Validate columns existence + missing_cols = set(self.feature_columns) - set(self.features_df.columns) + if missing_cols: + raise ValueError(f"Missing columns in feature file: {missing_cols}") + + # # Create lookup dictionary for faster access + # self.id_to_features = { + # id_val: self.features_df.loc[idx, self.feature_columns].values + # for idx, id_val in enumerate(self.features_df[self.identifier_column]) + # } + + super().__init__(name=name, n_jobs=n_jobs, uuid=uuid) + + @staticmethod + def _read_data_table( + feature_file_path: Path, + identifier_column: str, + feature_columns: Sequence[str], + ) -> pd.DataFrame: + sep = "," + if feature_file_path.name.endswith(".tsv"): + sep = "\t" + + try: + dtype_dict: dict[str, Any] = {col: np.float64 for col in feature_columns} + dtype_dict[identifier_column] = str + usecols = list(dtype_dict.keys()) + return pd.read_csv( + feature_file_path, + index_col=identifier_column, + usecols=usecols, + dtype=dtype_dict, + sep=sep, + ) + except Exception as e: + raise ValueError(f"Error reading feature file: {e}") from e + + def pretransform_single( + self, value: RDKitMol + ) -> npt.NDArray[np.float64] | InvalidInstance: + """Transform a molecule to its features from the file. + + Parameters + ---------- + value: RDKitMol + Molecule to transform + + Returns + ------- + npt.NDArray[np.float64] | InvalidInstance + Features for the molecule or InvalidInstance if not found and missing_value_strategy is "invalid_instance" + """ + try: + # Convert molecule to identifier + mol_id = _mol_to_identifier(value, self.id_type) + + # Look up features + if mol_id in self.features_df.index: + # Get features as numpy array + return self.features_df.loc[mol_id, self.feature_columns].to_numpy( + dtype=np.float64 + ) + + # Handle missing values + if self.missing_value_strategy == "invalid_instance": + return InvalidInstance( + self.uuid, + f"No features found for molecule with {self.id_type}: {mol_id}", + ) + else: # "nan" + return np.full(len(self.feature_columns), np.nan) + + except Exception as e: + warnings.warn(f"Error processing molecule: {e}", UserWarning, stacklevel=2) + return InvalidInstance(self.uuid, f"Error processing molecule: {str(e)}") + + def assemble_output( + self, value_list: Iterable[npt.NDArray[np.float64] | InvalidInstance] + ) -> list[npt.NDArray[np.float64] | InvalidInstance]: + """Assemble output from pretransform_single. + + Parameters + ---------- + value_list: Iterable + List of transformed values + + Returns + ------- + list + List of features arrays or InvalidInstance objects + """ + return np.array(list(value_list)) + + def get_params(self, deep: bool = True) -> dict[str, Any]: + """Get parameters for this pipeline element. + + Parameters + ---------- + deep: bool, default=True + If True, return parameters of subobjects + + Returns + ------- + dict + Parameters of this pipeline element + """ + params = super().get_params(deep=deep) + params.update( + { + "feature_file_path": self.feature_file_path, + "identifier_column": self.identifier_column, + "feature_columns": self.feature_columns, + "id_type": self.id_type, + "missing_value_strategy": self.missing_value_strategy, + } + ) + return params + + def set_params(self, **parameters: Any) -> MolToDataFromCSV: + """Set parameters of this pipeline element. + + Parameters + ---------- + **parameters + Parameters to set + + Returns + ------- + MolToDataFromCSV + Pipeline element with parameters set + """ + super().set_params(**parameters) + return self diff --git a/tests/test_elements/test_mol2any/test_mol2data_from_csv.py b/tests/test_elements/test_mol2any/test_mol2data_from_csv.py new file mode 100644 index 00000000..0954f85b --- /dev/null +++ b/tests/test_elements/test_mol2any/test_mol2data_from_csv.py @@ -0,0 +1,215 @@ +"""Test MolToFeaturesFromFile pipeline element.""" + +import unittest +import tempfile +import os +import pandas as pd +import numpy as np +from pathlib import Path +from rdkit import Chem + +from molpipeline.abstract_pipeline_elements.core import InvalidInstance +from molpipeline.mol2any.mol2data_from_csv import MolToDataFromCSV + + +class TestMolToFeaturesFromCSV(unittest.TestCase): + """Test the MolToFeaturesFromFile pipeline element.""" + + def setUp(self): + """Set up test data and molecules.""" + # Create test molecules with known identifiers + self.mols = [ + Chem.MolFromSmiles("CCO"), # ethanol + Chem.MolFromSmiles("CC(=O)O"), # acetic acid + Chem.MolFromSmiles("c1ccccc1"), # benzene + Chem.MolFromSmiles("CCCCCC"), # hexane - not in test data + ] + + # Create CSV files with test features + self.temp_dir = tempfile.TemporaryDirectory() + + # SMILES data + self.smiles_list = ["CCO", "CC(=O)O", "c1ccccc1"] + self.features_df = pd.DataFrame( + { + "smiles": self.smiles_list, + "feature1": [1.0, 2.0, 3.0], + "feature2": [4.0, 5.0, 6.0], + "feature3": [7.0, 8.0, 9.0], + } + ) + self.feature_file_path = Path(self.temp_dir.name) / "features.csv" + self.features_df.to_csv(self.feature_file_path, index=False) + + # InChI data + self.inchis = [Chem.MolToInchi(mol) for mol in self.mols[:3]] + self.features_df_inchi = pd.DataFrame( + { + "inchi": self.inchis, + "feature1": [1.0, 2.0, 3.0], + "feature2": [4.0, 5.0, 6.0], + } + ) + self.inchi_file_path = os.path.join(self.temp_dir.name, "features_inchi.csv") + self.features_df_inchi.to_csv(self.inchi_file_path, index=False) + + # InChIKey data + self.inchikeys = [Chem.MolToInchiKey(mol) for mol in self.mols[:3]] + self.features_df_inchikey = pd.DataFrame( + { + "inchikey": self.inchikeys, + "feature1": [1.0, 2.0, 3.0], + "feature2": [4.0, 5.0, 6.0], + } + ) + self.inchikey_file_path = os.path.join( + self.temp_dir.name, "features_inchikey.csv" + ) + self.features_df_inchikey.to_csv(self.inchikey_file_path, index=False) + + def tearDown(self): + """Clean up temporary files.""" + self.temp_dir.cleanup() + + def test_basic_functionality(self): + """Test that features are correctly returned for molecules.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=["feature1", "feature2", "feature3"], + ) + + results = mol2feat.transform(self.mols[:3]) + + self.assertEqual(len(results), 3) + np.testing.assert_array_equal(results[0], np.array([1.0, 4.0, 7.0])) + np.testing.assert_array_equal(results[1], np.array([2.0, 5.0, 8.0])) + np.testing.assert_array_equal(results[2], np.array([3.0, 6.0, 9.0])) + + def test_missing_molecule_invalid_instance(self): + """Test handling of missing molecules with invalid_instance strategy.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=["feature1", "feature2", "feature3"], + missing_value_strategy="invalid_instance", + ) + + results = mol2feat.transform(self.mols) # Include hexane (not in data) + + self.assertEqual(len(results), 4) + self.assertIsInstance(results[3], InvalidInstance) + + def test_missing_molecule_nan(self): + """Test handling of missing molecules with nan strategy.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=["feature1", "feature2", "feature3"], + missing_value_strategy="nan", + ) + + results = mol2feat.transform(self.mols) # Include hexane (not in data) + + self.assertEqual(len(results), 4) + self.assertTrue(np.isnan(results[3]).all()) + + def test_inchi_identifier(self): + """Test using InChI as the identifier.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.inchi_file_path, + identifier_column="inchi", + feature_columns=["feature1", "feature2"], + id_type="inchi", + ) + + results = mol2feat.transform(self.mols[:3]) + + self.assertEqual(len(results), 3) + np.testing.assert_array_equal(results[0], np.array([1.0, 4.0])) + + def test_inchikey_identifier(self): + """Test using InChIKey as the identifier.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.inchikey_file_path, + identifier_column="inchikey", + feature_columns=["feature1", "feature2"], + id_type="inchikey", + ) + + results = mol2feat.transform(self.mols[:3]) + + self.assertEqual(len(results), 3) + np.testing.assert_array_equal(results[0], np.array([1.0, 4.0])) + + def test_empty_feature_columns(self): + """Test that an empty feature_columns list raises ValueError.""" + with self.assertRaises(ValueError) as context: + MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=[], + ) + self.assertTrue( + str(context.exception).startswith("Empty feature_columns is not allowed") + ) + + def test_nonexistent_file(self): + """Test that a nonexistent feature file raises FileNotFoundError.""" + with self.assertRaises(FileNotFoundError): + MolToDataFromCSV( + feature_file_path="nonexistent_file.csv", + identifier_column="smiles", + feature_columns=["feature1"], + ) + + def test_missing_columns(self): + """Test that missing columns in the feature file raise ValueError.""" + with self.assertRaises(ValueError) as context: + MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=["feature1", "nonexistent_feature"], + ) + self.assertTrue(str(context.exception).startswith("Error reading feature file")) + + def test_get_params(self): + """Test that get_params returns the correct parameters.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=["feature1", "feature2"], + id_type="smiles", + missing_value_strategy="invalid_instance", + name="TestElement", + n_jobs=2, + ) + + params = mol2feat.get_params() + + self.assertEqual(params["feature_file_path"], Path(self.feature_file_path)) + self.assertEqual(params["identifier_column"], "smiles") + self.assertEqual(params["feature_columns"], ["feature1", "feature2"]) + self.assertEqual(params["id_type"], "smiles") + self.assertEqual(params["missing_value_strategy"], "invalid_instance") + self.assertEqual(params["name"], "TestElement") + self.assertEqual(params["n_jobs"], 2) + + def test_set_params(self): + """Test that set_params correctly sets parameters.""" + mol2feat = MolToDataFromCSV( + feature_file_path=self.feature_file_path, + identifier_column="smiles", + feature_columns=["feature1"], + name="OriginalName", + ) + + mol2feat.set_params(name="NewName", n_jobs=4) + + params = mol2feat.get_params() + self.assertEqual(params["name"], "NewName") + self.assertEqual(params["n_jobs"], 4) + + +if __name__ == "__main__": + unittest.main()