Skip to content

Commit 78d6975

Browse files
committed
squash
Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 50bb656 commit 78d6975

File tree

14 files changed

+91
-159
lines changed

14 files changed

+91
-159
lines changed

src/llmcompressor/modifiers/awq/base.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import inspect
2-
from typing import Dict, List, Optional, Tuple, Union
2+
from typing import Any, Dict, List, Optional, Tuple, Union
33

44
import torch
55
from compressed_tensors.quantization import (
@@ -12,7 +12,7 @@
1212
update_offload_parameter,
1313
)
1414
from loguru import logger
15-
from pydantic import ConfigDict, PrivateAttr, model_validator
15+
from pydantic import ConfigDict, PrivateAttr, field_validator, model_validator
1616
from torch.nn import Module
1717
from tqdm import tqdm
1818

@@ -27,6 +27,7 @@
2727
from llmcompressor.modifiers.quantization.quantization import QuantizationMixin
2828
from llmcompressor.modifiers.utils.hooks import HooksMixin
2929
from llmcompressor.pipelines.cache import IntermediatesCache
30+
from llmcompressor.sentinel import Sentinel
3031
from llmcompressor.utils.fsdp.helpers import get_fsdp_parent
3132
from llmcompressor.utils.helpers import calibration_forward_context
3233
from llmcompressor.utils.pytorch.module import get_layer_by_name, get_layers
@@ -96,8 +97,6 @@ class AWQModifier(Modifier, QuantizationMixin):
9697
- on_finalize
9798
- clear resolved mappings and captured activations
9899
99-
:param sequential_targets: list of module names to compress in
100-
the same calibration pass
101100
:param mappings: list activation layers to smooth, and which layers to
102101
scale the output such that activations are smoothed.
103102
Each entry of the mapping list should be a list itself, in which the first
@@ -116,11 +115,7 @@ class AWQModifier(Modifier, QuantizationMixin):
116115
and weights to determine the scaling factor
117116
"""
118117

119-
# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
120-
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)
121-
122118
# User-provided vars (in addition to QuantizationMixin args)
123-
sequential_targets: Union[str, List[str], None] = None
124119
mappings: Optional[List[AWQMapping]] = None
125120
offload_device: Optional[torch.device] = None
126121
duo_scaling: bool = True
@@ -141,6 +136,20 @@ class AWQModifier(Modifier, QuantizationMixin):
141136
default_factory=dict
142137
)
143138

139+
# deprecated
140+
sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated")
141+
142+
# Allow arbitrary types because AWQMapping has fields of type torch.nn.Module
143+
model_config: ConfigDict = ConfigDict(arbitrary_types_allowed=True)
144+
145+
@field_validator("sequential_targets", mode="before")
146+
def validate_sequential_targets(cls, value: bool) -> bool:
147+
if value is not Sentinel("deprecated"):
148+
raise ValueError(
149+
"Setting `sequential_targets` via modifiers is no longer supported, "
150+
"Please use `oneshot(sequential_targets=...)`"
151+
)
152+
144153
@model_validator(mode="after")
145154
def validate_model_after(model: "AWQModifier") -> "AWQModifier":
146155
"""

