Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
9166a80
initial progress
CharlieL7 Jun 9, 2025
003f329
progress 2
CharlieL7 Jun 9, 2025
dbb8b09
unpack work
CharlieL7 Jun 11, 2025
a3165ac
more progress
CharlieL7 Jun 15, 2025
09a4e47
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 Jun 17, 2025
0d92773
comments on what to do
CharlieL7 Jun 20, 2025
d4169f8
mm progress
CharlieL7 Jun 20, 2025
bed35c0
Merge branch 'mxfp4_pack_unpack' of github.com:ROCm/AMDMIGraphX into …
CharlieL7 Jun 20, 2025
a7e2715
parse and test progress
CharlieL7 Jun 26, 2025
a5d11f1
Add tests and fixes
CharlieL7 Jun 27, 2025
e880b5c
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 Jun 27, 2025
2741ed7
Fix parsing
CharlieL7 Jun 30, 2025
ffa4c00
Fix parsing and formatting
CharlieL7 Jul 1, 2025
cf8edf7
remove fp4_type and just use float_type
CharlieL7 Jul 1, 2025
e15bf89
Change parsing to work with odd number of elements
CharlieL7 Jul 1, 2025
5d74c7b
progress on tests
CharlieL7 Jul 2, 2025
8d1bb58
Fix errors and make tests
CharlieL7 Jul 7, 2025
2310272
Make onnx verify test and fix abs bug
CharlieL7 Jul 8, 2025
9f3f6a6
tidy fixes
CharlieL7 Jul 8, 2025
4062593
Fix api calls with new enum type
CharlieL7 Jul 9, 2025
78ecbd9
formatting
CharlieL7 Jul 9, 2025
564ffc7
some cleanup
CharlieL7 Jul 10, 2025
3b29957
simplify fp4_to_float with switch statement
CharlieL7 Jul 10, 2025
6fc0bf2
rename to packed_fp4 to fpx2
CharlieL7 Jul 10, 2025
83ac88e
Change switch case to static lookup table
CharlieL7 Jul 10, 2025
1620e23
review edits
CharlieL7 Jul 14, 2025
a5a6d57
review edits
CharlieL7 Jul 14, 2025
afade85
tracking commit
CharlieL7 Jul 14, 2025
357fcad
Change to using comparisons for float to fp4
CharlieL7 Jul 14, 2025
be8b610
Merge branch 'mxfp4_pack_unpack' of github.com:ROCm/AMDMIGraphX into …
CharlieL7 Jul 14, 2025
2e23dc8
add constexpr tests
CharlieL7 Jul 15, 2025
575fd0e
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 Jul 15, 2025
537317a
remove constexpr for float_to_fp4
CharlieL7 Jul 17, 2025
a7c0966
tidy fixes
CharlieL7 Jul 17, 2025
f9a505e
formatting
CharlieL7 Jul 17, 2025
6468662
Fix compilation
CharlieL7 Jul 17, 2025
b9f3215
fix hip_gemm switch statement
CharlieL7 Jul 18, 2025
896d4fc
More tidy fixes
CharlieL7 Jul 18, 2025
1bdadb0
formatting
CharlieL7 Jul 18, 2025
4536d58
more review comments
CharlieL7 Jul 18, 2025
04eed95
Remove fp4_casts.cpp from cmakelists
CharlieL7 Jul 21, 2025
c733c34
Add note on float_to_fp4 code
CharlieL7 Jul 21, 2025
05005b4
re-disable constexpr float_to_fp4
CharlieL7 Jul 21, 2025
aceec60
Merge branch 'develop' of github.com:ROCm/AMDMIGraphX into mxfp4_pack…
CharlieL7 Jul 21, 2025
161a54a
Make float_to_fp4 not constexpr again
CharlieL7 Jul 21, 2025
ee236da
make inline instead
CharlieL7 Jul 21, 2025
92d7978
use alternative lookup for float_to_fp4, add back constexpr test
CharlieL7 Jul 23, 2025
c64285b
Update src/include/migraphx/op/pack_fp4.hpp
CharlieL7 Jul 30, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,7 @@ register_migraphx_ops(
nonzero
onehot
outline
pack_fp4
pack_int4
pad
pointwise
Expand Down Expand Up @@ -291,6 +292,7 @@ register_migraphx_ops(
undefined
unique
unknown
unpack_fp4
unpack_int4
unsqueeze
where
Expand Down
2 changes: 2 additions & 0 deletions src/api/api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ static shape::type_t to_shape_type(migraphx_shape_datatype_t t)
switch(t)
{
case migraphx_shape_tuple_type: return shape::tuple_type;
case migraphx_shape_fp4x2_type: return shape::fp4x2_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case migraphx_shape_##x: return shape::x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
Expand All @@ -111,6 +112,7 @@ static migraphx_shape_datatype_t to_shape_type(shape::type_t t)
switch(t)
{
case shape::tuple_type: return migraphx_shape_tuple_type;
case shape::fp4x2_type: return migraphx_shape_fp4x2_type;
#define MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT(x, y) \
case shape::x: return migraphx_shape_##x;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_DETAIL_SHAPE_CASE_CONVERT)
Expand Down
1 change: 1 addition & 0 deletions src/api/include/migraphx/migraphx.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,7 @@ typedef enum
typedef enum
{
migraphx_shape_tuple_type,
migraphx_shape_fp4x2_type,
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES)
} migraphx_shape_datatype_t;
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES
Expand Down
80 changes: 80 additions & 0 deletions src/include/migraphx/fp4_casts.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_RTGLIB_FLOAT4_CASTS_HPP
#define MIGRAPHX_GUARD_RTGLIB_FLOAT4_CASTS_HPP

#include <cstdint>
#include <algorithm>
#include <array>
#include <cmath>
#include <iterator>
#include <migraphx/errors.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace fp4_detail {
static constexpr std::array<float, 16> fp4_lut = {
0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0, -0.0, -0.5, -1.0, -1.5, -2.0, -3.0, -4.0, -6.0};

// pair is {fp4_tie_value, round_to_zero}
// if round_to_zero round tie towards zero, else round tie away from zero
static constexpr std::array<std::pair<float, uint8_t>, 7> fp4_even_round = {
{{0.25, 1}, {0.75, 0}, {1.25, 1}, {1.75, 0}, {2.5, 1}, {3.5, 0}, {5, 1}}};
} // namespace fp4_detail

// converts 4 LSB to float
constexpr float fp4_to_float(uint8_t x)
{
return fp4_detail::fp4_lut[x % fp4_detail::fp4_lut.size()];
}

// rounding mode = roundToNearestRoundTiesToEven
// Reference quantization code from Microsoft:
// https://github.com/microsoft/microxcaling/blob/main/mx/elemwise_ops.py#L82
// Not constexpr because std::signbit is not constexpr until C++23
inline uint8_t float_to_fp4(float f_x)
{
using fp4_detail::fp4_even_round;
using fp4_detail::fp4_lut;
if(std::isnan(f_x))
{
return 0;
}
bool sign = std::signbit(f_x);
uint8_t sign_add = sign ? fp4_lut.size() / 2 : 0u;
float abs_f = std::abs(f_x);
// index value is the positive fp4 value
uint8_t i = std::upper_bound(fp4_even_round.begin(),
fp4_even_round.end(),
std::make_pair(abs_f, uint8_t{0})) -
fp4_even_round.begin();

return i + sign_add;
}

} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
104 changes: 104 additions & 0 deletions src/include/migraphx/op/pack_fp4.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_PACK_FP4_HPP
#define MIGRAPHX_GUARD_OPERATORS_PACK_FP4_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/fp4_casts.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {
namespace op {

struct pack_fp4
{
int64_t axis = -1;

std::string name() const { return "pack_fp4"; }

value attributes() const
{
value normalize = value::object{};
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}

migraphx::shape normalize_compute_shape(std::vector<migraphx::shape> inputs) const
{
check_shapes{inputs, *this}.same_dims().has(1);
const auto& in_shape = inputs.front();
auto new_lens = in_shape.lens();
if(new_lens[axis] % 2 != 0)
{
MIGRAPHX_THROW("PACK_FP4: Can not pack axis that has odd lengths");
}
new_lens[axis] /= 2;
return {migraphx::shape::fp4x2_type, new_lens};
}

argument compute(const shape& output_shape, const std::vector<argument>& args) const
{
auto input = args.front();
auto in_shape = input.get_shape();

migraphx::shape uint8_shape = shape{migraphx::shape::uint8_type, output_shape.lens()};
argument uint8_arg{uint8_shape};
uint8_arg.visit([&](auto out) {
input.visit([&](auto inp) {
par_for(output_shape.elements(), [&](auto i) {
using inp_type = typename decltype(inp)::value_type;
auto data_idx = output_shape.multi(i);
auto in_data_multi_idx = data_idx;
in_data_multi_idx[axis] *= 2;
inp_type inp_val0 = inp[in_data_multi_idx];
in_data_multi_idx[axis] += 1;
inp_type inp_val1 = inp[in_data_multi_idx];
uint8_t out_val0 = float_to_fp4(inp_val0);
uint8_t out_val1 = float_to_fp4(inp_val1);
// NOTE: integral promotion occurs when bitshifting for uint8_t
out[i] = static_cast<uint8_t>(out_val1 << 4u) |
static_cast<uint8_t>(out_val0 & 0xFu);
});
});
});
migraphx::argument result =
uint8_arg.reshape({migraphx::shape::fp4x2_type, output_shape.lens()});
return result;
}
};

} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
103 changes: 103 additions & 0 deletions src/include/migraphx/op/unpack_fp4.hpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
/*
* The MIT License (MIT)
*
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to deal
* in the Software without restriction, including without limitation the rights
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
* THE SOFTWARE.
*/
#ifndef MIGRAPHX_GUARD_OPERATORS_UNPACK_FP4_HPP
#define MIGRAPHX_GUARD_OPERATORS_UNPACK_FP4_HPP

#include <migraphx/check_shapes.hpp>
#include <migraphx/op/normalize_attribute.hpp>
#include <migraphx/shape.hpp>
#include <migraphx/config.hpp>
#include <migraphx/par_for.hpp>
#include <migraphx/argument.hpp>
#include <migraphx/fp4_casts.hpp>

namespace migraphx {
inline namespace MIGRAPHX_INLINE_NS {

namespace op {
struct unpack_fp4
{
int64_t axis = -1;

std::string name() const { return "unpack_fp4"; }

value attributes() const
{
value normalize = value::object{};
normalize["axis"] = value::array{normalize_attribute::include_min};
return {{"normalize_axes", normalize}};
}

template <class Self, class F>
static auto reflect(Self& self, F f)
{
return pack(f(self.axis, "axis"));
}

migraphx::shape normalize_compute_shape(std::vector<migraphx::shape> inputs) const
{
check_shapes{inputs, *this}.same_dims().has(1);
const auto& in_shape = inputs.front();
if(in_shape.type() != migraphx::shape::fp4x2_type)
{
MIGRAPHX_THROW("UNPACK_FP4: Only fp4x2_type is supported for unpacking");
}
auto new_lens = in_shape.lens();
new_lens[axis] *= 2;
return {migraphx::shape::float_type, new_lens};
}

argument compute(const shape& output_shape, const std::vector<argument>& args) const
{
const auto& input = args.front();
auto in_shape = input.get_shape();

argument uint8_input = input.reshape({migraphx::shape::uint8_type, in_shape.lens()});
migraphx::shape float_shape = shape{migraphx::shape::float_type, output_shape.lens()};
argument float_arg{float_shape};

float_arg.visit([&](auto out) {
uint8_input.visit([&](auto inp) {
par_for(in_shape.elements(), [&](auto i) {
auto data_idx = in_shape.multi(i);
data_idx[axis] *= 2;
// unpacking 2 unsigned parts
// unpacking 4 least significant bits first
uint8_t fp4_val = inp[i];
out[data_idx] = fp4_to_float(fp4_val);

data_idx[axis] += 1;
fp4_val = fp4_val >> 4u;
out[data_idx] = fp4_to_float(fp4_val);
});
});
});
return float_arg;
}
};
} // namespace op
} // namespace MIGRAPHX_INLINE_NS
} // namespace migraphx

#endif
7 changes: 6 additions & 1 deletion src/include/migraphx/shape.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ struct MIGRAPHX_EXPORT shape
#define MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES(x, t) x,
enum type_t
{
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES) tuple_type,
fp4x2_type // packed fp4 contained in uint8
};
#undef MIGRAPHX_SHAPE_GENERATE_ENUM_TYPES

