2
2
3
3
from compressed_tensors .transform import TransformScheme , apply_transform_config
4
4
5
- from llmcompressor .core import State
5
+ from llmcompressor .core import Event , EventType , State
6
6
from llmcompressor .modifiers import Modifier
7
7
8
8
from .template .quip import QUIP
@@ -12,9 +12,9 @@ class TransformModifier(Modifier):
12
12
preset_config : Optional [str ] = None
13
13
config_groups : Optional [Dict [str , TransformScheme ]] = None
14
14
15
- # model validator to validate both preset and config gropus are not provided
15
+ # model validator to validate both preset and config groups are not provided
16
16
17
- def on_initialize (self , state : State , ** kwargs ):
17
+ def on_initialize (self , state : State , ** kwargs ) -> bool :
18
18
if self .preset_config is not None :
19
19
# import config template and customize to model
20
20
pass
@@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs):
23
23
config = QUIP
24
24
25
25
apply_transform_config (state .model , config )
26
- breakpoint ()
26
+
27
+ return True
28
+
29
+ def on_start (self , state : State , event : Event , ** kwargs ):
30
+ self .started_ = True
31
+
32
+ def on_event (self , state : State , event : Event , ** kwargs ):
33
+ if event .type_ == EventType .CALIBRATION_EPOCH_START :
34
+ if not self .started_ :
35
+ self .on_start (state , None )
36
+
37
+ elif event .type_ == EventType .SEQUENTIAL_EPOCH_END :
38
+ pass
39
+
40
+ elif event .type_ == EventType .CALIBRATION_EPOCH_END :
41
+ if not self .ended_ :
42
+ self .on_end (state , None )
43
+
44
+ def on_end (self , state : State , event : Event , ** kwargs ):
45
+ self .ended_ = True
46
+
47
+ def on_finalize (self , state : State , ** kwargs ) -> bool :
48
+ if not self .ended_ :
49
+ self .on_end (state , None )
50
+
51
+ return True
0 commit comments