Skip to content

[FEA] CVT F32 -> TF32 PTX for sm80 #2254

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
osayamenja opened this issue Apr 19, 2025 · 0 comments
Open

[FEA] CVT F32 -> TF32 PTX for sm80 #2254

osayamenja opened this issue Apr 19, 2025 · 0 comments
Labels

Comments

@osayamenja
Copy link

osayamenja commented Apr 19, 2025

Is your feature request related to a problem? Please describe.
Currently, converting from tf32 to f32 with round to nearest dispatches to a PTX cvt instruction only for sm90.

Describe the solution you'd like
If we allow rna rounding, we can dispatch to cvt.rna.tf32.f32, which works for sm80.

Describe alternatives you've considered
N/A

Additional context
A simple code sample is given below:

__global__ void f2tfK() {
    constexpr float x = -0.45466f;
    uint32_t d = 0;
    constexpr auto f2tf = cutlass::NumericConverter<cutlass::tfloat32_t, float>{};
    asm volatile("cvt.rna.tf32.f32 %0, %1;" : "=r"(d) : "f"(x));
    const auto res = cutlass::tfloat32_t::bitcast(d);
    const auto cRes = f2tf(x);
    printf("Intrinsic: "); cute::print(res); printf("\n");
    printf("Other: "); cute::print(cRes); printf("\n");
    printf("isEqual? %s\n", cRes == res ? "yes" : "no");
}
// Output: 
// Intrinsic: -0.454590
// Other: -0.454590
// isEqual? yes
@osayamenja osayamenja changed the title [FEA] CVT TF32 -> F32 PTX for sm80 [FEA] CVT F32 -> TF32 PTX for sm80 Apr 19, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

1 participant