src/llmcompressor/modifiers/obcq/base.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ class SparseGPTModifier(SparsityModifierBase):
4545
- on_finalize
4646
- remove_hooks()
4747
48+
:param targets: list of module names to quantize if a scheme is provided. Defaults
49+
to Linear layers
50+
:param ignore: optional list of module class names or submodule names to not
51+
quantize even if they match a target. Defaults to empty list.
4852
:param sparsity: Sparsity to compress model to
4953
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
5054
Layerwise Sparsity (OWL), more information can be found
@@ -62,12 +66,6 @@ class SparseGPTModifier(SparsityModifierBase):
6266
previously pruned model, defaults to False.
6367
:param offload_hessians: Set to True for decreased memory usage but increased
6468
runtime.
65-
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
66-
to compress every layer in the model. Alias for `targets`
67-
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
68-
to compress every layer in the model. Alias for `sequential_targets`
69-
:param ignore: optional list of module class names or submodule names to not
70-
quantize even if they match a target. Defaults to empty list.
7169
"""
7270

7371
# modifier arguments

src/llmcompressor/modifiers/obcq/sgpt_base.py

Lines changed: 19 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -8,13 +8,14 @@
88
import torch
99
from loguru import logger
1010
from pydantic import Field, PrivateAttr, field_validator, model_validator
11+
from transformers import PreTrainedModel
1112

1213
from llmcompressor.core import Event, EventType, State
1314
from llmcompressor.modifiers.modifier import Modifier
1415
from llmcompressor.modifiers.utils.hooks import HooksMixin
16+
from llmcompressor.sentinel import Sentinel
1517
from llmcompressor.utils.pytorch.module import (
1618
get_layers,
17-
get_no_split_params,
1819
get_prunable_layers,
1920
match_targets,
2021
)
@@ -27,35 +28,41 @@ class SparsityModifierBase(Modifier):
2728
"""
2829

2930
# modifier arguments
31+
targets: Union[str, List[str]] = ["Linear"]
32+
ignore: List[str] = Field(default_factory=list)
3033
sparsity: Optional[Union[float, List[float]]]
3134
sparsity_profile: Optional[str] = None
3235
mask_structure: str = "0:0"
3336
owl_m: Optional[int] = None
3437
owl_lmbda: Optional[float] = None
3538

36-
# data pipeline arguments
37-
sequential_update: Optional[bool] = False # deprecated
38-
sequential_targets: Union[str, List[str], None] = None
39-
targets: Union[str, List[str]] = ["Linear"]
40-
ignore: List[str] = Field(default_factory=list)
41-
4239
# private variables
4340
_prune_n: Optional[int] = PrivateAttr(default=None)
4441
_prune_m: Optional[int] = PrivateAttr(default=None)
4542
_module_names: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
4643
_target_layers: Dict[str, torch.nn.Module] = PrivateAttr(default_factory=dict)
4744
_module_sparsities: Dict[torch.nn.Module, str] = PrivateAttr(default_factory=dict)
4845

46+
# deprecated
47+
sequential_update: Union[Sentinel, Any] = Sentinel("deprecated")
48+
sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated")
49+
4950
@field_validator("sequential_update", mode="before")
5051
def validate_sequential_update(cls, value: bool) -> bool:
51-
if not value:
52+
if value is not Sentinel("deprecated"):
5253
warnings.warn(
5354
"`sequential_update=False` is no longer supported, setting "
5455
"sequential_update=True",
5556
DeprecationWarning,
5657
)
5758

58-
return True
59+
@field_validator("sequential_targets", mode="before")
60+
def validate_sequential_targets(cls, value: bool) -> bool:
61+
if value is not Sentinel("deprecated"):
62+
raise ValueError(
63+
"Setting `sequential_targets` via modifiers is no longer supported, "
64+
"Please use `oneshot(sequential_targets=...)`"
65+
)
5966

