File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
molpipeline/estimators/chemprop Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -100,9 +100,7 @@ def __init__(
100
100
self .trainer_params = get_params_trainer (self .lightning_trainer )
101
101
self .set_params (** kwargs )
102
102
103
- def _update_trainer (
104
- self ,
105
- ) -> None :
103
+ def _update_trainer (self ) -> None :
106
104
"""Update the trainer for the model."""
107
105
trainer_params = dict (self .trainer_params )
108
106
if self .model_ckpoint_params :
@@ -139,6 +137,8 @@ def fit(
139
137
X , batch_size = self .batch_size , num_workers = self .n_jobs
140
138
)
141
139
self .lightning_trainer .fit (self .model , training_data )
140
+ # The trainer is reinitalized to avoid storing the training data
141
+ self ._update_trainer ()
142
142
return self
143
143
144
144
def set_params (self , ** params : Any ) -> Self :
You can’t perform that action at this time.
0 commit comments