Skip to content

Commit 75c8622

Browse files
cjatinmangupta
authored andcommitted
SWDEV-446047 - initial version of fp8 documentation
Change-Id: I9d9d45435534381d22dff4d4860549e861d52c3f (cherry picked from commit 2b36235)
1 parent 48ab3c3 commit 75c8622

File tree

3 files changed

+233
-0
lines changed

3 files changed

+233
-0
lines changed

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ On non-AMD platforms, like NVIDIA, HIP provides header files required to support
5656
* [Comparing Syntax for different APIs](./reference/terms)
5757
* [HSA Runtime API for ROCm](./reference/virtual_rocr)
5858
* [List of deprecated APIs](./reference/deprecated_api_list)
59+
* [FP8 numbers in HIP](./reference/fp8_numbers)
5960

6061
:::
6162

docs/reference/fp8_numbers.rst

Lines changed: 230 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,230 @@
1+
.. meta::
2+
:description: This page describes FP8 numbers present in HIP.
3+
:keywords: AMD, ROCm, HIP, fp8, fnuz, ocp
4+
5+
*******************************************************************************
6+
FP8 Numbers
7+
*******************************************************************************
8+
9+
`FP8 numbers <https://arxiv.org/pdf/2209.05433>`_ were introduced to accelerate deep learning inferencing. They provide higher throughput of matrix operations because the smaller size allows more of them in the available fixed memory.
10+
11+
HIP has two FP8 number representations called *FP8-OCP* and *FP8-FNUZ*.
12+
13+
Open Compute Project(OCP) number definition can be found `here <https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1>`_.
14+
15+
Definition of FNUZ: fnuz suffix means only finite and NaN values are supported. Unlike other types, Inf are not supported.
16+
NaN is when sign bit is set and all other exponent and mantissa bits are 0. All other values are finite.
17+
This provides one extra value of exponent and adds to the range of supported FP8 numbers.
18+
19+
FP8 Definition
20+
==============
21+
22+
FP8 numbers are composed of a sign, an exponent and a mantissa. Their sizes are dependent on the format.
23+
There are two formats of FP8 numbers, E4M3 and E5M2.
24+
25+
- E4M3: 1 bit sign, 4 bit exponent, 3 bit mantissa
26+
- E5M2: 1 bit sign, 5 bit exponent, 2 bit mantissa
27+
28+
HIP Header
29+
==========
30+
31+
HIP header defined the FP8 ocp/fnuz numbers `here <https://github.com/ROCm/clr/blob/develop/hipamd/include/hip/amd_detail/amd_hip_fp8.h>`_.
32+
33+
Supported Devices
34+
=================
35+
36+
.. list-table:: Supported devices for fp8 numbers
37+
:header-rows: 1
38+
39+
* - Device Type
40+
- FNUZ FP8
41+
- OCP FP8
42+
* - Host
43+
- Yes
44+
- Yes
45+
* - gfx940/gfx941/gfx942
46+
- Yes
47+
- No
48+
* - gfx1200/gfx1201
49+
- No
50+
- Yes
51+
52+
Usage
53+
=====
54+
55+
To use the FP8 numbers inside HIP programs.
56+
57+
.. code-block:: c
58+
59+
#include <hip/hip_fp8.h>
60+
61+
FP8 numbers can be used on CPU side:
62+
63+
.. code-block:: c
64+
65+
__hip_fp8_storage_t convert_float_to_fp8(
66+
float in, /* Input val */
67+
__hip_fp8_interpretation_t interpret, /* interpretation of number E4M3/E5M2 */
68+
__hip_saturation_t sat /* Saturation behavior */
69+
) {
70+
return __hip_cvt_float_to_fp8(in, sat, interpret);
71+
}
72+
73+
The same can be done in kernels as well.
74+
75+
.. code-block:: c
76+
77+
__device__ __hip_fp8_storage_t d_convert_float_to_fp8(
78+
float in,
79+
__hip_fp8_interpretation_t interpret,
80+
__hip_saturation_t sat) {
81+
return __hip_cvt_float_to_fp8(in, sat, interpret);
82+
}
83+
84+
An important thing to note here is if you use this on gfx94x GPU, it will be fnuz number but on any other GPU it will be an OCP number.
85+
86+
The following code example does roundtrip FP8 conversions on both the CPU and GPU and compares the results.
87+
88+
.. code-block:: c
89+
90+
#include <hip/hip_fp8.h>
91+
#include <hip/hip_runtime.h>
92+
#include <iostream>
93+
#include <vector>
94+
95+
#define hip_check(hip_call) \
96+
{ \
97+
auto hip_res = hip_call; \
98+
if (hip_res != hipSuccess) { \
99+
std::cerr << "Failed in hip call: " << #hip_call \
100+
<< " with error: " << hipGetErrorName(hip_res) << std::endl; \
101+
std::abort(); \
102+
} \
103+
}
104+
105+
__device__ __hip_fp8_storage_t d_convert_float_to_fp8(
106+
float in, __hip_fp8_interpretation_t interpret, __hip_saturation_t sat) {
107+
return __hip_cvt_float_to_fp8(in, sat, interpret);
108+
}
109+
110+
__device__ float d_convert_fp8_to_float(float in,
111+
__hip_fp8_interpretation_t interpret) {
112+
__half hf = __hip_cvt_fp8_to_halfraw(in, interpret);
113+
return hf;
114+
}
115+
116+
__global__ void float_to_fp8_to_float(float *in,
117+
__hip_fp8_interpretation_t interpret,
118+
__hip_saturation_t sat, float *out,
119+
size_t size) {
120+
int i = threadIdx.x;
121+
if (i < size) {
122+
auto fp8 = d_convert_float_to_fp8(in[i], interpret, sat);
123+
out[i] = d_convert_fp8_to_float(fp8, interpret);
124+
}
125+
}
126+
127+
__hip_fp8_storage_t
128+
convert_float_to_fp8(float in, /* Input val */
129+
__hip_fp8_interpretation_t
130+
interpret, /* interpretation of number E4M3/E5M2 */
131+
__hip_saturation_t sat /* Saturation behavior */
132+
) {
133+
return __hip_cvt_float_to_fp8(in, sat, interpret);
134+
}
135+
136+
float convert_fp8_to_float(
137+
__hip_fp8_storage_t in, /* Input val */
138+
__hip_fp8_interpretation_t
139+
interpret /* interpretation of number E4M3/E5M2 */
140+
) {
141+
__half hf = __hip_cvt_fp8_to_halfraw(in, interpret);
142+
return hf;
143+
}
144+
145+
int main() {
146+
constexpr size_t size = 32;
147+
hipDeviceProp_t prop;
148+
hip_check(hipGetDeviceProperties(&prop, 0));
149+
bool is_supported = (std::string(prop.gcnArchName).find("gfx94") != std::string::npos) || // gfx94x
150+
(std::string(prop.gcnArchName).find("gfx120") != std::string::npos); // gfx120x
151+
if(!is_supported) {
152+
std::cerr << "Need a gfx94x or gfx120x, but found: " << prop.gcnArchName << std::endl;
153+
std::cerr << "No device conversions are supported, only host conversions are supported." << std::endl;
154+
return -1;
155+
}
156+
157+
const __hip_fp8_interpretation_t interpret = (std::string(prop.gcnArchName).find("gfx94") != std::string::npos)
158+
? __HIP_E4M3_FNUZ // gfx94x
159+
: __HIP_E4M3; // gfx120x
160+
constexpr __hip_saturation_t sat = __HIP_SATFINITE;
161+
162+
std::vector<float> in;
163+
in.reserve(size);
164+
for (size_t i = 0; i < size; i++) {
165+
in.push_back(i + 1.1f);
166+
}
167+
168+
std::cout << "Converting float to fp8 and back..." << std::endl;
169+
// CPU convert
170+
std::vector<float> cpu_out;
171+
cpu_out.reserve(size);
172+
for (const auto &fval : in) {
173+
auto fp8 = convert_float_to_fp8(fval, interpret, sat);
174+
cpu_out.push_back(convert_fp8_to_float(fp8, interpret));
175+
}
176+
177+
// GPU convert
178+
float *d_in, *d_out;
179+
hip_check(hipMalloc(&d_in, sizeof(float) * size));
180+
hip_check(hipMalloc(&d_out, sizeof(float) * size));
181+
182+
hip_check(hipMemcpy(d_in, in.data(), sizeof(float) * in.size(),
183+
hipMemcpyHostToDevice));
184+
185+
float_to_fp8_to_float<<<1, size>>>(d_in, interpret, sat, d_out, size);
186+
187+
std::vector<float> gpu_out(size, 0.0f);
188+
hip_check(hipMemcpy(gpu_out.data(), d_out, sizeof(float) * gpu_out.size(),
189+
hipMemcpyDeviceToHost));
190+
191+
hip_check(hipFree(d_in));
192+
hip_check(hipFree(d_out));
193+
194+
// Validation
195+
for (size_t i = 0; i < size; i++) {
196+
if (cpu_out[i] != gpu_out[i]) {
197+
std::cerr << "cpu round trip result: " << cpu_out[i]
198+
<< " - gpu round trip result: " << gpu_out[i] << std::endl;
199+
std::abort();
200+
}
201+
}
202+
std::cout << "...CPU and GPU round trip convert matches." << std::endl;
203+
}
204+
205+
There are C++ style classes available as well.
206+
207+
.. code-block:: c
208+
209+
__hip_fp8_e4m3_fnuz fp8_val(1.1f); // gfx94x
210+
__hip_fp8_e4m3 fp8_val(1.1f); // gfx120x
211+
212+
Each type of FP8 number has its own class:
213+
214+
- __hip_fp8_e4m3
215+
- __hip_fp8_e5m2
216+
- __hip_fp8_e4m3_fnuz
217+
- __hip_fp8_e5m2_fnuz
218+
219+
There is support of vector of FP8 types.
220+
221+
- __hip_fp8x2_e4m3: holds 2 values of OCP FP8 e4m3 numbers
222+
- __hip_fp8x4_e4m3: holds 4 values of OCP FP8 e4m3 numbers
223+
- __hip_fp8x2_e5m2: holds 2 values of OCP FP8 e5m2 numbers
224+
- __hip_fp8x4_e5m2: holds 4 values of OCP FP8 e5m2 numbers
225+
- __hip_fp8x2_e4m3_fnuz: holds 2 values of FP8 fnuz e4m3 numbers
226+
- __hip_fp8x4_e4m3_fnuz: holds 4 values of FP8 fnuz e4m3 numbers
227+
- __hip_fp8x2_e5m2_fnuz: holds 2 values of FP8 fnuz e5m2 numbers
228+
- __hip_fp8x4_e5m2_fnuz: holds 4 values of FP8 fnuz e5m2 numbers
229+
230+
FNUZ extensions will be available on gfx94x only.

docs/sphinx/_toc.yml.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ subtrees:
4141
- file: reference/virtual_rocr
4242
- file: reference/deprecated_api_list
4343
title: List of deprecated APIs
44+
- file: reference/fp8_numbers
45+
title: FP8 numbers in HIP
4446

4547
- caption: Tutorials
4648
entries:

0 commit comments

Comments
 (0)