Skip to content

Commit 5ae04d5

Browse files
authored
Feature/model persistence (#289)
Allow save and load multiple times for transformer models. Load train state on loading model.
1 parent 589c7ca commit 5ae04d5

File tree

3 files changed

+194
-38
lines changed

3 files changed

+194
-38
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
88
## Unreleased
99
### Added
1010
- `extras` argument to `SequenceDataset`, `extra_cols` argument to `TransformerDataPreparatorBase`, `session_tower_forward` and `item_tower_forward` methods to `SimilarityModuleBase` ([#287](https://github.com/MobileTeleSystems/RecTools/pull/287))
11+
- Support for resaving transformer models multiple times and loading trainer state ([#289](https://github.com/MobileTeleSystems/RecTools/pull/289))
1112

1213
### Fixed
1314
- [Breaking] Now `LastNSplitter` guarantees taking the last ordered interaction in dataframe in case of identical timestamps ([#288](https://github.com/MobileTeleSystems/RecTools/pull/288))

rectools/models/nn/transformers/base.py

Lines changed: 57 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
import typing_extensions as tpe
2525
from pydantic import BeforeValidator, PlainSerializer
2626
from pytorch_lightning import Trainer
27+
from torch.utils.data import DataLoader, TensorDataset
2728

2829
from rectools import ExternalIds
2930
from rectools.dataset.dataset import Dataset, DatasetSchema, DatasetSchemaDict, IdMap
@@ -505,16 +506,25 @@ def _fit_partial(
505506
if not self.is_fitted:
506507
self._build_model_from_dataset(dataset)
507508
self.fit_trainer = deepcopy(self._trainer)
508-
elif self.fit_trainer is None:
509+
else:
510+
# assumed that dataset is same as in `fit` or as in first call to `fit_partial`
511+
# currently new datasets is not supported due to difficulties with
512+
# handling id maps and item (user) features
509513
self.data_preparator.process_dataset_train(dataset)
510-
self.fit_trainer = deepcopy(self._trainer)
514+
if self.fit_trainer is None:
515+
raise RuntimeError("expected to have fit_trainer set")
511516

512517
train_dataloader = self.data_preparator.get_dataloader_train()
513518
val_dataloader = self.data_preparator.get_dataloader_val()
514519

515520
self.lightning_model.train()
516-
self.fit_trainer.fit_loop.max_epochs = self.fit_trainer.current_epoch + max_epochs
517-
self.fit_trainer.fit_loop.min_epochs = self.fit_trainer.current_epoch + min_epochs
521+
522+
# if checkpoint is from ModelCheckpoint callback (and saved at end of epoch)
523+
# its epoch value equal to num of data epochs - 1 (as epoch is not ended in checkpoint time)
524+
# so instead of `fit_trainer.current_epoch` we use `count of ready epochs`
525+
current_epoch = self.fit_trainer.fit_loop.epoch_progress.current.ready
526+
self.fit_trainer.fit_loop.max_epochs = current_epoch + max_epochs
527+
self.fit_trainer.fit_loop.min_epochs = current_epoch + min_epochs
518528
self.fit_trainer.fit(self.lightning_model, train_dataloader, val_dataloader)
519529

520530
def _recommend_u2i(
@@ -574,8 +584,25 @@ def _get_config(self) -> TransformerModelConfig_T:
574584
return self.config_class(**params)
575585

576586
@classmethod
577-
def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
578-
"""Create model from loaded Lightning checkpoint."""
587+
def _model_from_checkpoint(
588+
cls, checkpoint: tp.Dict[str, tp.Any], ckpt_path: tp.Optional[tp.Union[str, Path]] = None
589+
) -> tpe.Self:
590+
"""
591+
Create model from loaded Lightning checkpoint.
592+
593+
Parameters
594+
----------
595+
checkpoint: Dict[str, tp.Any]
596+
Checkpoint object (pl/torch like)
597+
ckpt_path: Union[str, Path], optional
598+
Path to checkpoint location.
599+
If specified should be a path to `checkpoint` arg file.
600+
`checkpoint` is saved to temp file if not specified.
601+
602+
Returns
603+
-------
604+
Model instance.
605+
"""
579606
model_config = checkpoint["hyper_parameters"]["model_config"]
580607
loaded = cls.from_config(model_config)
581608
loaded.is_fitted = True
@@ -596,20 +623,36 @@ def _model_from_checkpoint(cls, checkpoint: tp.Dict[str, tp.Any]) -> tpe.Self:
596623
item_external_ids=item_external_ids,
597624
model_config=model_config,
598625
)
626+
627+
try:
628+
temp_file = None
629+
actual_ckpt_path = ckpt_path
630+
if actual_ckpt_path is None:
631+
temp_file = NamedTemporaryFile() # pylint: disable=consider-using-with
632+
actual_ckpt_path = temp_file.name
633+
torch.save(checkpoint, actual_ckpt_path)
634+
635+
loaded.fit_trainer = deepcopy(loaded._trainer)
636+
# use stub dataset to load trainer state
637+
loaded.fit_trainer.fit(
638+
loaded.lightning_model,
639+
ckpt_path=actual_ckpt_path,
640+
train_dataloaders=DataLoader(TensorDataset(torch.Tensor())),
641+
)
642+
643+
finally:
644+
if temp_file is not None:
645+
temp_file.close()
646+
599647
loaded.lightning_model.is_fitted = True
600-
loaded.lightning_model.load_state_dict(checkpoint["state_dict"])
601648

602649
return loaded
603650

604651
def __getstate__(self) -> object:
605652
if self.is_fitted:
606653
if self.fit_trainer is None:
607-
explanation = """
608-
Model is fitted but has no `fit_trainer`. Most likely it was just loaded from the
609-
checkpoint. Model that was loaded from checkpoint cannot be saved without being
610-
fitted again.
611-
"""
612-
raise RuntimeError(explanation)
654+
raise RuntimeError("Fitted model is expected to have `fit_trainer` set")
655+
613656
with NamedTemporaryFile() as f:
614657
self.fit_trainer.save_checkpoint(f.name)
615658
checkpoint = Path(f.name).read_bytes()
@@ -658,7 +701,7 @@ def load_from_checkpoint(
658701
prev_config_flatten = make_dict_flat(prev_model_config)
659702
prev_config_flatten.update(model_params_update)
660703
checkpoint["hyper_parameters"]["model_config"] = unflatten_dict(prev_config_flatten)
661-
loaded = cls._model_from_checkpoint(checkpoint)
704+
loaded = cls._model_from_checkpoint(checkpoint, ckpt_path=checkpoint_path)
662705
return loaded
663706

664707
def load_weights_from_checkpoint(self, checkpoint_path: tp.Union[str, Path]) -> None:

tests/models/nn/transformers/test_base.py

Lines changed: 136 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,12 @@
1919

2020
import pandas as pd
2121
import pytest
22+
import pytorch_lightning as pl
2223
import torch
2324
from pytest import FixtureRequest
2425
from pytorch_lightning import Trainer, seed_everything
2526
from pytorch_lightning.loggers import CSVLogger
27+
from torch import nn
2628

2729
from rectools import Columns
2830
from rectools.dataset import Dataset
@@ -35,6 +37,35 @@
3537
from .utils import custom_trainer, custom_trainer_ckpt, custom_trainer_multiple_ckpt, leave_one_out_mask
3638

3739

40+
def assert_torch_models_equal(model_a: nn.Module, model_b: nn.Module) -> None:
41+
assert type(model_a) is type(model_b), "different types"
42+
43+
with torch.no_grad():
44+
for (apn, apv), (bpn, bpv) in zip(model_a.named_parameters(), model_b.named_parameters()):
45+
assert apn == bpn, "different parameter name"
46+
assert torch.isclose(apv, bpv).all(), "different parameter value"
47+
48+
49+
def assert_pl_models_equal(model_a: pl.LightningModule, model_b: pl.LightningModule) -> None:
50+
"""Assert pl modules are equal in terms of weights and trainer"""
51+
assert_torch_models_equal(model_a, model_b)
52+
53+
trainer_a = model_a.trainer
54+
trainer_b = model_a.trainer
55+
56+
assert_pl_trainers_equal(trainer_a, trainer_b)
57+
58+
59+
def assert_pl_trainers_equal(trainer_a: Trainer, trainer_b: Trainer) -> None:
60+
"""Assert pl trainers are equal in terms of optimizers state"""
61+
assert len(trainer_a.optimizers) == len(trainer_b.optimizers), "Different number of optimizers"
62+
63+
for opt_a, opt_b in zip(trainer_b.optimizers, trainer_b.optimizers):
64+
# Check optimizer class
65+
assert type(opt_a) is type(opt_b), f"Optimizer types differ: {type(opt_a)} vs {type(opt_b)}"
66+
assert opt_a.state_dict() == opt_b.state_dict(), "optimizers state dict differs"
67+
68+
3869
class TestTransformerModelBase:
3970
def setup_method(self) -> None:
4071
torch.use_deterministic_algorithms(True)
@@ -209,28 +240,6 @@ def test_load_from_checkpoint(
209240

210241
self._assert_same_reco(model, recovered_model, dataset)
211242

212-
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
213-
def test_raises_when_save_model_loaded_from_checkpoint(
214-
self,
215-
model_cls: tp.Type[TransformerModelBase],
216-
dataset: Dataset,
217-
) -> None:
218-
model = model_cls.from_config(
219-
{
220-
"deterministic": True,
221-
"get_trainer_func": custom_trainer_ckpt,
222-
}
223-
)
224-
model.fit(dataset)
225-
assert model.fit_trainer is not None
226-
if model.fit_trainer.log_dir is None:
227-
raise ValueError("No log dir")
228-
ckpt_path = os.path.join(model.fit_trainer.log_dir, "checkpoints", "last_epoch.ckpt")
229-
recovered_model = model_cls.load_from_checkpoint(ckpt_path)
230-
with pytest.raises(RuntimeError):
231-
with NamedTemporaryFile() as f:
232-
recovered_model.save(f.name)
233-
234243
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
235244
def test_load_weights_from_checkpoint(
236245
self,
@@ -391,8 +400,6 @@ def test_fit_partial_from_checkpoint(
391400
recovered_fit_partial_model = model_cls.load_from_checkpoint(ckpt_path)
392401

393402
seed_everything(32, workers=True)
394-
fit_partial_model.fit_trainer = deepcopy(fit_partial_model._trainer) # pylint: disable=protected-access
395-
fit_partial_model.lightning_model.optimizer = None
396403
fit_partial_model.fit_partial(dataset, min_epochs=1, max_epochs=1)
397404

398405
seed_everything(32, workers=True)
@@ -410,3 +417,108 @@ def test_raises_when_incorrect_similarity_dist(
410417
with pytest.raises(ValueError):
411418
model = model_cls.from_config(model_config)
412419
model.fit(dataset=dataset)
420+
421+
@pytest.mark.parametrize("fit", (True, False))
422+
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
423+
@pytest.mark.parametrize("default_trainer", (True, False))
424+
def test_resaving(
425+
self,
426+
model_cls: tp.Type[TransformerModelBase],
427+
dataset: Dataset,
428+
default_trainer: bool,
429+
fit: bool,
430+
) -> None:
431+
config: tp.Dict[str, tp.Any] = {"deterministic": True}
432+
if not default_trainer:
433+
config["get_trainer_func"] = custom_trainer
434+
model = model_cls.from_config(config)
435+
436+
seed_everything(32, workers=True)
437+
if fit:
438+
model.fit(dataset)
439+
440+
with NamedTemporaryFile() as f:
441+
model.save(f.name)
442+
recovered_model = model_cls.load(f.name)
443+
444+
with NamedTemporaryFile() as f:
445+
recovered_model.save(f.name)
446+
second_recovered_model = model_cls.load(f.name)
447+
448+
assert isinstance(recovered_model, model_cls)
449+
450+
original_model_config = model.get_config()
451+
second_recovered_model_config = recovered_model.get_config()
452+
assert second_recovered_model_config == original_model_config
453+
454+
if fit:
455+
assert_pl_models_equal(model.lightning_model, second_recovered_model.lightning_model)
456+
457+
# check if trainer keep state on multiple call partial fit
458+
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
459+
def test_fit_partial_multiple_times(
460+
self,
461+
dataset: Dataset,
462+
model_cls: tp.Type[TransformerModelBase],
463+
) -> None:
464+
class FixSeedLightningModule(TransformerLightningModule):
465+
def on_train_epoch_start(self) -> None:
466+
seed_everything(32, workers=True)
467+
468+
seed_everything(32, workers=True)
469+
model = model_cls.from_config(
470+
{
471+
"epochs": 3,
472+
"data_preparator_kwargs": {"shuffle_train": False},
473+
"get_trainer_func": custom_trainer,
474+
"lightning_module_type": FixSeedLightningModule,
475+
}
476+
)
477+
model.fit_partial(dataset, min_epochs=1, max_epochs=1)
478+
t1 = deepcopy(model.fit_trainer)
479+
model.fit_partial(
480+
Dataset.construct(pd.DataFrame(columns=Columns.Interactions)),
481+
min_epochs=1,
482+
max_epochs=1,
483+
)
484+
t2 = deepcopy(model.fit_trainer)
485+
486+
# Since for the second we are fitting on an empty dataset,
487+
# the trainer state should be kept exactly the same as after the first fit
488+
# to prove that fit_partial does not change trainer state before proceeding to training."
489+
assert t1 is not None
490+
assert t2 is not None
491+
assert_pl_trainers_equal(t1, t2)
492+
493+
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
494+
def test_raises_when_fit_trainer_is_none_on_save_trained_model(
495+
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset
496+
) -> None:
497+
config: tp.Dict[str, tp.Any] = {"deterministic": True}
498+
model = model_cls.from_config(config)
499+
500+
seed_everything(32, workers=True)
501+
model.fit(dataset)
502+
model.fit_trainer = None
503+
504+
with NamedTemporaryFile() as f:
505+
with pytest.raises(RuntimeError):
506+
model.save(f.name)
507+
508+
@pytest.mark.parametrize("model_cls", (SASRecModel, BERT4RecModel))
509+
def test_raises_when_fit_trainer_is_none_on_fit_partial_trained_model(
510+
self, model_cls: tp.Type[TransformerModelBase], dataset: Dataset
511+
) -> None:
512+
config: tp.Dict[str, tp.Any] = {"deterministic": True}
513+
model = model_cls.from_config(config)
514+
515+
seed_everything(32, workers=True)
516+
model.fit(dataset)
517+
model.fit_trainer = None
518+
519+
with pytest.raises(RuntimeError):
520+
model.fit_partial(
521+
dataset,
522+
min_epochs=1,
523+
max_epochs=1,
524+
)

0 commit comments

Comments
 (0)