Expand Down Expand Up @@ -414,6 +415,9 @@ struct MIGRAPHX_EXPORT shape
tv();
return;
}
case fp4x2_type: {
MIGRAPHX_THROW("fp4x2_type cannot be visited.");
}
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE(x, t) \
case x: v(as<t>()); return;
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_CASE)
Expand All @@ -439,6 +443,7 @@ struct MIGRAPHX_EXPORT shape
{
#define MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL(x, t) v(as<t>());
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL)
v(as<uint8_t>());
#undef MIGRAPHX_SHAPE_GENERATE_VISITOR_ALL
}

Expand Down
1 change: 1 addition & 0 deletions src/netron_output.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ int get_onnx_type(shape::type_t s_type)
case shape::fp8e5m2_type: return 19;
case shape::fp8e5m2fnuz_type: return 20;
case shape::tuple_type: return 0;
case shape::fp4x2_type: return 21; // TODO update this when the type is added
}
Comment on lines +60 to 61
Copy link

Copilot AI Jul 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO comment indicates incomplete implementation. The magic number 21 should be updated to the correct ONNX type when available.

Suggested change
case shape::fp4x2_type: return 21; // TODO update this when the type is added
}
case shape::fp4x2_type:
{
constexpr int placeholder_fp4x2_type = 21; // Temporary placeholder for ONNX type
// TODO: Replace placeholder_fp4x2_type with the correct ONNX type when available
return placeholder_fp4x2_type;
}
}

Copilot uses AI. Check for mistakes.
MIGRAPHX_THROW("MIGraphX type " + std::to_string(s_type) + " not supported");
}
Expand Down
Loading
Loading