-
Notifications
You must be signed in to change notification settings - Fork 69
Open
Description
I’ve encountered an issue regarding the parameter passing of fp8_index in the index.forward method in model.py. Specifically, the scaling factor for q (referred to as weight in the code) has a shape mismatch between the model and the kernel.
In the model, the shape of the scaling factor (q) is (B, S, index_H, index_D_h / block_size), while in the kernel, specifically in the signature of fp8_index_kernel_, the shape of q_s is (b, m, h).
Could you please clarify the design intentions behind this discrepancy? I’m particularly interested in understanding the rationale behind this shape mismatch.
Feel free to adjust the tone or add more specifics if needed!
model.py/
# (B, S, index_n_head) -> (B, S, index_n_head, index_d // block_size)
weights = self.weights_proj(x) * self.n_heads ** -0.5
# score = (q_fp8*q_scale @ k_fp8*k_scale.T) / sqrt(d_k)
# = q_fp8@k_fp8 * q_scale * k_sace * 1/sqrt(d_k)
weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
index_score = fp8_index(q_fp8.contiguous(), weights,
self.k_cache[:bsz, :end_pos].contiguous(),
self.k_scale_cache[:bsz, :end_pos].contiguous())
kernel.py
@T.prim_func
def fp8_index_kernel_(
q: T.Tensor[(b, m, h, d), FP8],
q_s: T.Tensor[(b, m, h), FP32],
k: T.Tensor[(b, n, d), FP8],
k_s: T.Tensor[(b, n), FP32],
o: T.Tensor[(b, m, n), FP32],
) -> None:
Carlomus
Metadata
Metadata
Assignees
Labels
No labels