19
19
20
20
import pandas as pd
21
21
import pytest
22
+ import pytorch_lightning as pl
22
23
import torch
23
24
from pytest import FixtureRequest
24
25
from pytorch_lightning import Trainer , seed_everything
25
26
from pytorch_lightning .loggers import CSVLogger
27
+ from torch import nn
26
28
27
29
from rectools import Columns
28
30
from rectools .dataset import Dataset
35
37
from .utils import custom_trainer , custom_trainer_ckpt , custom_trainer_multiple_ckpt , leave_one_out_mask
36
38
37
39
40
+ def assert_torch_models_equal (model_a : nn .Module , model_b : nn .Module ) -> None :
41
+ assert type (model_a ) is type (model_b ), "different types"
42
+
43
+ with torch .no_grad ():
44
+ for (apn , apv ), (bpn , bpv ) in zip (model_a .named_parameters (), model_b .named_parameters ()):
45
+ assert apn == bpn , "different parameter name"
46
+ assert torch .isclose (apv , bpv ).all (), "different parameter value"
47
+
48
+
49
+ def assert_pl_models_equal (model_a : pl .LightningModule , model_b : pl .LightningModule ) -> None :
50
+ """Assert pl modules are equal in terms of weights and trainer"""
51
+ assert_torch_models_equal (model_a , model_b )
52
+
53
+ trainer_a = model_a .trainer
54
+ trainer_b = model_a .trainer
55
+
56
+ assert_pl_trainers_equal (trainer_a , trainer_b )
57
+
58
+
59
+ def assert_pl_trainers_equal (trainer_a : Trainer , trainer_b : Trainer ) -> None :
60
+ """Assert pl trainers are equal in terms of optimizers state"""
61
+ assert len (trainer_a .optimizers ) == len (trainer_b .optimizers ), "Different number of optimizers"
62
+
63
+ for opt_a , opt_b in zip (trainer_b .optimizers , trainer_b .optimizers ):
64
+ # Check optimizer class
65
+ assert type (opt_a ) is type (opt_b ), f"Optimizer types differ: { type (opt_a )} vs { type (opt_b )} "
66
+ assert opt_a .state_dict () == opt_b .state_dict (), "optimizers state dict differs"
67
+
68
+
38
69
class TestTransformerModelBase :
39
70
def setup_method (self ) -> None :
40
71
torch .use_deterministic_algorithms (True )
@@ -209,28 +240,6 @@ def test_load_from_checkpoint(
209
240
210
241
self ._assert_same_reco (model , recovered_model , dataset )
211
242
212
- @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
213
- def test_raises_when_save_model_loaded_from_checkpoint (
214
- self ,
215
- model_cls : tp .Type [TransformerModelBase ],
216
- dataset : Dataset ,
217
- ) -> None :
218
- model = model_cls .from_config (
219
- {
220
- "deterministic" : True ,
221
- "get_trainer_func" : custom_trainer_ckpt ,
222
- }
223
- )
224
- model .fit (dataset )
225
- assert model .fit_trainer is not None
226
- if model .fit_trainer .log_dir is None :
227
- raise ValueError ("No log dir" )
228
- ckpt_path = os .path .join (model .fit_trainer .log_dir , "checkpoints" , "last_epoch.ckpt" )
229
- recovered_model = model_cls .load_from_checkpoint (ckpt_path )
230
- with pytest .raises (RuntimeError ):
231
- with NamedTemporaryFile () as f :
232
- recovered_model .save (f .name )
233
-
234
243
@pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
235
244
def test_load_weights_from_checkpoint (
236
245
self ,
@@ -391,8 +400,6 @@ def test_fit_partial_from_checkpoint(
391
400
recovered_fit_partial_model = model_cls .load_from_checkpoint (ckpt_path )
392
401
393
402
seed_everything (32 , workers = True )
394
- fit_partial_model .fit_trainer = deepcopy (fit_partial_model ._trainer ) # pylint: disable=protected-access
395
- fit_partial_model .lightning_model .optimizer = None
396
403
fit_partial_model .fit_partial (dataset , min_epochs = 1 , max_epochs = 1 )
397
404
398
405
seed_everything (32 , workers = True )
@@ -410,3 +417,108 @@ def test_raises_when_incorrect_similarity_dist(
410
417
with pytest .raises (ValueError ):
411
418
model = model_cls .from_config (model_config )
412
419
model .fit (dataset = dataset )
420
+
421
+ @pytest .mark .parametrize ("fit" , (True , False ))
422
+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
423
+ @pytest .mark .parametrize ("default_trainer" , (True , False ))
424
+ def test_resaving (
425
+ self ,
426
+ model_cls : tp .Type [TransformerModelBase ],
427
+ dataset : Dataset ,
428
+ default_trainer : bool ,
429
+ fit : bool ,
430
+ ) -> None :
431
+ config : tp .Dict [str , tp .Any ] = {"deterministic" : True }
432
+ if not default_trainer :
433
+ config ["get_trainer_func" ] = custom_trainer
434
+ model = model_cls .from_config (config )
435
+
436
+ seed_everything (32 , workers = True )
437
+ if fit :
438
+ model .fit (dataset )
439
+
440
+ with NamedTemporaryFile () as f :
441
+ model .save (f .name )
442
+ recovered_model = model_cls .load (f .name )
443
+
444
+ with NamedTemporaryFile () as f :
445
+ recovered_model .save (f .name )
446
+ second_recovered_model = model_cls .load (f .name )
447
+
448
+ assert isinstance (recovered_model , model_cls )
449
+
450
+ original_model_config = model .get_config ()
451
+ second_recovered_model_config = recovered_model .get_config ()
452
+ assert second_recovered_model_config == original_model_config
453
+
454
+ if fit :
455
+ assert_pl_models_equal (model .lightning_model , second_recovered_model .lightning_model )
456
+
457
+ # check if trainer keep state on multiple call partial fit
458
+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
459
+ def test_fit_partial_multiple_times (
460
+ self ,
461
+ dataset : Dataset ,
462
+ model_cls : tp .Type [TransformerModelBase ],
463
+ ) -> None :
464
+ class FixSeedLightningModule (TransformerLightningModule ):
465
+ def on_train_epoch_start (self ) -> None :
466
+ seed_everything (32 , workers = True )
467
+
468
+ seed_everything (32 , workers = True )
469
+ model = model_cls .from_config (
470
+ {
471
+ "epochs" : 3 ,
472
+ "data_preparator_kwargs" : {"shuffle_train" : False },
473
+ "get_trainer_func" : custom_trainer ,
474
+ "lightning_module_type" : FixSeedLightningModule ,
475
+ }
476
+ )
477
+ model .fit_partial (dataset , min_epochs = 1 , max_epochs = 1 )
478
+ t1 = deepcopy (model .fit_trainer )
479
+ model .fit_partial (
480
+ Dataset .construct (pd .DataFrame (columns = Columns .Interactions )),
481
+ min_epochs = 1 ,
482
+ max_epochs = 1 ,
483
+ )
484
+ t2 = deepcopy (model .fit_trainer )
485
+
486
+ # Since for the second we are fitting on an empty dataset,
487
+ # the trainer state should be kept exactly the same as after the first fit
488
+ # to prove that fit_partial does not change trainer state before proceeding to training."
489
+ assert t1 is not None
490
+ assert t2 is not None
491
+ assert_pl_trainers_equal (t1 , t2 )
492
+
493
+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
494
+ def test_raises_when_fit_trainer_is_none_on_save_trained_model (
495
+ self , model_cls : tp .Type [TransformerModelBase ], dataset : Dataset
496
+ ) -> None :
497
+ config : tp .Dict [str , tp .Any ] = {"deterministic" : True }
498
+ model = model_cls .from_config (config )
499
+
500
+ seed_everything (32 , workers = True )
501
+ model .fit (dataset )
502
+ model .fit_trainer = None
503
+
504
+ with NamedTemporaryFile () as f :
505
+ with pytest .raises (RuntimeError ):
506
+ model .save (f .name )
507
+
508
+ @pytest .mark .parametrize ("model_cls" , (SASRecModel , BERT4RecModel ))
509
+ def test_raises_when_fit_trainer_is_none_on_fit_partial_trained_model (
510
+ self , model_cls : tp .Type [TransformerModelBase ], dataset : Dataset
511
+ ) -> None :
512
+ config : tp .Dict [str , tp .Any ] = {"deterministic" : True }
513
+ model = model_cls .from_config (config )
514
+
515
+ seed_everything (32 , workers = True )
516
+ model .fit (dataset )
517
+ model .fit_trainer = None
518
+
519
+ with pytest .raises (RuntimeError ):
520
+ model .fit_partial (
521
+ dataset ,
522
+ min_epochs = 1 ,
523
+ max_epochs = 1 ,
524
+ )
0 commit comments