Skip to content

Commit db1ebc7

Browse files
kylesayrsdsikka
andauthored
[Bugfix] Remove tracing imports from tests (#1498)
## Purpose ## * We no longer use traceable definitions as of #1411, so we should no longer be importing these definitions in the tests ## Changes ## * Remove traceable model definition imports from test configs * Remove `get_model_class` utility, which is no longer necessary since we only import model classes from transformers now ## Testing ## ```bash grep -r -i 'trace' src tests examples ``` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com> Co-authored-by: Dipika Sikka <dipikasikka1@gmail.com>
1 parent 4edaa1a commit db1ebc7

File tree

6 files changed

+9
-18
lines changed

6 files changed

+9
-18
lines changed
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
1-
from .debug import get_model_class
1+
from .debug import trace
22

3-
__all__ = ["get_model_class"]
3+
__all__ = ["trace"]

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,15 +7,14 @@
77
import transformers
88
from transformers import AutoProcessor, PreTrainedModel
99

10-
from llmcompressor.transformers import tracing
1110
from llmcompressor.utils.pytorch.module import get_no_split_params
1211
from llmcompressor.pipelines.sequential.helpers import trace_subgraphs, Subgraph
1312
from llmcompressor.transformers import TextGenerationDataset
1413
from llmcompressor.args import DatasetArguments
1514

1615
from llmcompressor.utils.dev import skip_weights_download
1716

18-
__all__ = ["get_model_class"]
17+
__all__ = ["trace"]
1918

2019

2120
def parse_args():
@@ -111,14 +110,6 @@ def trace(
111110
return model, subgraphs, sample
112111

113112

114-
def get_model_class(model_class: str) -> Type[PreTrainedModel]:
115-
model_cls = getattr(tracing, model_class, getattr(transformers, model_class, None))
116-
if model_cls is None:
117-
raise ValueError(f"Could not import model class {model_class}")
118-
119-
return model_cls
120-
121-
122113
def get_dataset_kwargs(modality: str, ignore: List[str]) -> Dict[str, str]:
123114
dataset_kwargs = {
124115
"text": {
@@ -167,7 +158,7 @@ def main():
167158

168159
trace(
169160
model_id=args.model_id,
170-
model_class=get_model_class(args.model_class),
161+
model_class=getattr(transformers, args.model_class),
171162
sequential_targets=args.sequential_targets,
172163
ignore=args.ignore,
173164
modality=args.modality,

tests/e2e/e2e_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,11 @@
11
import torch
2+
import transformers
23
from datasets import load_dataset
34
from loguru import logger
45
from transformers import AutoProcessor
56

67
from llmcompressor import oneshot
78
from llmcompressor.modifiers.quantization import GPTQModifier, QuantizationModifier
8-
from llmcompressor.transformers.tracing import get_model_class
99
from tests.test_timer.timer_utils import log_time
1010
from tests.testing_utils import process_dataset
1111

@@ -16,7 +16,7 @@ def _load_model_and_processor(
1616
model_class: str,
1717
device: str,
1818
):
19-
pretrained_model_class = get_model_class(model_class)
19+
pretrained_model_class = getattr(transformers, model_class)
2020
loaded_model = pretrained_model_class.from_pretrained(
2121
model, device_map=device, torch_dtype="auto"
2222
)

tests/lmeval/configs/vl_fp8_dynamic_per_token.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cadence: weekly
22
model: Qwen/Qwen2.5-VL-7B-Instruct
3-
model_class: TraceableQwen2_5_VLForConditionalGeneration
3+
model_class: Qwen2_5_VLForConditionalGeneration
44
scheme: FP8_DYNAMIC
55
lmeval:
66
model: "hf-multimodal"

tests/lmeval/configs/vl_int8_w8a8_dynamic_per_token.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cadence: "weekly"
22
model: Qwen/Qwen2.5-VL-7B-Instruct
3-
model_class: TraceableQwen2_5_VLForConditionalGeneration
3+
model_class: Qwen2_5_VLForConditionalGeneration
44
scheme: INT8_dyn_per_token
55
recipe: tests/e2e/vLLM/recipes/INT8/recipe_int8_channel_weight_dynamic_per_token.yaml
66
dataset_id: lmms-lab/flickr30k

tests/lmeval/configs/vl_w4a16_actorder_weight.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
cadence: "weekly"
22
model: Qwen/Qwen2.5-VL-7B-Instruct
3-
model_class: TraceableQwen2_5_VLForConditionalGeneration
3+
model_class: Qwen2_5_VLForConditionalGeneration
44
scheme: W4A16_actorder_weight
55
recipe: tests/e2e/vLLM/recipes/actorder/recipe_w4a16_actorder_weight.yaml
66
dataset_id: lmms-lab/flickr30k

0 commit comments

Comments
 (0)