6067
@field_validator("sparsity_profile", mode="before")
6168
def validate_sparsity_profile(cls, value: Optional[str]) -> bool:
@@ -109,12 +116,12 @@ def on_initialize(self, state: "State", **kwargs) -> bool:
109116
110117
:param state: session state storing input model and calibration data
111118
"""
112-
model: torch.nn.Module = state.model
119+
model: PreTrainedModel = state.model
113120
dataloader: torch.utils.data.DataLoader = state.data.calib
114121

115122
# infer module and sequential targets
116-
self.sequential_targets = self._infer_sequential_targets(model)
117-
layers = get_layers(self.sequential_targets, model)
123+
sequential_targets = model._get_no_split_modules("auto")
124+
layers = get_layers(sequential_targets, model)
118125
self._target_layers = get_layers(
119126
self.targets, model
120127
) # layers containing targets
@@ -191,15 +198,6 @@ def on_end(self, state: State, event: Event, **kwargs):
191198
self.ended_ = True
192199
self.remove_hooks()
193200

194-
def _infer_sequential_targets(
195-
self, model: torch.nn.Module
196-
) -> Union[str, List[str]]:
197-
if self.sequential_targets is None:
198-
return get_no_split_params(model)
199-
if isinstance(self.sequential_targets, str):
200-
return [self.sequential_targets]
201-
return self.sequential_targets
202-
203201
def _infer_owl_layer_sparsity(
204202
self,
205203
model: torch.nn.Module,

src/llmcompressor/modifiers/pruning/wanda/base.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,14 +36,15 @@ class WandaPruningModifier(SparsityModifierBase):
3636
Lifecycle:
3737
- on_initialize
3838
- register_hook(module, calibrate_module, "forward")
39-
- run_sequential / run_layer_sequential / run_basic
40-
- make_empty_row_scalars
41-
- accumulate_row_scalars
4239
- on_sequential_batch_end
4340
- sparsify_weight
4441
- on_finalize
4542
- remove_hooks()
4643
44+
:param targets: list of module names to quantize if a scheme is provided. Defaults
45+
to Linear layers
46+
:param ignore: optional list of module class names or submodule names to not
47+
quantize even if they match a target. Defaults to empty list.
4748
:param sparsity: Sparsity to compress model to
4849
:param sparsity_profile: Can be set to 'owl' to use Outlier Weighed
4950
Layerwise Sparsity (OWL), more information can be found
@@ -53,12 +54,6 @@ class WandaPruningModifier(SparsityModifierBase):
5354
shape. Defaults to 0:0 which represents an unstructured mask.
5455
:param owl_m: Number of outliers to use for OWL
5556
:param owl_lmbda: Lambda value to use for OWL
56-
:param sequential_targets: list of layer names to compress during OBCQ, or '__ALL__'
57-
to compress every layer in the model. Alias for `targets`
58-
:param targets: list of layer names to compress during OBCQ, or '__ALL__'
59-
to compress every layer in the model. Alias for `sequential_targets`
60-
:param ignore: optional list of module class names or submodule names to not
61-
quantize even if they match a target. Defaults to empty list.
6257
"""
6358

6459
# private variables

src/llmcompressor/modifiers/quantization/gptq/base.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import contextlib
22
import warnings
3-
from typing import Dict, List, Optional, Tuple, Union
3+
from typing import Any, Dict, Optional, Tuple, Union
44

