Skip to content

Commit f03850f

Browse files
authored
Update test_rejection_sampler.py
1 parent 060e5d2 commit f03850f

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

tests/sample/test_rejection_sampler.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,14 +8,14 @@
88
from vllm.v1.spec_decode.metadata import SpecDecodeMetadata
99

1010
from vllm_ascend.sample.rejection_sampler import (PLACEHOLDER_TOKEN_ID,
11-
RejectionSampler)
11+
AscendRejectionSampler)
1212

1313
DEVICE = "npu"
1414

1515

1616
@pytest.fixture
1717
def rejection_sampler():
18-
return RejectionSampler()
18+
return AscendRejectionSampler()
1919

2020

2121
def create_logits_tensor(output_token_ids: list[list[int]],
@@ -423,7 +423,7 @@ def estimate_rejection_sampling_pdf(
423423
Returns:
424424
Estimated probability distribution of the output tokens.
425425
"""
426-
rejection_sampler = RejectionSampler()
426+
rejection_sampler = AscendRejectionSampler()
427427
num_tokens = num_samples * k
428428
# Repeat draft probs num_samples * k times.
429429
draft_probs = draft_probs.reshape(1, 1,

0 commit comments

Comments
 (0)