Skip to content

Commit aa2d5aa

Browse files
authored
Merge pull request #788 from turboderp-org/dev
Merge Dev into Master
2 parents c820539 + e312b74 commit aa2d5aa

File tree

7 files changed

+204
-36
lines changed

7 files changed

+204
-36
lines changed

exllamav2/architecture.py

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@
5353
["block_sparse_moe.experts.*.w2"],
5454
["block_sparse_moe.experts.*.w3"],
5555
["block_sparse_moe.gate"]]
56+
layer_keys_qwen3moe_mlp = [["mlp.experts.*.gate_proj"],
57+
["mlp.experts.*.up_proj"],
58+
["mlp.experts.*.down_proj"],
59+
["mlp.gate"]]
5660
layer_keys_dbrx_mlp = [["block_sparse_moe.experts.*.v1", "block_sparse_moe.experts.v1"],
5761
["block_sparse_moe.experts.*.w1", "block_sparse_moe.experts.w1"],
5862
["block_sparse_moe.experts.*.w2", "block_sparse_moe.experts.w2"],
@@ -428,6 +432,39 @@ class Params:
428432
self.lm.attention_bias_qkv = True
429433
self.lm.supports_tp = True
430434

435+
# Qwen3
436+
437+
if arch_string == "Qwen3ForCausalLM":
438+
arch_recognized = True
439+
self.lm.layer_keys += \
440+
layer_keys_llama_norms + \
441+
layer_keys_llama_attn + \
442+
layer_keys_llama_mlp
443+
self.lm.expect_keys += \
444+
expect_keys_llama
445+
self.lm.supports_tp = True
446+
self.lm.default_use_qk_norm = True
447+
448+
# Qwen3MoE
449+
450+
if arch_string == "Qwen3MoeForCausalLM":
451+
arch_recognized = True
452+
self.lm.layer_keys += \
453+
layer_keys_llama_norms + \
454+
layer_keys_llama_attn + \
455+
layer_keys_qwen3moe_mlp
456+
self.lm.expect_keys += \
457+
expect_keys_llama
458+
self.lm.supports_tp = True
459+
self.lm.default_use_qk_norm = True
460+
self.lm.keys.update({
461+
"mlp_gate": ".mlp.experts.*.gate_proj",
462+
"mlp_up": ".mlp.experts.*.up_proj",
463+
"mlp_down": ".mlp.experts.*.down_proj",
464+
"mlp_expert_gate": ".mlp.gate"
465+
})
466+
self.lm.is_moe = True
467+
431468
# Qwen2-VL (2, 2.5)
432469

433470
if arch_string in ["Qwen2VLForConditionalGeneration", "Qwen2_5_VLForConditionalGeneration"]:

exllamav2/config.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -319,9 +319,12 @@ def prepare(self, no_tensors: bool = False):
319319
default_intermediate_size,
320320
opt_subkey = "text_config",
321321
)
322-
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts"], None)
322+
self.num_experts = read(read_config, int, ["num_local_experts", "ffn_config->moe_num_experts", "num_experts"], None)
323323
self.num_experts_per_token = read(read_config, int,["num_experts_per_tok", "ffn_config->moe_top_k"], None)
324324

325+
if self.arch.lm.is_moe:
326+
self.intermediate_size = read(read_config, int, ["moe_intermediate_size"], self.intermediate_size)
327+
325328
# Logit/embedding/residual scale
326329

327330
self.logit_scale = read(read_config, float, "logit_scale", 1)

exllamav2/conversion/adaptivegptq.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -229,7 +229,10 @@ def prepare(self, no_h_inv = False):
229229

230230
with torch.inference_mode():
231231

232-
self.hessian /= self.num_batches
232+
if self.hessian is None or self.num_batches == 0:
233+
self.hessian = torch.eye(self.rows, device = self.device, dtype = torch.float)
234+
else:
235+
self.hessian /= self.num_batches
233236
diagonal = torch.diag(self.hessian)
234237

235238
# Prepare weights

