Skip to content

Commit b6c088e

Browse files
add on lifecycle methods
Signed-off-by: Brian Dellabetta <bdellabe@redhat.com>
1 parent 3aa35e7 commit b6c088e

File tree

2 files changed

+31
-6
lines changed

2 files changed

+31
-6
lines changed

examples/transform/llama3_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,13 @@
33

44
from llmcompressor.modifiers.quantization import GPTQModifier
55
from llmcompressor.modifiers.transform import TransformModifier
6-
from llmcompressor.transformers import oneshot
6+
from llmcompressor import oneshot
77

88
# Select model and load it.
99
MODEL_ID = "meta-llama/Meta-Llama-3-8B-Instruct"
1010

1111
model = AutoModelForCausalLM.from_pretrained(
1212
MODEL_ID,
13-
device_map="auto",
1413
torch_dtype="auto",
1514
)
1615
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
@@ -66,6 +65,7 @@ def tokenize(sample):
6665
model=model,
6766
dataset=ds,
6867
recipe=recipe,
68+
pipeline="sequential",
6969
max_seq_length=MAX_SEQUENCE_LENGTH,
7070
num_calibration_samples=NUM_CALIBRATION_SAMPLES,
7171
)

src/llmcompressor/modifiers/transform/transform.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from compressed_tensors.transform import TransformScheme, apply_transform_config
44

5-
from llmcompressor.core import State
5+
from llmcompressor.core import Event, EventType, State
66
from llmcompressor.modifiers import Modifier
77

88
from .template.quip import QUIP
@@ -12,9 +12,9 @@ class TransformModifier(Modifier):
1212
preset_config: Optional[str] = None
1313
config_groups: Optional[Dict[str, TransformScheme]] = None
1414

15-
# model validator to validate both preset and config gropus are not provided
15+
# model validator to validate both preset and config groups are not provided
1616

17-
def on_initialize(self, state: State, **kwargs):
17+
def on_initialize(self, state: State, **kwargs) -> bool:
1818
if self.preset_config is not None:
1919
# import config template and customize to model
2020
pass
@@ -23,4 +23,29 @@ def on_initialize(self, state: State, **kwargs):
2323
config = QUIP
2424

2525
apply_transform_config(state.model, config)
26-
breakpoint()
26+
27+
return True
28+
29+
def on_start(self, state: State, event: Event, **kwargs):
30+
self.started_ = True
31+
32+
def on_event(self, state: State, event: Event, **kwargs):
33+
if event.type_ == EventType.CALIBRATION_EPOCH_START:
34+
if not self.started_:
35+
self.on_start(state, None)
36+
37+
elif event.type_ == EventType.SEQUENTIAL_EPOCH_END:
38+
pass
39+
40+
elif event.type_ == EventType.CALIBRATION_EPOCH_END:
41+
if not self.ended_:
42+
self.on_end(state, None)
43+
44+
def on_end(self, state: State, event: Event, **kwargs):
45+
self.ended_ = True
46+
47+
def on_finalize(self, state: State, **kwargs) -> bool:
48+
if not self.ended_:
49+
self.on_end(state, None)
50+
51+
return True

0 commit comments

Comments
 (0)