Skip to content

Fix chemprop init with trainer #181

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 21 additions & 8 deletions molpipeline/estimators/chemprop/lightning_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def get_enable_progress_bar(trainer: pl.Trainer) -> bool:
return False


def get_device(trainer: pl.Trainer) -> str | Accelerator:
def get_device(trainer: pl.Trainer) -> str:
"""Get the device used by the lightning trainer.

Parameters
Expand All @@ -44,15 +44,28 @@ def get_device(trainer: pl.Trainer) -> str | Accelerator:
-------
str
The device used by the lightning trainer.

Raises
------
NotImplementedError
If the accelerator type is not supported. Currently only GOU and CPU is supported.
ValueError
If pytorch_lightning is used instead of lightning. Please use from lightning import pytorch as pl, instead of import pytorch_lightning as pl.

"""
devices: str | Accelerator
if isinstance(trainer.accelerator, str):
return trainer.accelerator
if isinstance(trainer.accelerator, CPUAccelerator):
devices = "cpu"
elif isinstance(trainer.accelerator, CUDAAccelerator):
devices = "gpu"
else:
devices = trainer.accelerator
return devices
return "cpu"
if isinstance(trainer.accelerator, CUDAAccelerator):
return "gpu"
if isinstance(trainer.accelerator, Accelerator):
raise NotImplementedError(
"The accelerator type is not supported. Currently only gpu and cpu accelerators are supported."
)
raise ValueError(
"Unsupported accelerator type, please use from lightning import pytorch as pl, instead of import pytorch_lightning as pl."
)


TRAINER_DEFAULT_PARAMS = {
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ dev = [
"pydocstyle>=6.3.0",
"pylint>=3.3.6",
"pyright>=1.1.399",
"pytorch-lightning>=2.5.1.post0",
"rdkit-stubs>=0.8",
"ruff>=0.11.4",
]
33 changes: 33 additions & 0 deletions test_extras/test_chemprop/test_chemprop_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,3 +525,36 @@ def test_prediction(self) -> None:
mols,
test_data_df["Label"].to_numpy(),
)


class TestChempropModelTrainerInit(unittest.TestCase):
"""Test the Chemprop model initialization."""

def test_chemprop_model_init(self) -> None:
"""Test the Chemprop model initialization."""
trainer = pl.Trainer(
accelerator="cpu",
)
mpnn = MPNN(
message_passing=BondMessagePassing(),
agg=SumAggregation(),
predictor=BinaryClassificationFFN(),
)
chemprop_model = ChempropModel(model=mpnn, lightning_trainer=trainer)
self.assertIsInstance(chemprop_model, ChempropModel)
self.assertIsInstance(chemprop_model.lightning_trainer.accelerator, pl.accelerators.cpu.CPUAccelerator)

def test_wrong_import_init(self) -> None:
"""Test the Che mprop model initialization with wrong import."""
import pytorch_lightning as pl # noqa # pylint: disable=redefined-outer-name, import-outside-toplevel
trainer = pl.Trainer(
accelerator="cpu",
)
mpnn = MPNN(
message_passing=BondMessagePassing(),
agg=SumAggregation(),
predictor=BinaryClassificationFFN(),
)
with self.assertRaises(ValueError):
# This should raise a ValueError because ChempropModel expects lightning.pytorch and not pytorch_lightning.
ChempropModel(model=mpnn, lightning_trainer=trainer)
Loading
Loading