Skip to content

Commit b18284b

Browse files
authored
[Serving] Enable GPU Sampling (#2368)
enable gpu sampling
1 parent 135419e commit b18284b

File tree

4 files changed

+15
-6
lines changed

4 files changed

+15
-6
lines changed

cpp/serve/sampler/gpu_sampler.cc

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -545,7 +545,8 @@ class GPUSampler : public SamplerObj {
545545
if (!need_top_p && !need_prob_values) {
546546
// - Short path: If top_p and prob values are not needed, we directly sample from multinomial.
547547
SyncCopyStream(device_, compute_stream_, copy_stream_);
548-
if (flashinfer_multinomial_sample_func_ != nullptr) {
548+
if (device_.device_type == DLDeviceType::kDLCUDA &&
549+
flashinfer_multinomial_sample_func_ != nullptr) {
549550
sampled_token_ids_device =
550551
sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_);
551552
(*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device,
@@ -588,7 +589,8 @@ class GPUSampler : public SamplerObj {
588589
uniform_samples_device, sample_indices_device, top_p_device);
589590
} else {
590591
// - Sample without top_p.
591-
if (flashinfer_multinomial_sample_func_ != nullptr) {
592+
if (device_.device_type == DLDeviceType::kDLCUDA &&
593+
flashinfer_multinomial_sample_func_ != nullptr) {
592594
sampled_token_ids_device =
593595
sampled_token_ids_device_.CreateView({sample_indices_device->shape[0]}, dtype_i32_);
594596
(*flashinfer_multinomial_sample_func_)(probs_on_device, uniform_samples_device,

cpp/serve/sampler/sampler.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,8 @@ class Sampler : public ObjectRef {
140140

141141
/*! \brief Check if the given device supports GPU sampling. */
142142
static bool SupportGPUSampler(Device device) {
143-
return device.device_type == DLDeviceType::kDLCUDA;
143+
return device.device_type == DLDeviceType::kDLCUDA ||
144+
device.device_type == DLDeviceType::kDLVulkan;
144145
}
145146

146147
TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Sampler, ObjectRef, SamplerObj);

python/mlc_llm/compiler_pass/attach_sampler.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def __init__(self, target: tvm.target.Target, variable_bounds: Dict[str, int]):
2828

2929
def transform_module(self, mod: IRModule, _ctx: tvm.transform.PassContext) -> IRModule:
3030
"""Entrypoint"""
31-
if str(self.target.kind) != "cuda":
31+
if str(self.target.kind) not in ["cuda", "vulkan"]:
3232
# Only enable GPU sampling for CUDA.
3333
return mod
3434

@@ -87,7 +87,11 @@ def _attach_multinomial_sampling_func(bb: relax.BlockBuilder):
8787
name="sample_indices",
8888
)
8989
result_tensor = nn.multinomial_from_uniform( # pylint:disable=too-many-function-args
90-
probs_tensor, uniform_samples_tensor, sample_indices_tensor, "int32"
90+
probs_tensor,
91+
uniform_samples_tensor,
92+
sample_indices_tensor,
93+
"int32",
94+
name="nn_multinomial_from_uniform",
9195
)
9296
result = bb.emit(
9397
relax.call_pure_packed(
@@ -97,7 +101,8 @@ def _attach_multinomial_sampling_func(bb: relax.BlockBuilder):
97101
sinfo_args=sample_indices.struct_info, # pylint: disable=no-member
98102
)
99103
)
100-
gv = bb.emit_func_output(result)
104+
output = bb.emit_output(result)
105+
gv = bb.emit_func_output(output)
101106
return gv
102107

103108

python/mlc_llm/compiler_pass/pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def _pipeline(mod: tvm.ir.IRModule, _ctx: tvm.transform.PassContext) -> tvm.ir.I
121121
_DebugDump("debug-phase1.py", debug_dump, show_meta=False),
122122
# Phase 2. Lowering to TIR, inherited TVM Relax's official "zero" pipeline
123123
_LogProgress("Lowering to TVM TIR kernels"),
124+
tvm.relax.backend.DispatchSampling(),
124125
tvm.relax.backend.DispatchSortScan(),
125126
tvm.relax.transform.LegalizeOps(),
126127
tvm.relax.transform.AnnotateTIROpPattern(),

0 commit comments

Comments
 (0)