exllamav2/exllamav2_ext/cuda/q_mlp.cu

Lines changed: 57 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -324,9 +324,15 @@ void QMoEMLP::forward_
324324
// half* lora_temp
325325
)
326326
{
327-
if (num_experts != 4 && num_experts != 8 && num_experts != 16)
327+
if (rows > MAX_Q_GEMM_WEIGHTS)
328328
{
329-
printf(" ## num_experts must be 4, 8 or 16\n");
329+
printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS);
330+
DBGI(rows);
331+
}
332+
333+
if (num_experts != 4 && num_experts != 8 && num_experts != 16 && num_experts != 128)
334+
{
335+
printf(" ## num_experts must be 4, 8, 16 or 128\n");
330336
return;
331337
}
332338

@@ -354,54 +360,77 @@ void QMoEMLP::forward_
354360
&beta_,
355361
temp_logits, num_experts);
356362

357-
// Compute softmax filter to and normalize top-k outputs
363+
// Select activation kernel
358364

359-
dim3 blockDim, gridDim;
360-
blockDim.x = WARPS;
361-
blockDim.y = 1;
362-
gridDim.x = 1;
363-
gridDim.y = DIVIDE(rows, WARPS);
364-
if (num_experts == 4)
365-
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
366-
else if (num_experts == 8)
367-
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
368-
else if (num_experts == 16)
369-
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
365+
int intermediate_size = w1[0]->width;
366+
fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu);
370367

371368
// For small no. rows, execute all kernels but pass the routing weights. Rows with a weight of zero will skip dot
372369
// product accum and kernels launched with only zero-weights will exit prematurely.
373370

374-
if (rows <= MAX_Q_GEMM_WEIGHTS)
371+
if (num_experts == 4 || num_experts == 8 || num_experts == 16)
375372
{
376-
int intermediate_size = w1[0]->width;
377-
fp_act_mul_kernel kernel = pick_act_mul_kernel(use_half2, true, act_gelu);
373+
dim3 blockDim, gridDim;
374+
blockDim.x = WARPSIZE;
375+
blockDim.y = 1;
376+
gridDim.x = 1;
377+
gridDim.y = DIVIDE(rows, WARPSIZE);
378+
if (num_experts == 4)
379+
softmax4_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
380+
else if (num_experts == 8)
381+
softmax8_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
382+
else if (num_experts == 16)
383+
softmax16_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
378384

379385
for (int i = 0; i < num_experts; i++)
380386
{
381387
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
382388
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
383389

384-
// apply_loras_cuda(cublas_handle, w1_lora[i], loras, w1[i], temp_state, temp_a, lora_temp, rows);
385-
// apply_loras_cuda(cublas_handle, w3_lora[i], loras, w3[i], temp_state, temp_b, lora_temp, rows);
386-
387390
blockDim.x = THREADS_X;
388391
blockDim.y = THREADS_Y;
389392
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
390393
gridDim.y = DIVIDE(rows, THREADS_Y);
391394
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
392395

393396
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);
394-
395-
// apply_loras_cuda(cublas_handle, w2_lora[i], loras, w2[i], temp_a, x, lora_temp, rows);
396397
}
397-
}
398+
}
398399

399-
// Gather larger number of rows in separate batches according to which experts they trigger, evaluate each MLP
400-
// only on the affected rows and scale by routing weights while adding back directly onto the residual hidden state
400+
// For very large number of experts (Qwen3 etc.) copy to CPU, synchronize and only launch top K experts. This is
401+
// not optimal but the kernel launch overhead is very severe otherwise. Really needs a graph
401402

