Skip to content

Commit d3b75fb

Browse files
committed
Revert "Revert "SWDEV-446047 - initial version of fp8 documentation""
This reverts commit 57b22e1. Change-Id: I66cf95b5bd9edcab7789b42330c2d666e77a037b
1 parent 57b22e1 commit d3b75fb

File tree

3 files changed

+230
-0
lines changed

3 files changed

+230
-0
lines changed

docs/index.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ On non-AMD platforms, like NVIDIA, HIP provides header files required to support
5656
* [Comparing Syntax for different APIs](./reference/terms)
5757
* [HSA Runtime API for ROCm](./reference/virtual_rocr)
5858
* [List of deprecated APIs](./reference/deprecated_api_list)
59+
* [FP8 numbers in HIP](./reference/fp8_numbers)
5960

6061
:::
6162

docs/reference/fp8_numbers.rst

Lines changed: 227 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,227 @@
1+
.. meta::
2+
:description: This page describes FP8 numbers present in HIP.
3+
:keywords: AMD, ROCm, HIP, fp8, fnuz, ocp
4+
5+
*******************************************************************************
6+
FP8 Numbers
7+
*******************************************************************************
8+
9+
`FP8 numbers <https://arxiv.org/pdf/2209.05433>`_ were introduced to accelerate deep learning inferencing. They provide higher throughput of matrix operations because the smaller size allows more of them in the available fixed memory.
10+
11+
HIP has two FP8 number representations called *FP8-OCP* and *FP8-FNUZ*.
12+
13+
Open Compute Project(OCP) number definition can be found `here <https://www.opencompute.org/documents/ocp-8-bit-floating-point-specification-ofp8-revision-1-0-2023-12-01-pdf-1>`_.
14+
15+
Definition of FNUZ: fnuz suffix means only finite and NaN values are supported. Unlike other types, Inf are not supported.
16+
NaN is when sign bit is set and all other exponent and mantissa bits are 0. All other values are finite.
17+
This provides one extra value of exponent and adds to the range of supported FP8 numbers.
18+
19+
FP8 Definition
20+
==============
21+
22+
FP8 numbers are composed of a sign, an exponent and a mantissa. Their sizes are dependent on the format.
23+
There are two formats of FP8 numbers, E4M3 and E5M2.
24+
25+
- E4M3: 1 bit sign, 4 bit exponent, 3 bit mantissa
26+
- E5M2: 1 bit sign, 5 bit exponent, 2 bit mantissa
27+
28+
HIP Header
29+
==========
30+
31+
HIP header defined the FP8 ocp/fnuz numbers `here <https://github.com/ROCm/clr/blob/develop/hipamd/include/hip/amd_detail/amd_hip_fp8.h>`_.
32+
33+
Supported Devices
34+
=================
35+
36+
.. list-table:: Supported devices for fp8 numbers
37+
:header-rows: 1
38+
39+
* - Device Type
40+
- FNUZ FP8
41+
- OCP FP8
42+
* - Host
43+
- Yes
44+
- Yes
45+
* - gfx940/gfx941/gfx942
46+
- Yes
47+
- No
48+
49+
Usage
50+
=====
51+
52+
To use the FP8 numbers inside HIP programs.
53+
54+
.. code-block:: c
55+
56+
#include <hip/hip_fp8.h>
57+
58+
FP8 numbers can be used on CPU side:
59+
60+
.. code-block:: c
61+
62+
__hip_fp8_storage_t convert_float_to_fp8(
63+
float in, /* Input val */
64+
__hip_fp8_interpretation_t interpret, /* interpretation of number E4M3/E5M2 */
65+
__hip_saturation_t sat /* Saturation behavior */
66+
) {
67+
return __hip_cvt_float_to_fp8(in, sat, interpret);
68+
}
69+
70+
The same can be done in kernels as well.
71+
72+
.. code-block:: c
73+
74+
__device__ __hip_fp8_storage_t d_convert_float_to_fp8(
75+
float in,
76+
__hip_fp8_interpretation_t interpret,
77+
__hip_saturation_t sat) {
78+
return __hip_cvt_float_to_fp8(in, sat, interpret);
79+
}
80+
81+
An important thing to note here is if you use this on gfx94x GPU, it will be fnuz number but on any other GPU it will be an OCP number.
82+
83+
The following code example does roundtrip FP8 conversions on both the CPU and GPU and compares the results.
84+
85+
.. code-block:: c
86+
87+
#include <hip/hip_fp8.h>
88+
#include <hip/hip_runtime.h>
89+
#include <iostream>
90+
#include <vector>
91+
92+
#define hip_check(hip_call) \
93+
{ \
94+
auto hip_res = hip_call; \
95+
if (hip_res != hipSuccess) { \
96+
std::cerr << "Failed in hip call: " << #hip_call \
97+
<< " with error: " << hipGetErrorName(hip_res) << std::endl; \
98+
std::abort(); \
99+
} \
100+
}
101+
102+
__device__ __hip_fp8_storage_t d_convert_float_to_fp8(
103+
float in, __hip_fp8_interpretation_t interpret, __hip_saturation_t sat) {
104+
return __hip_cvt_float_to_fp8(in, sat, interpret);
105+
}
106+
107+
__device__ float d_convert_fp8_to_float(float in,
108+
__hip_fp8_interpretation_t interpret) {
109+
__half hf = __hip_cvt_fp8_to_halfraw(in, interpret);
110+
return hf;
111+
}
112+
113+
__global__ void float_to_fp8_to_float(float *in,
114+
__hip_fp8_interpretation_t interpret,
115+
__hip_saturation_t sat, float *out,
116+
size_t size) {
117+
int i = threadIdx.x;
118+
if (i < size) {
119+
auto fp8 = d_convert_float_to_fp8(in[i], interpret, sat);
120+
out[i] = d_convert_fp8_to_float(fp8, interpret);
121+
}
122+
}
123+
124+
__hip_fp8_storage_t
125+
convert_float_to_fp8(float in, /* Input val */
126+
__hip_fp8_interpretation_t
127+
interpret, /* interpretation of number E4M3/E5M2 */
128+
__hip_saturation_t sat /* Saturation behavior */
129+
) {
130+
return __hip_cvt_float_to_fp8(in, sat, interpret);
131+
}
132+
133+
float convert_fp8_to_float(
134+
__hip_fp8_storage_t in, /* Input val */
135+
__hip_fp8_interpretation_t
136+
interpret /* interpretation of number E4M3/E5M2 */
137+
) {
138+
__half hf = __hip_cvt_fp8_to_halfraw(in, interpret);
139+
return hf;
140+
}
141+
142+
int main() {
143+
constexpr size_t size = 32;
144+
hipDeviceProp_t prop;
145+
hip_check(hipGetDeviceProperties(&prop, 0));
146+
bool is_supported = (std::string(prop.gcnArchName).find("gfx94") != std::string::npos) || // gfx94x
147+
(std::string(prop.gcnArchName).find("gfx120") != std::string::npos); // gfx120x
148+
if(!is_supported) {
149+
std::cerr << "Need a gfx94x or gfx120x, but found: " << prop.gcnArchName << std::endl;
150+
std::cerr << "No device conversions are supported, only host conversions are supported." << std::endl;
151+
return -1;
152+
}
153+
154+
const __hip_fp8_interpretation_t interpret = (std::string(prop.gcnArchName).find("gfx94") != std::string::npos)
155+
? __HIP_E4M3_FNUZ // gfx94x
156+
: __HIP_E4M3; // gfx120x
157+
constexpr __hip_saturation_t sat = __HIP_SATFINITE;
158+
159+
std::vector<float> in;
160+
in.reserve(size);
161+
for (size_t i = 0; i < size; i++) {
162+
in.push_back(i + 1.1f);
163+
}
164+
165+
std::cout << "Converting float to fp8 and back..." << std::endl;
166+
// CPU convert
167+
std::vector<float> cpu_out;
168+
cpu_out.reserve(size);
169+
for (const auto &fval : in) {
170+
auto fp8 = convert_float_to_fp8(fval, interpret, sat);
171+
cpu_out.push_back(convert_fp8_to_float(fp8, interpret));
172+
}
173+
174+
// GPU convert
175+
float *d_in, *d_out;
176+
hip_check(hipMalloc(&d_in, sizeof(float) * size));
177+
hip_check(hipMalloc(&d_out, sizeof(float) * size));
178+
179+
hip_check(hipMemcpy(d_in, in.data(), sizeof(float) * in.size(),
180+
hipMemcpyHostToDevice));
181+
182+
float_to_fp8_to_float<<<1, size>>>(d_in, interpret, sat, d_out, size);
183+
184+
std::vector<float> gpu_out(size, 0.0f);
185+
hip_check(hipMemcpy(gpu_out.data(), d_out, sizeof(float) * gpu_out.size(),
186+
hipMemcpyDeviceToHost));
187+
188+
hip_check(hipFree(d_in));
189+
hip_check(hipFree(d_out));
190+
191+
// Validation
192+
for (size_t i = 0; i < size; i++) {
193+
if (cpu_out[i] != gpu_out[i]) {
194+
std::cerr << "cpu round trip result: " << cpu_out[i]
195+
<< " - gpu round trip result: " << gpu_out[i] << std::endl;
196+
std::abort();
197+
}
198+
}
199+
std::cout << "...CPU and GPU round trip convert matches." << std::endl;
200+
}
201+
202+
There are C++ style classes available as well.
203+
204+
.. code-block:: c
205+
206+
__hip_fp8_e4m3_fnuz fp8_val(1.1f); // gfx94x
207+
__hip_fp8_e4m3 fp8_val(1.1f); // gfx120x
208+
209+
Each type of FP8 number has its own class:
210+
211+
- __hip_fp8_e4m3
212+
- __hip_fp8_e5m2
213+
- __hip_fp8_e4m3_fnuz
214+
- __hip_fp8_e5m2_fnuz
215+
216+
There is support of vector of FP8 types.
217+
218+
- __hip_fp8x2_e4m3: holds 2 values of OCP FP8 e4m3 numbers
219+
- __hip_fp8x4_e4m3: holds 4 values of OCP FP8 e4m3 numbers
220+
- __hip_fp8x2_e5m2: holds 2 values of OCP FP8 e5m2 numbers
221+
- __hip_fp8x4_e5m2: holds 4 values of OCP FP8 e5m2 numbers
222+
- __hip_fp8x2_e4m3_fnuz: holds 2 values of FP8 fnuz e4m3 numbers
223+
- __hip_fp8x4_e4m3_fnuz: holds 4 values of FP8 fnuz e4m3 numbers
224+
- __hip_fp8x2_e5m2_fnuz: holds 2 values of FP8 fnuz e5m2 numbers
225+
- __hip_fp8x4_e5m2_fnuz: holds 4 values of FP8 fnuz e5m2 numbers
226+
227+
FNUZ extensions will be available on gfx94x only.

docs/sphinx/_toc.yml.in

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ subtrees:
4141
- file: reference/virtual_rocr
4242
- file: reference/deprecated_api_list
4343
title: List of deprecated APIs
44+
- file: reference/fp8_numbers
45+
title: FP8 numbers in HIP
4446

4547
- caption: Tutorials
4648
entries:

0 commit comments

Comments
 (0)