Skip to content

Commit 9e8210b

Browse files
authored
[Tracing] Update ignored functions list (#1599)
## Purpose ## * Fix tracing for model definitions introduced as part of `transformers==4.53` * Resolves #1603 ## Background ## In the latest transformers release, this change landed which changed the name of the function which generates the causal mask. huggingface/transformers#37866 ## Changes ## * Extend the list of function names to ignore during tracing, specifically targeting functions which create causal masks * Update debugger tool to use ignore list from `DatasetArguments` * Update Tracer to skip masking function as part of autowrapping any functions which were not caught by the autowrapper ## Testing ## * `tests/llmcompressor/transformers/tracing/test_models.py` now passes with the latest `transformers==4.53` --------- Signed-off-by: Kyle Sayers <kylesayrs@gmail.com>
1 parent 5f9af3c commit 9e8210b

File tree

4 files changed

+31
-11
lines changed

4 files changed

+31
-11
lines changed

src/llmcompressor/args/dataset_arguments.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,17 @@ class DatasetArguments(CustomDatasetArguments):
181181
},
182182
)
183183
tracing_ignore: List[str] = field(
184-
default_factory=lambda: ["_update_causal_mask"],
184+
default_factory=lambda: [
185+
"_update_causal_mask",
186+
"create_causal_mask",
187+
"make_causal_mask",
188+
"get_causal_mask",
189+
"mask_interface",
190+
"mask_function",
191+
"_prepare_4d_causal_attention_mask",
192+
"_prepare_fsmt_decoder_inputs",
193+
"_prepare_4d_causal_attention_mask_with_cache_position",
194+
],
185195
metadata={
186196
"help": "List of functions to ignore during tracing, either "
187197
"{module}.{method_name} or {function_name}"

src/llmcompressor/pipelines/sequential/ast_utils/auto_wrapper.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -180,15 +180,9 @@ def _wrap_if_possible(self, node: ast.AST) -> Union[ast.AST, ast.Assign, ast.Cal
180180
return node
181181

182182
if isinstance(node, ast.stmt):
183-
logger.debug("---- Autowrapper ----")
184-
logger.debug(ast.unparse(node))
185-
logger.debug("---------------------")
186183
return self._wrap_stmt(node)
187184

188185
elif isinstance(node, ast.expr):
189-
logger.debug("---- Autowrapper ----")
190-
logger.debug(ast.unparse(node))
191-
logger.debug("---------------------")
192186
return self._wrap_expr(node)
193187

194188
else:
@@ -254,6 +248,11 @@ def _wrap_stmt(self, node: ast.stmt) -> ast.Assign:
254248
# update local names with newly returned values
255249
self._local_names |= returns
256250

251+
# log newly created function definition
252+
logger.debug("---- Autowrapper ----")
253+
logger.debug(ast.unparse(ast.fix_missing_locations(fn_def)))
254+
logger.debug("---------------------")
255+
257256
return assign_call
258257

259258
def _wrap_expr(self, node: ast.expr) -> ast.Call:

src/llmcompressor/pipelines/sequential/helpers.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
from collections import deque
44
from dataclasses import dataclass
5-
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
5+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple
66

77
import torch
88
from compressed_tensors.quantization import find_name_or_class_matches
@@ -169,10 +169,12 @@ class SequentialTracer(HFTracer):
169169
"""
170170

171171
def __init__(self, ancestors: Set[Module], offloaded: Set[Module]):
172-
super().__init__()
173172
self.ancestors = ancestors
174173
self.offloaded = offloaded
175174

175+
# skip any mask creation functions not already caught by the autowrapper
176+
super().__init__(autowrap_functions=_get_autowrap_functions())
177+
176178
# check unlikely case that ancestors have direct params which are offloaded
177179
offloaded_ancestors = offloaded & ancestors
178180
if offloaded_ancestors:
@@ -531,3 +533,12 @@ def dispatch_for_sequential(model: PreTrainedModel) -> PreTrainedModel:
531533
logger.warning("CUDA is not available! Compressing model on CPU instead")
532534

533535
return model
536+
537+
538+
def _get_autowrap_functions() -> Tuple[Callable[[Any], Any], ...]:
539+
try:
540+
from transformers.masking_utils import LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING
541+
542+
return tuple(LAYER_PATTERN_TO_MASK_FUNCTION_MAPPING.values())
543+
except ImportError:
544+
return tuple()

src/llmcompressor/transformers/tracing/debug.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ def parse_args():
2222
parser.add_argument("--model_id", type=str, required=True, help="The stub of the model to load") # noqa: E501
2323
parser.add_argument("--model_class", type=str, required=True, help="The class name of the model") # noqa: E501
2424
parser.add_argument("--sequential_targets", type=str, nargs="*", default=None, metavar="TARGET", help="List of targets for sequential tracing") # noqa: E501
25-
parser.add_argument("--ignore", type=str, nargs="*", default=["_update_causal_mask"], metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
25+
parser.add_argument("--ignore", type=str, nargs="*", default=DatasetArguments().tracing_ignore, metavar="PATTERN", help="List of patterns to ignore during tracing") # noqa: E501
2626
parser.add_argument("--modality", type=str, default="text", help="Modality of calibration dataset, defaults to text") # noqa: E501
2727
parser.add_argument("--trust_remote_code", type=bool, default=False, help="Whether to trust model remote code") # noqa: E501
2828
parser.add_argument("--skip_weights", type=bool, default=True, help="Whether to load the model with dummy weights") # noqa: E501
@@ -34,7 +34,7 @@ def trace(
3434
model_id: str,
3535
model_class: Type[PreTrainedModel],
3636
sequential_targets: Optional[Union[List[str], str]] = None,
37-
ignore: Union[List[str], str] = ["_update_causal_mask"],
37+
ignore: Union[List[str], str] = DatasetArguments().tracing_ignore,
3838
modality: str = "text",
3939
trust_remote_code: bool = True,
4040
skip_weights: bool = True,

0 commit comments

Comments
 (0)