Skip to content

Commit a18d381

Browse files
committed
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 73ff936 commit a18d381

File tree

2 files changed

+26
-9
lines changed

2 files changed

+26
-9
lines changed

tests/singlecard/compile/test_simple.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
set_current_vllm_config)
1515
from vllm.utils import direct_register_custom_op
1616

17+
from vllm_ascend.utils import vllm_version_is
18+
1719
global_counter = 0
1820

1921
# create a library to hold the custom op
@@ -92,14 +94,24 @@ def test_simple_piecewise_compile():
9294

9395
inputs = torch.randn(100).npu()
9496

95-
with compilation_counter.expect(
96-
num_graphs_seen=1, # one graph for the model
97-
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
98-
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
99-
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
100-
num_cudagraph_caputured=
101-
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
102-
):
97+
if vllm_version_is("0.9.0"):
98+
kwargs = {
99+
"num_graphs_seen": 1, # one graph for the model
100+
"num_piecewise_graphs_seen": 5, # 2 * num_layers + 1
101+
"num_piecewise_capturable_graphs_seen": 3, # 1 + num_layers
102+
"num_backend_compilations": 3, # num_piecewise_capturable_graphs_seen
103+
"num_cudagraph_caputured": 6 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
104+
}
105+
else:
106+
kwargs = {
107+
"num_graphs_seen": 1, # one graph for the model
108+
"num_piecewise_graphs_seen": 5, # 2 * num_layers + 1
109+
"num_piecewise_capturable_graphs_seen": 3, # 1 + num_layers
110+
"num_backend_compilations": 3, # num_piecewise_capturable_graphs_seen
111+
"num_cudagraph_captured": 6 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
112+
}
113+
114+
with compilation_counter.expect(kwargs):
103115

104116
model(inputs)
105117

vllm_ascend/compilation/piecewise_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from vllm.logger import logger
3232
from vllm.utils import weak_ref_tensors
3333

34+
from vllm_ascend.utils import vllm_version_is
35+
3436

3537
@dataclasses.dataclass
3638
class ConcreteSizeEntry:
@@ -205,7 +207,10 @@ def __call__(self, *args) -> Any:
205207
entry.output = weak_ref_tensors(output)
206208
entry.aclgraph = aclgraph
207209

208-
compilation_counter.num_cudagraph_caputured += 1
210+
if vllm_version_is("0.9.0"):
211+
compilation_counter.num_cudagraph_caputured += 1
212+
else:
213+
compilation_counter.num_cudagraph_captured += 1
209214

210215
# important: we need to return the output, rather than
211216
# the weak ref of the output, so that pytorch can correctly

0 commit comments

Comments
 (0)