402-
else
403+
else if (num_experts == 128)
403404
{
404-
printf(" ## ropws > %i not implemented\n", MAX_Q_GEMM_WEIGHTS);
405-
DBGI(rows);
405+
dim3 blockDim, gridDim;
406+
blockDim.x = WARPSIZE;
407+
blockDim.y = 1;
408+
gridDim.x = 1;
409+
gridDim.y = DIVIDE(rows, WARPSIZE);
410+
softmax128_topk_norm_kernel<<<gridDim, blockDim, 0, stream>>>(temp_logits, rows, num_experts_per_token);
411+
412+
half* h_logits;
413+
h_logits = (half*) malloc(128 * sizeof(half));
414+
cudaMemcpyAsync(h_logits, temp_logits, 128 * sizeof(half), cudaMemcpyDeviceToHost, stream);
415+
cudaStreamSynchronize(stream);
416+
417+
for (int i = 0; i < num_experts; i++)
418+
{
419+
uint16_t w = __half_as_ushort(h_logits[i]);
420+
if (!w) continue;
421+
422+
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w1[i], temp_a, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
423+
gemm_half_q_half_cuda(stream, cublas_handle, temp_state, w3[i], temp_b, rows, intermediate_size, columns, true, temp_dq, true, temp_logits + i, num_experts, false);
424+
425+
blockDim.x = THREADS_X;
426+
blockDim.y = THREADS_Y;
427+
gridDim.x = DIVIDE(intermediate_size, THREADS_X) / (use_half2 ? 2 : 1);
428+
gridDim.y = DIVIDE(rows, THREADS_Y);
429+
kernel<<<gridDim, blockDim, 0, stream>>>(temp_a, temp_b, rows, intermediate_size, temp_logits + i, num_experts);
430+
431+
gemm_half_q_half_cuda(stream, cublas_handle, temp_a, w2[i], x, rows, columns, intermediate_size, false, temp_dq, true, temp_logits + i, num_experts, true);
432+
}
433+
434+
free(h_logits);
406435
}
407436
}

exllamav2/exllamav2_ext/cuda/q_mlp_softmax.cuh

Lines changed: 98 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11

2-
#define WARPS 32
2+
#define WARPSIZE 32
33

