Skip to content

Commit 3d7ed6a

Browse files
xiez22xiezhe-24gemini-code-assist[bot]
authored
[Bug Fix] Fix incorrect argument inference in AutoWrapper with self assignment (#1616)
SUMMARY: Fix a bug in the AutoWrapper `_wrap_stmt` logic where a variable that is both read and written (e.g., `x = x + 1`) is incorrectly excluded from the unbound argument list. This caused some internal variables (like `attention_mask`) to be missing from wrapped function inputs, leading to runtime errors when using torch.fx. An example of reproducing this bug: ```python import ast from llmcompressor.pipelines.sequential.ast_utils.name_analyzer import NameAnalyzer from llmcompressor.pipelines.sequential.ast_utils.auto_wrapper import AutoWrapper code = """ def forward(x, y): if y > 0: x = x + 1 else: x = x - 1 return x """ tree = ast.parse(code) analyzer = NameAnalyzer(omit=set()) unbound, assigned, cond = analyzer.analyze(tree) print("=== Result ===") print("Unbound:", unbound) print("Assigned:", assigned) print("Conditionally Assigned:", cond) namespace = {} wrapper = AutoWrapper(namespace, ignore=[]) wrapped_tree = wrapper.auto_wrap(ast.Module(body=[tree], type_ignores=[])) print("\n=== Wrapped Code ===") print(ast.unparse(wrapped_tree)) ``` The output of the original code: ``` === Result === Unbound: {'y'} Assigned: {'x'} Conditionally Assigned: set() === Wrapped Code === @torch.fx.wrap def wrapped_8751792970731(y): if y > 0: x = x + 1 else: x = x - 1 return (x,) def forward(x, y): x, = wrapped_8751792970731(y) return x ``` `x` is missing from the input args of `wrapped_8751792970731`, leading to unbound variable in the wrapped function The fix add a `visit_Assign` function to properly detect and classify variables as unbound when they are read before assignment, even within the same statement. TEST PLAN: Added a new test case `test_branch_with_self_assignment` to `tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py` --------- Signed-off-by: xiezhe.24 <xiezhe.24@bytedance.com> Co-authored-by: xiezhe.24 <xiezhe.24@bytedance.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
1 parent adaa6bb commit 3d7ed6a

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/llmcompressor/pipelines/sequential/ast_utils/name_analyzer.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,14 @@ def visit_Name(self, node: ast.Name) -> ast.Name:
6666

6767
self.generic_visit(node)
6868

69+
def visit_Assign(self, node: ast.Assign):
70+
# Visit the right side of the assignment first
71+
self.visit(node.value)
72+
73+
# Now visit the left side of the assignment
74+
for target in node.targets:
75+
self.visit(target)
76+
6977
def visit_If(self, node: ast.If):
7078
self.visit(node.test)
7179

tests/llmcompressor/pipelines/sequential/ast_utils.py/test_auto_wrapper.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -98,3 +98,28 @@ def forward(self):
9898
"""
9999
namespace = {"self": Model()}
100100
check_wrapping(source, None, 1, namespace=namespace, ignore=["meth_one"])
101+
102+
103+
def test_branch_with_self_assignment():
104+
source = """
105+
def forward(x, y):
106+
if y > 0:
107+
x = x + 1
108+
else:
109+
x = x - 1
110+
return x
111+
"""
112+
113+
tree = ast.parse(textwrap.dedent(source))
114+
wrapper = AutoWrapper(namespace={}, ignore=[])
115+
wrapper.auto_wrap(tree)
116+
117+
assert len(wrapper._wrapper_fn_defs) == 1
118+
119+
# Check if both x, y are included in args
120+
wrapped_fn = wrapper._wrapper_fn_defs[0]
121+
arg_names = {arg.arg for arg in wrapped_fn.args.args}
122+
123+
assert arg_names == {"x", "y"}, (
124+
f"Expected arguments {{'x', 'y'}}, but got {arg_names}"
125+
)

0 commit comments

Comments
 (0)