Skip to content

Commit 8680d95

Browse files
shanjiazkylesayrs
andauthored
Removed recipeStage, recipeModifer, stageModifier classes (#1514)
SUMMARY: Removed recipeStage, recipeModifer, stageModifier classes and dependencies. TEST PLAN: Tested locally DETAILS & CONCERNS: 1. `recipe.py` is where most changes happened. - Rewrote the parsing logic in `parse_from_dict`. - `Recipe` class now is basically just a list of `Modifiers`, all other changes are downstream dependencies. - All the serilization helpers and deserialization helpers now live in `utils`. Had to rewrite serialization logic. - Stage is stored as a string by `self.stage`. 2. Had to update a test in `test_recipe`. We support creating a recipe instance by reading a multi-stage recipe but no longer support serializing multi-stage recipes directly. The current structure assumes one recipe could only have one stage. 3. Had to rewrite `update_and_save_recipe` function since now each recipe only has one stage and therefore can't handle merging multi-stage recipes. The new setup: - `infer_recipe_from_model_path` will find existing recipe path. - `yaml` method now takes in an optional argument `existing_recipe_path` and merge the existing recipe with the current recipe. Since there would only be at most one existing recipe at a time, we should be safe. TODO: - [x] Add validations for recipe class. - [x] Simplify `recipe.py` structure --------- Signed-off-by: shanjiaz <zsjwpianpian@gmail.com> Co-authored-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 9341907 commit 8680d95

File tree

20 files changed

+283
-787
lines changed

20 files changed

+283
-787
lines changed

src/llmcompressor/core/lifecycle.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from llmcompressor.core.events import Event, EventType
1414
from llmcompressor.core.state import State
15-
from llmcompressor.modifiers import StageModifiers
1615
from llmcompressor.recipe import Recipe, RecipeArgsInput, RecipeInput, RecipeStageInput
1716

1817
__all__ = ["CompressionLifecycle"]
@@ -33,7 +32,6 @@ class CompressionLifecycle:
3332

3433
state: State = field(default_factory=State)
3534
recipe: Recipe = field(default_factory=Recipe)
36-
modifiers: List[StageModifiers] = field(default_factory=list)
3735

3836
initialized_: bool = False
3937
finalized: bool = False
@@ -60,7 +58,7 @@ def reset(self):
6058
"""
6159
logger.debug("Resetting compression lifecycle")
6260

63-
for mod in self.modifiers:
61+
for mod in self.recipe.modifiers:
6462
if not mod.initialized or mod.finalized:
6563
continue
6664
try:
@@ -92,21 +90,26 @@ def initialize(
9290
return
9391

9492
logger.debug("Initializing compression lifecycle")
95-
self.recipe = Recipe.simplify_recipe(
96-
recipe=recipe, target_stage=recipe_stage, override_args=recipe_args
97-
)
98-
self.modifiers = self.recipe.create_modifier()
93+
if not recipe:
94+
self.recipe = Recipe()
95+
else:
96+
self.recipe = Recipe.create_instance(
97+
path_or_modifiers=recipe, target_stage=recipe_stage
98+
)
99+
if recipe_args:
100+
self.recipe.args = {**recipe_args}
99101

100102
mod_data = []
101-
for mod in self.modifiers:
103+
for mod in self.recipe.modifiers:
102104
data = mod.initialize(state=self.state, **kwargs)
103105
logger.debug("Initialized modifier: {}", mod)
104106
if data is not None:
105107
mod_data.append(data)
106108

107109
self.initialized_ = True
108110
logger.info(
109-
"Compression lifecycle initialized for {} modifiers", len(self.modifiers)
111+
"Compression lifecycle initialized for {} modifiers",
112+
len(self.recipe.modifiers),
110113
)
111114

112115
return mod_data
@@ -130,7 +133,7 @@ def finalize(self, **kwargs) -> List[Any]:
130133

131134
logger.debug("Finalizing compression lifecycle")
132135
mod_data = []
133-
for mod in self.modifiers:
136+
for mod in self.recipe.modifiers:
134137
data = mod.finalize(state=self.state, **kwargs)
135138
logger.debug("Finalized modifier: {}", mod)
136139
if data is not None:
@@ -139,7 +142,8 @@ def finalize(self, **kwargs) -> List[Any]:
139142
self.finalized = True
140143

141144
logger.info(
142-
"Compression lifecycle finalized for {} modifiers", len(self.modifiers)
145+
"Compression lifecycle finalized for {} modifiers",
146+
len(self.recipe.modifiers),
143147
)
144148

145149
return mod_data
@@ -196,7 +200,7 @@ def event(
196200

197201
event = Event(type_=event_type)
198202
mod_data = []
199-
for mod in self.modifiers:
203+
for mod in self.recipe.modifiers:
200204
data = mod.update_event(state=self.state, event=event, **kwargs)
201205
logger.debug("Updated event with modifier: {}", mod)
202206
if data is not None:

src/llmcompressor/core/session.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -220,17 +220,6 @@ def get_serialized_recipe(self) -> Optional[str]:
220220

221221
logger.warning("Recipe not found in session - it may have been reset")
222222

223-
def get_modifiers(self):
224-
"""
225-
Get all modifiers across all stages
226-
"""
227-
stage_modifiers = self.lifecycle.modifiers
228-
return [
229-
modifier
230-
for stage_modifier in stage_modifiers
231-
for modifier in stage_modifier.modifiers
232-
] # noqa: E127
233-
234223
def _log_model_info(self):
235224
# Log model level logs if cadence reached
236225
current_index = self._lifecycle.global_step

src/llmcompressor/entrypoints/oneshot.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,9 +186,8 @@ def apply_recipe_modifiers(
186186
recipe_args=self.recipe_args.recipe_args,
187187
calib_data=calibration_dataloader,
188188
)
189-
190189
user_pipeline = self.dataset_args.pipeline
191-
modifiers = session.get_modifiers()
190+
modifiers = session.lifecycle.recipe.modifiers
192191
pipeline = CalibrationPipeline.from_modifiers(modifiers, user=user_pipeline)
193192
pipeline(self.model, calibration_dataloader, self.dataset_args)
194193

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
from .factory import ModifierFactory
22
from .interface import ModifierInterface
33
from .modifier import Modifier
4-
from .stage import StageModifiers
54

65
__all__ = [
76
"ModifierFactory",
87
"ModifierInterface",
98
"Modifier",
10-
"StageModifiers",
119
]

src/llmcompressor/modifiers/stage.py

Lines changed: 0 additions & 105 deletions
This file was deleted.

src/llmcompressor/pipelines/independent/pipeline.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
from torch.utils.data.dataloader import DataLoader
66

77
from llmcompressor.core import active_session
8-
from llmcompressor.modifiers.stage import StageModifiers
98
from llmcompressor.pipelines.registry import CalibrationPipeline
109
from llmcompressor.utils.helpers import patch_attr
1110

@@ -34,18 +33,15 @@ def __call__(
3433
_logger = logger.patch(lambda r: r.update(function="IndependentPipeline"))
3534

3635
session = active_session()
37-
modifiers = session.get_modifiers()
38-
with patch_attr(session.lifecycle, "modifiers", None):
39-
for index, modifier in enumerate(modifiers):
40-
mod_type = str(type(modifier).__name__)
41-
session.lifecycle.modifiers = [
42-
StageModifiers(modifiers=[modifier], group=mod_type, index=index)
43-
]
44-
36+
modifiers = session.lifecycle.recipe.modifiers
37+
with patch_attr(session.lifecycle.recipe, "modifiers", None):
38+
for modifier in modifiers:
39+
mod_type = type(modifier).__name__
40+
session.lifecycle.recipe.modifiers = [modifier]
4541
pipeline = CalibrationPipeline.from_modifiers([modifier])
4642
pipeline_name = pipeline.__class__.__name__
4743
_logger.info(f"Inferred `{pipeline_name}` for `{mod_type}`")
4844

4945
pipeline(model, dataloader, dataset_args)
5046

51-
# restore modifiers on exit so model can be compressed based on recipe
47+
# restore modifiers on exit so model can be compressed based on recipe

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ def __call__(
6363
model_device = get_execution_device(model)
6464

6565
# find layers
66-
modifiers = session.get_modifiers()
66+
modifiers = session.lifecycle.recipe.modifiers
6767
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
6868
layers = match_modules(model, sequential_targets)
6969

src/llmcompressor/pipelines/sequential/pipeline.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,9 @@ def __call__(
5757
model_device = get_execution_device(model)
5858

5959
# prepare to trace subgraphs
60-
modifiers = session.get_modifiers()
60+
modifiers = session.lifecycle.recipe.modifiers
6161
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
62+
6263
ignore = dataset_args.tracing_ignore
6364

6465
# trace subgraphs

src/llmcompressor/recipe/__init__.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
from .base import RecipeBase
21
from .metadata import DatasetMetaData, LayerMetaData, ModelMetaData, ParamMetaData
3-
from .modifier import RecipeModifier
42
from .recipe import Recipe, RecipeArgsInput, RecipeInput, RecipeStageInput
5-
from .stage import RecipeStage
63

74
__all__ = [
85
"DatasetMetaData",
96
"ParamMetaData",
107
"LayerMetaData",
118
"ModelMetaData",
12-
"RecipeBase",
13-
"RecipeModifier",
14-
"RecipeStage",
159
"Recipe",
1610
"RecipeInput",
1711
"RecipeStageInput",

src/llmcompressor/recipe/base.py

Lines changed: 0 additions & 25 deletions
This file was deleted.

0 commit comments

Comments
 (0)