44
__global__ void softmax16_topk_norm_kernel
55
(
@@ -8,7 +8,7 @@ __global__ void softmax16_topk_norm_kernel
88
const int topk
99
)
1010
{
11-
int row = blockIdx.y * WARPS + threadIdx.x;
11+
int row = blockIdx.y * WARPSIZE + threadIdx.x;
1212
if (row >= rows) return;
1313

1414
// Softmax
@@ -122,7 +122,7 @@ __global__ void softmax8_topk_norm_kernel
122122
const int topk
123123
)
124124
{
125-
int row = blockIdx.y * WARPS + threadIdx.x;
125+
int row = blockIdx.y * WARPSIZE + threadIdx.x;
126126
if (row >= rows) return;
127127

128128
// Softmax
@@ -206,7 +206,7 @@ __global__ void softmax4_topk_norm_kernel
206206
const int topk
207207
)
208208
{
209-
int row = blockIdx.y * WARPS + threadIdx.x;
209+
int row = blockIdx.y * WARPSIZE + threadIdx.x;
210210
if (row >= rows) return;
211211

212212
// Softmax
@@ -268,3 +268,97 @@ __global__ void softmax4_topk_norm_kernel
268268
logits_int2.y = l23.as_uint32;
269269
*row_ptr = logits_int2;
270270
}
271+
272+
__global__ void softmax128_topk_norm_kernel
273+
(
274+
half* __restrict__ x,
275+
const int rows,
276+
const int topk
277+
)
278+
{
279+
const int row = blockIdx.y * WARPSIZE + threadIdx.x;
280+
if (row >= rows) return;
281+
282+
register float f[128];
283+
284+
int4* row_ptr = reinterpret_cast<int4*>(x + row * 128);
285+
286+
#pragma unroll
287+
for (int v = 0; v < 16; ++v) // 16 × 8 halfs = 128 halfs
288+
{
289+
int4 v4 = row_ptr[v];
290+
291+
half2_uint32 h0(v4.x), h1(v4.y), h2(v4.z), h3(v4.w);
292+
293+
const int base = v * 8;
294+
f[base + 0] = __low2float (h0.as_half2);
295+
f[base + 1] = __high2float(h0.as_half2);
296+
f[base + 2] = __low2float (h1.as_half2);
297+
f[base + 3] = __high2float(h1.as_half2);
298+
f[base + 4] = __low2float (h2.as_half2);
299+
f[base + 5] = __high2float(h2.as_half2);
300+
f[base + 6] = __low2float (h3.as_half2);
301+
f[base + 7] = __high2float(h3.as_half2);
302+
}
303+
304+
float maxf = -FLT_MAX;
305+
#pragma unroll
306+
for (int i = 0; i < 128; ++i) maxf = fmaxf(maxf, f[i]);
307+
308+
float sum = 0.f;
309+
#pragma unroll
310+
for (int i = 0; i < 128; ++i)
311+
{
312+
float e = __expf(f[i] - maxf);
313+
f[i] = e;
314+
sum += e;
315+
}
316+
317+
constexpr float epsilon = 1e-8f;
318+
const float isum = 1.f / (sum + 128.0f * epsilon);
319+
320+
#pragma unroll
321+
for (int i = 0; i < 128; ++i) f[i] = f[i] * isum + epsilon;
322+
323+
float remaining = 1.0f;
324+
for (int drop = 0; drop < 128 - topk; ++drop)
325+
{
326+
float minv = 1.0f;
327+
int mini = -1;
328+
#pragma unroll
329+
for (int j = 0; j < 128; ++j)
330+
{
331+
if (f[j] > 0.0f && f[j] < minv)
332+
{
333+
minv = f[j];
334+
mini = j;
335+
}
336+
}
337+
remaining -= f[mini];
338+
f[mini] = 0.0f;
339+
}
340+
341+
const float inv_remaining = 1.f / remaining;
342+
#pragma unroll
343+
for (int i = 0; i < 128; ++i) f[i] *= inv_remaining;
344+
345+
#pragma unroll
346+
for (int v = 0; v < 16; ++v)
347+
{
348+
const int base = v * 8;
349+
350+
half2_uint32 h0, h1, h2, h3;
351+
h0.as_half2 = __floats2half2_rn(f[base + 0], f[base + 1]);
352+
h1.as_half2 = __floats2half2_rn(f[base + 2], f[base + 3]);
353+
h2.as_half2 = __floats2half2_rn(f[base + 4], f[base + 5]);
354+
h3.as_half2 = __floats2half2_rn(f[base + 6], f[base + 7]);
355+
356+
int4 v4;
357+
v4.x = h0.as_uint32;
358+
v4.y = h1.as_uint32;
359+
v4.z = h2.as_uint32;
360+
v4.w = h3.as_uint32;
361+
362+
row_ptr[v] = v4;
363+
}
364+
}

exllamav2/moe_mlp.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,7 @@ def scratch_space_fixed(self) -> int:
167167

168168
def scratch_space(self) -> int:
169169

170-
assert self.model.config.intermediate_size >= self.model.config.hidden_size
170+
# assert self.model.config.intermediate_size >= self.model.config.hidden_size
171171
return self.temp_state_size() + \
172172
self.temp_gathered_state_size() + \
173173
self.temp_a_size() + \
@@ -235,7 +235,7 @@ def forward(
235235
# TODO: LoRA currently uses the Torch codepath. Needs conditional (early-exit) kernels with output scaling
236236
# for the LoRA matmuls in order to work with the C++ path
237237

238-
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16] or (loras is not None and len(loras) > 0):
238+
if self.q_handle is None or intermediates or batch_size * sequence_length > 4 or self.num_experts not in [4, 8, 16, 128] or (loras is not None and len(loras) > 0):
239239
return self.forward_torch(hidden_states, cache, attn_params, past_len, intermediates, loras = loras, **kwargs)
240240

241241
# if loras is None or self.temp_lora_size == 0:

exllamav2/vlm/vision_tower.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def __init__(
4242
km = self.archparams.keys
4343
self.modules = []
4444

45+
self.tp_context = None
46+
4547
# Preprocessor
4648

4749
if cfg.vision_model_type == "pixtral":

0 commit comments

Comments
 (0)