Skip to content

Issue with parameter passing of fp8_index in index.forward in model.py #43

@Rrostyy

Description

@Rrostyy

I’ve encountered an issue regarding the parameter passing of fp8_index in the index.forward method in model.py. Specifically, the scaling factor for q (referred to as weight in the code) has a shape mismatch between the model and the kernel.

In the model, the shape of the scaling factor (q) is (B, S, index_H, index_D_h / block_size), while in the kernel, specifically in the signature of fp8_index_kernel_, the shape of q_s is (b, m, h).

Could you please clarify the design intentions behind this discrepancy? I’m particularly interested in understanding the rationale behind this shape mismatch.

Feel free to adjust the tone or add more specifics if needed!

model.py/
        # (B, S, index_n_head) -> (B, S, index_n_head, index_d // block_size)
        weights = self.weights_proj(x) * self.n_heads ** -0.5

        # score = (q_fp8*q_scale @ k_fp8*k_scale.T) / sqrt(d_k)
        #       = q_fp8@k_fp8 * q_scale * k_sace * 1/sqrt(d_k)
        weights = weights.unsqueeze(-1) * q_scale * self.softmax_scale
        index_score = fp8_index(q_fp8.contiguous(), weights, 
                                self.k_cache[:bsz, :end_pos].contiguous(), 
                                self.k_scale_cache[:bsz, :end_pos].contiguous())

kernel.py
@T.prim_func
    def fp8_index_kernel_(
        q: T.Tensor[(b, m, h, d), FP8],
        q_s: T.Tensor[(b, m, h), FP32],
        k: T.Tensor[(b, n, d), FP8],
        k_s: T.Tensor[(b, n), FP32],
        o: T.Tensor[(b, m, n), FP32],
    ) -> None:

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions