Skip to content

Commit b668d96

Browse files
committed
Signed-off-by: Yikun Jiang <yikunkero@gmail.com>
1 parent 73ff936 commit b668d96

File tree

2 files changed

+30
-9
lines changed

2 files changed

+30
-9
lines changed

tests/singlecard/compile/test_simple.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
set_current_vllm_config)
1515
from vllm.utils import direct_register_custom_op
1616

17+
from vllm_ascend.utils import vllm_version_is
18+
1719
global_counter = 0
1820

1921
# create a library to hold the custom op
@@ -92,14 +94,28 @@ def test_simple_piecewise_compile():
9294

9395
inputs = torch.randn(100).npu()
9496

95-
with compilation_counter.expect(
96-
num_graphs_seen=1, # one graph for the model
97-
num_piecewise_graphs_seen=5, # 2 * num_layers + 1
98-
num_piecewise_capturable_graphs_seen=3, # 1 + num_layers
99-
num_backend_compilations=3, # num_piecewise_capturable_graphs_seen
100-
num_cudagraph_caputured=
101-
6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
102-
):
97+
if vllm_version_is("0.9.0"):
98+
kwargs = {
99+
"num_graphs_seen": 1, # one graph for the model
100+
"num_piecewise_graphs_seen": 5, # 2 * num_layers + 1
101+
"num_piecewise_capturable_graphs_seen": 3, # 1 + num_layers
102+
"num_backend_compilations":
103+
3, # num_piecewise_capturable_graphs_seen
104+
"num_cudagraph_caputured":
105+
6 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
106+
}
107+
else:
108+
kwargs = {
109+
"num_graphs_seen": 1, # one graph for the model
110+
"num_piecewise_graphs_seen": 5, # 2 * num_layers + 1
111+
"num_piecewise_capturable_graphs_seen": 3, # 1 + num_layers
112+
"num_backend_compilations":
113+
3, # num_piecewise_capturable_graphs_seen
114+
"num_cudagraph_captured":
115+
6 # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen
116+
}
117+
118+
with compilation_counter.expect(kwargs):
103119

104120
model(inputs)
105121

vllm_ascend/compilation/piecewise_backend.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,8 @@
3131
from vllm.logger import logger
3232
from vllm.utils import weak_ref_tensors
3333

34+
from vllm_ascend.utils import vllm_version_is
35+
3436

3537
@dataclasses.dataclass
3638
class ConcreteSizeEntry:
@@ -205,7 +207,10 @@ def __call__(self, *args) -> Any:
205207
entry.output = weak_ref_tensors(output)
206208
entry.aclgraph = aclgraph
207209

208-
compilation_counter.num_cudagraph_caputured += 1
210+
if vllm_version_is("0.9.0"):
211+
compilation_counter.num_cudagraph_caputured += 1
212+
else:
213+
compilation_counter.num_cudagraph_captured += 1
209214

210215
# important: we need to return the output, rather than
211216
# the weak ref of the output, so that pytorch can correctly

0 commit comments

Comments
 (0)