-
Notifications
You must be signed in to change notification settings - Fork 112
MXFP4 reference implementation and parsing MXFixNeuron #4111
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
48 commits
Select commit
Hold shift + click to select a range
9166a80
initial progress
CharlieL7 003f329
progress 2
CharlieL7 dbb8b09
unpack work
CharlieL7 a3165ac
more progress
CharlieL7 09a4e47
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 0d92773
comments on what to do
CharlieL7 d4169f8
mm progress
CharlieL7 bed35c0
Merge branch 'mxfp4_pack_unpack' of github.com:ROCm/AMDMIGraphX into …
CharlieL7 a7e2715
parse and test progress
CharlieL7 a5d11f1
Add tests and fixes
CharlieL7 e880b5c
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 2741ed7
Fix parsing
CharlieL7 ffa4c00
Fix parsing and formatting
CharlieL7 cf8edf7
remove fp4_type and just use float_type
CharlieL7 e15bf89
Change parsing to work with odd number of elements
CharlieL7 5d74c7b
progress on tests
CharlieL7 8d1bb58
Fix errors and make tests
CharlieL7 2310272
Make onnx verify test and fix abs bug
CharlieL7 9f3f6a6
tidy fixes
CharlieL7 4062593
Fix api calls with new enum type
CharlieL7 78ecbd9
formatting
CharlieL7 564ffc7
some cleanup
CharlieL7 3b29957
simplify fp4_to_float with switch statement
CharlieL7 6fc0bf2
rename to packed_fp4 to fpx2
CharlieL7 83ac88e
Change switch case to static lookup table
CharlieL7 1620e23
review edits
CharlieL7 a5a6d57
review edits
CharlieL7 afade85
tracking commit
CharlieL7 357fcad
Change to using comparisons for float to fp4
CharlieL7 be8b610
Merge branch 'mxfp4_pack_unpack' of github.com:ROCm/AMDMIGraphX into …
CharlieL7 2e23dc8
add constexpr tests
CharlieL7 575fd0e
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 537317a
remove constexpr for float_to_fp4
CharlieL7 a7c0966
tidy fixes
CharlieL7 f9a505e
formatting
CharlieL7 6468662
Fix compilation
CharlieL7 b9f3215
fix hip_gemm switch statement
CharlieL7 896d4fc
More tidy fixes
CharlieL7 1bdadb0
formatting
CharlieL7 4536d58
more review comments
CharlieL7 04eed95
Remove fp4_casts.cpp from cmakelists
CharlieL7 c733c34
Add note on float_to_fp4 code
CharlieL7 05005b4
re-disable constexpr float_to_fp4
CharlieL7 aceec60
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 161a54a
Make float_to_fp4 not constexpr again
CharlieL7 ee236da
make inline instead
CharlieL7 92d7978
use alternative lookup for float_to_fp4, add back constexpr test
CharlieL7 c64285b
Update src/include/migraphx/op/pack_fp4.hpp
CharlieL7 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,80 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
| #ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT4_CASTS_HPP | ||
| #define MIGRAPHX_GUARD_RTGLIB_FLOAT4_CASTS_HPP | ||
|
|
||
| #include <cstdint> | ||
| #include <algorithm> | ||
| #include <array> | ||
| #include <cmath> | ||
| #include <iterator> | ||
| #include <migraphx/errors.hpp> | ||
|
|
||
| namespace migraphx { | ||
| inline namespace MIGRAPHX_INLINE_NS { | ||
|
|
||
| namespace fp4_detail { | ||
| static constexpr std::array<float, 16> fp4_lut = { | ||
| 0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0}; | ||
|
|
||
| // pair is {fp4_tie_value, round_to_zero} | ||
| // if round_to_zero round tie towards zero, else round tie away from zero | ||
| static constexpr std::array<std::pair<float, uint8_t>, 7> fp4_even_round = { | ||
| {{0.25, 1}, {0.75, 0}, {1.25, 1}, {1.75, 0}, {2.5, 1}, {3.5, 0}, {5, 1}}}; | ||
| } // namespace fp4_detail | ||
|
|
||
| // converts 4 LSB to float | ||
| constexpr float fp4_to_float(uint8_t x) | ||
| { | ||
| return fp4_detail::fp4_lut[x % fp4_detail::fp4_lut.size()]; | ||
| } | ||
|
|
||
| // rounding mode = roundToNearestRoundTiesToEven | ||
| // Reference quantization code from Microsoft: | ||
| // https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py#L82 | ||
| // Not constexpr because std::signbit is not constexpr until C++23 | ||
| inline uint8_t float_to_fp4(float f_x) | ||
| { | ||
| using fp4_detail::fp4_even_round; | ||
| using fp4_detail::fp4_lut; | ||
| if(std::isnan(f_x)) | ||
| { | ||
| return 0; | ||
| } | ||
| bool sign = std::signbit(f_x); | ||
| uint8_t sign_add = sign ? fp4_lut.size() / 2 : 0u; | ||
| float abs_f = std::abs(f_x); | ||
| // index value is the positive fp4 value | ||
| uint8_t i = std::upper_bound(fp4_even_round.begin(), | ||
| fp4_even_round.end(), | ||
| std::make_pair(abs_f, uint8_t{0})) - | ||
| fp4_even_round.begin(); | ||
|
|
||
| return i + sign_add; | ||
| } | ||
|
|
||
| } // namespace MIGRAPHX_INLINE_NS | ||
| } // namespace migraphx | ||
|
|
||
| #endif | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
| #ifndef MIGRAPHX_GUARD_OPERATORS_PACK_FP4_HPP | ||
| #define MIGRAPHX_GUARD_OPERATORS_PACK_FP4_HPP | ||
|
|
||
| #include <migraphx/check_shapes.hpp> | ||
| #include <migraphx/op/normalize_attribute.hpp> | ||
| #include <migraphx/argument.hpp> | ||
| #include <migraphx/par_for.hpp> | ||
| #include <migraphx/fp4_casts.hpp> | ||
|
|
||
| namespace migraphx { | ||
| inline namespace MIGRAPHX_INLINE_NS { | ||
| namespace op { | ||
|
|
||
| struct pack_fp4 | ||
| { | ||
| int64_t axis = -1; | ||
|
|
||
| std::string name() const { return "pack_fp4"; } | ||
|
|
||
| value attributes() const | ||
| { | ||
| value normalize = value::object{}; | ||
| normalize["axis"] = value::array{normalize_attribute::include_min}; | ||
| return {{"normalize_axes", normalize}}; | ||
| } | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.axis, "axis")); | ||
| } | ||
|
|
||
| migraphx::shape normalize_compute_shape(std::vector<migraphx::shape> inputs) const | ||
| { | ||
| check_shapes{inputs, *this}.same_dims().has(1); | ||
| const auto& in_shape = inputs.front(); | ||
| auto new_lens = in_shape.lens(); | ||
| if(new_lens[axis] % 2 != 0) | ||
| { | ||
| MIGRAPHX_THROW("PACK_FP4: Can not pack axis that has odd lengths"); | ||
| } | ||
| new_lens[axis] /= 2; | ||
| return {migraphx::shape::fp4x2_type, new_lens}; | ||
| } | ||
|
|
||
| argument compute(const shape& output_shape, const std::vector<argument>& args) const | ||
| { | ||
| auto input = args.front(); | ||
| auto in_shape = input.get_shape(); | ||
|
|
||
| migraphx::shape uint8_shape = shape{migraphx::shape::uint8_type, output_shape.lens()}; | ||
| argument uint8_arg{uint8_shape}; | ||
| uint8_arg.visit([&](auto out) { | ||
| input.visit([&](auto inp) { | ||
| par_for(output_shape.elements(), [&](auto i) { | ||
| using inp_type = typename decltype(inp)::value_type; | ||
| auto data_idx = output_shape.multi(i); | ||
| auto in_data_multi_idx = data_idx; | ||
| in_data_multi_idx[axis] *= 2; | ||
| inp_type inp_val0 = inp[in_data_multi_idx]; | ||
| in_data_multi_idx[axis] += 1; | ||
| inp_type inp_val1 = inp[in_data_multi_idx]; | ||
| uint8_t out_val0 = float_to_fp4(inp_val0); | ||
| uint8_t out_val1 = float_to_fp4(inp_val1); | ||
| // NOTE: integral promotion occurs when bitshifting for uint8_t | ||
| out[i] = static_cast<uint8_t>(out_val1 << 4u) | | ||
| static_cast<uint8_t>(out_val0 & 0xFu); | ||
| }); | ||
| }); | ||
| }); | ||
| migraphx::argument result = | ||
| uint8_arg.reshape({migraphx::shape::fp4x2_type, output_shape.lens()}); | ||
| return result; | ||
| } | ||
| }; | ||
|
|
||
| } // namespace op | ||
| } // namespace MIGRAPHX_INLINE_NS | ||
| } // namespace migraphx | ||
|
|
||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,103 @@ | ||
| /* | ||
| * The MIT License (MIT) | ||
| * | ||
| * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | ||
| * | ||
| * Permission is hereby granted, free of charge, to any person obtaining a copy | ||
| * of this software and associated documentation files (the "Software"), to deal | ||
| * in the Software without restriction, including without limitation the rights | ||
| * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | ||
| * copies of the Software, and to permit persons to whom the Software is | ||
| * furnished to do so, subject to the following conditions: | ||
| * | ||
| * The above copyright notice and this permission notice shall be included in | ||
| * all copies or substantial portions of the Software. | ||
| * | ||
| * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | ||
| * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | ||
| * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE | ||
| * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | ||
| * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | ||
| * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | ||
| * THE SOFTWARE. | ||
| */ | ||
| #ifndef MIGRAPHX_GUARD_OPERATORS_UNPACK_FP4_HPP | ||
| #define MIGRAPHX_GUARD_OPERATORS_UNPACK_FP4_HPP | ||
|
|
||
| #include <migraphx/check_shapes.hpp> | ||
| #include <migraphx/op/normalize_attribute.hpp> | ||
| #include <migraphx/shape.hpp> | ||
| #include <migraphx/config.hpp> | ||
| #include <migraphx/par_for.hpp> | ||
| #include <migraphx/argument.hpp> | ||
| #include <migraphx/fp4_casts.hpp> | ||
|
|
||
| namespace migraphx { | ||
| inline namespace MIGRAPHX_INLINE_NS { | ||
|
|
||
| namespace op { | ||
| struct unpack_fp4 | ||
| { | ||
| int64_t axis = -1; | ||
|
|
||
| std::string name() const { return "unpack_fp4"; } | ||
|
|
||
| value attributes() const | ||
| { | ||
| value normalize = value::object{}; | ||
| normalize["axis"] = value::array{normalize_attribute::include_min}; | ||
| return {{"normalize_axes", normalize}}; | ||
| } | ||
|
|
||
| template <class Self, class F> | ||
| static auto reflect(Self& self, F f) | ||
| { | ||
| return pack(f(self.axis, "axis")); | ||
| } | ||
|
|
||
| migraphx::shape normalize_compute_shape(std::vector<migraphx::shape> inputs) const | ||
| { | ||
| check_shapes{inputs, *this}.same_dims().has(1); | ||
| const auto& in_shape = inputs.front(); | ||
| if(in_shape.type() != migraphx::shape::fp4x2_type) | ||
| { | ||
| MIGRAPHX_THROW("UNPACK_FP4: Only fp4x2_type is supported for unpacking"); | ||
| } | ||
| auto new_lens = in_shape.lens(); | ||
| new_lens[axis] *= 2; | ||
| return {migraphx::shape::float_type, new_lens}; | ||
| } | ||
|
|
||
| argument compute(const shape& output_shape, const std::vector<argument>& args) const | ||
| { | ||
| const auto& input = args.front(); | ||
| auto in_shape = input.get_shape(); | ||
|
|
||
| argument uint8_input = input.reshape({migraphx::shape::uint8_type, in_shape.lens()}); | ||
| migraphx::shape float_shape = shape{migraphx::shape::float_type, output_shape.lens()}; | ||
| argument float_arg{float_shape}; | ||
|
|
||
| float_arg.visit([&](auto out) { | ||
| uint8_input.visit([&](auto inp) { | ||
| par_for(in_shape.elements(), [&](auto i) { | ||
| auto data_idx = in_shape.multi(i); | ||
| data_idx[axis] *= 2; | ||
| // unpacking 2 unsigned parts | ||
| // unpacking 4 least significant bits first | ||
| uint8_t fp4_val = inp[i]; | ||
| out[data_idx] = fp4_to_float(fp4_val); | ||
|
|
||
| data_idx[axis] += 1; | ||
| fp4_val = fp4_val >> 4u; | ||
| out[data_idx] = fp4_to_float(fp4_val); | ||
| }); | ||
| }); | ||
| }); | ||
| return float_arg; | ||
| } | ||
| }; | ||
| } // namespace op | ||
| } // namespace MIGRAPHX_INLINE_NS | ||
| } // namespace migraphx | ||
|
|
||
| #endif |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change | ||||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -57,6 +57,7 @@ int get_onnx_type(shape::type_t s_type) | |||||||||||||||||||
| case shape::fp8e5m2_type: return 19; | ||||||||||||||||||||
| case shape::fp8e5m2fnuz_type: return 20; | ||||||||||||||||||||
| case shape::tuple_type: return 0; | ||||||||||||||||||||
| case shape::fp4x2_type: return 21; // TODO update this when the type is added | ||||||||||||||||||||
| } | ||||||||||||||||||||
|
Comment on lines
+60
to
61
|
||||||||||||||||||||
| case shape::fp4x2_type: return 21; // TODO update this when the type is added | |
| } | |
| case shape::fp4x2_type: | |
| { | |
| constexpr int placeholder_fp4x2_type = 21; // Temporary placeholder for ONNX type | |
| // TODO: Replace placeholder_fp4x2_type with the correct ONNX type when available | |
| return placeholder_fp4x2_type; | |
| } | |
| } |
Oops, something went wrong.
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.