55
import torch
66
from compressed_tensors.quantization import (
@@ -61,7 +61,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
6161
6262
Lifecycle:
6363
- on_initialize
64-
- apply config to model
64+
- apply quantization config to model
6565
- on_start
6666
- add activation calibration hooks
6767
- add gptq weight calibration hooks
@@ -71,8 +71,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
7171
- remove_hooks()
7272
- model.apply(freeze_module_quantization)
7373
74-
:param sequential_targets: list of layer names to compress during GPTQ, or
75-
'__ALL__' to compress every layer in the model
7674
:param block_size: Used to determine number of columns to compress in one pass
7775
:param dampening_frac: Amount of dampening to apply to H, as a fraction of the
7876
diagonal norm
@@ -83,7 +81,7 @@ class GPTQModifier(Modifier, QuantizationMixin):
8381
8482
:param config_groups: dictionary specifying quantization schemes to apply to target
8583
modules. Modules not matching a scheme target will NOT be quantized.
86-
:param targets: list of layer names to quantize if a scheme is provided. Defaults
84+
:param targets: list of module names to quantize if a scheme is provided. Defaults
8785
to Linear layers
8886
:param ignore: optional list of module class names or submodule names to not
8987
quantize even if they match a target in config_groups. Defaults to empty list.
@@ -106,8 +104,6 @@ class GPTQModifier(Modifier, QuantizationMixin):
106104
"""
107105

108106
# gptq modifier arguments
109-
sequential_update: bool = True # DEPRECATED
110-
sequential_targets: Union[str, List[str], None] = None
111107
block_size: int = 128
112108
dampening_frac: Optional[float] = 0.01
113109
actorder: Optional[Union[ActivationOrdering, Sentinel]] = None
@@ -118,16 +114,26 @@ class GPTQModifier(Modifier, QuantizationMixin):
118114
_hessians: Dict[torch.nn.Module, torch.Tensor] = PrivateAttr(default_factory=dict)
119115
_num_samples: Dict[torch.nn.Module, int] = PrivateAttr(default_factory=dict)
120116

117+
# deprecated
118+
sequential_update: Union[Sentinel, Any] = Sentinel("deprecated")
119+
sequential_targets: Union[Sentinel, Any] = Sentinel("deprecated")
120+
121121
@field_validator("sequential_update", mode="before")
122122
def validate_sequential_update(cls, value: bool) -> bool:
123-
if not value:
123+
if value is not Sentinel("deprecated"):
124124
warnings.warn(
125125
"`sequential_update=False` is no longer supported, setting "
126126
"sequential_update=True",
127127
DeprecationWarning,
128128
)
129129

130-
return True
130+
@field_validator("sequential_targets", mode="before")
131+
def validate_sequential_targets(cls, value: bool) -> bool:
132+
if value is not Sentinel("deprecated"):
133+
raise ValueError(
134+
"Setting `sequential_targets` via modifiers is no longer supported, "
135+
"Please use `oneshot(sequential_targets=...)`"
136+
)
131137

132138
def resolve_quantization_config(self) -> QuantizationConfig:
133139
config = super().resolve_quantization_config()

src/llmcompressor/pipelines/layer_sequential/helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,7 @@ def capture_first_layer_intermediates(
6969
desc = "Preparing intermediates cache"
7070
for batch_index, batch in enumerate(tqdm.tqdm(dataloader, desc=desc)):
7171
batch = apply_pad_mask_to_batch(batch) if mask_padding else batch
72-
batch = tensors_to_device(batch, model_device)
72+
batch = tensors_to_device(batch, torch.device("cpu"))
7373

7474
try:
7575
model(**batch)

src/llmcompressor/pipelines/layer_sequential/pipeline.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from compressed_tensors.utils import disable_offloading, get_execution_device
66
from torch.utils.data.dataloader import DataLoader
77

8-
from llmcompressor.core import LifecycleCallbacks, active_session
8+
from llmcompressor.core import LifecycleCallbacks
99
from llmcompressor.modifiers.utils.hooks import HooksMixin
1010
from llmcompressor.pipelines.cache import IntermediatesCache
1111
from llmcompressor.pipelines.layer_sequential.helpers import (
@@ -17,7 +17,7 @@
1717
from llmcompressor.pipelines.registry import CalibrationPipeline
1818
from llmcompressor.pipelines.sequential.helpers import (
1919
dispatch_for_sequential,
20-
get_sequential_targets,
20+
infer_sequential_targets,
2121
)
2222
from llmcompressor.utils.helpers import DisableQuantization, calibration_forward_context
2323

@@ -56,15 +56,15 @@ def __call__(
5656
:param dataloader: loads data for calibration
5757
:param dataset_args: dataset arguments relevant to pipelines
5858
"""
59-
session = active_session()
59+
# prepare model for sequential onloading
60+
dispatch_for_sequential(model)
6061

6162
# prepare model for sequential onloading
6263
dispatch_for_sequential(model)
6364
model_device = get_execution_device(model)
6465

6566
# find layers
66-
modifiers = session.get_modifiers()
67-
sequential_targets = get_sequential_targets(modifiers, model, dataset_args)
67+
sequential_targets = infer_sequential_targets(model, dataset_args)
6868
layers = match_modules(model, sequential_targets)
6969

7070
LifecycleCallbacks.calibration_epoch_start()

0 commit comments

Comments
 (0)