Skip to content

Commit 963900a

Browse files
add on lifecycle methods
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 3aa35e7 commit 963900a

File tree

1 file changed

+29
-4
lines changed

1 file changed

+29
-4
lines changed

src/llmcompressor/modifiers/transform/transform.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from compressed_tensors.transform import TransformScheme, apply_transform_config
44

5-
from llmcompressor.core import State
5+
from llmcompressor.core import Event, EventType, State
66
from llmcompressor.modifiers import Modifier
77

88
from .template.quip import QUIP
@@ -12,9 +12,9 @@ class TransformModifier(Modifier):
1212
preset_config: Optional[str] = None
1313
config_groups: Optional[Dict[str, TransformScheme]] = None
1414

15-
# model validator to validate both preset and config gropus are not provided
15+
# model validator to validate both preset and config groups are not provided
1616

17-
def on_initialize(self, state: State, **kwargs):
17+
def on_initialize(self, state: State, **kwargs) -> bool:
1818
if self.preset_config is not None:
1919
# import config template and customize to model
2020
pass
@@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs):
2323
config = QUIP
2424

2525
apply_transform_config(state.model, config)
26-
breakpoint()
26+
27+
return True
28+
29+
def on_start(self, state: State, event: Event, **kwargs):
30+
self.started_ = True
31+
32+
def on_event(self, state: State, event: Event, **kwargs):
33+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
34+
if not self.started_:
35+
self.on_start(state, None)
36+
37+
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
38+
pass
39+
40+
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
41+
if not self.ended_:
42+
self.on_end(state, None)
43+
44+
def on_end(self, state: State, event: Event, **kwargs):
45+
self.ended_ = True
46+
47+
def on_finalize(self, state: State, **kwargs) -> bool:
48+
if not self.ended_:
49+
self.on_end(state, None)
50+
51+
return True

0 commit comments

Comments
 (0)