Skip to content

Commit 6514d12

Browse files
authored
reinitialize trainer after fitting (#112)
1 parent 6280407 commit 6514d12

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

molpipeline/estimators/chemprop/abstract.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -100,9 +100,7 @@ def __init__(
100100
self.trainer_params = get_params_trainer(self.lightning_trainer)
101101
self.set_params(**kwargs)
102102

103-
def _update_trainer(
104-
self,
105-
) -> None:
103+
def _update_trainer(self) -> None:
106104
"""Update the trainer for the model."""
107105
trainer_params = dict(self.trainer_params)
108106
if self.model_ckpoint_params:
@@ -139,6 +137,8 @@ def fit(
139137
X, batch_size=self.batch_size, num_workers=self.n_jobs
140138
)
141139
self.lightning_trainer.fit(self.model, training_data)
140+
# The trainer is reinitalized to avoid storing the training data
141+
self._update_trainer()
142142
return self
143143

144144
def set_params(self, **params: Any) -> Self:

0 commit comments

Comments
 (0)