Skip to content

Commit 3093272

Browse files
authored
MXFP4 GPU pack and unpack (#4181)
GPU JIT kernels and updates for running pack_fp4 and unpack_fp4 on GPU
1 parent 7ff4219 commit 3093272

File tree

15 files changed

+529
-42
lines changed

15 files changed

+529
-42
lines changed

src/include/migraphx/byte.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ template <class IntType,
4646
MIGRAPHX_REQUIRES(std::is_integral<IntType>{} and std::is_unsigned<IntType>{})>
4747
constexpr byte operator<<(byte b, IntType shift) noexcept
4848
{
49-
return static_cast<byte>(static_cast<unsigned char>(b) << shift);
49+
return static_cast<byte>(static_cast<uint8_t>(b) << shift);
5050
};
5151

5252
template <class IntType,

src/include/migraphx/op/pack_fp4.hpp

Lines changed: 21 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,10 @@ struct pack_fp4
5757
{
5858
check_shapes{inputs, *this}.same_dims().has(1);
5959
const auto& in_shape = inputs.front();
60+
if(in_shape.type() != migraphx::shape::float_type)
61+
{
62+
MIGRAPHX_THROW("PACK_FP4: Only float32 type input is supported");
63+
}
6064
auto new_lens = in_shape.lens();
6165
if(new_lens[axis] % 2 != 0)
6266
{
@@ -68,31 +72,27 @@ struct pack_fp4
6872

6973
argument compute(const shape& output_shape, const std::vector<argument>& args) const
7074
{
71-
auto input = args.front();
75+
const auto& input = args.front();
7276
auto in_shape = input.get_shape();
7377

74-
migraphx::shape uint8_shape = shape{migraphx::shape::uint8_type, output_shape.lens()};
75-
argument uint8_arg{uint8_shape};
76-
uint8_arg.visit([&](auto out) {
77-
input.visit([&](auto inp) {
78-
par_for(output_shape.elements(), [&](auto i) {
79-
using inp_type = typename decltype(inp)::value_type;
80-
auto data_idx = output_shape.multi(i);
81-
auto in_data_multi_idx = data_idx;
82-
in_data_multi_idx[axis] *= 2;
83-
inp_type inp_val0 = inp[in_data_multi_idx];
84-
in_data_multi_idx[axis] += 1;
85-
inp_type inp_val1 = inp[in_data_multi_idx];
86-
uint8_t out_val0 = float_to_fp4(inp_val0);
87-
uint8_t out_val1 = float_to_fp4(inp_val1);
88-
// NOTE: integral promotion occurs when bitshifting for uint8_t
89-
out[i] = static_cast<uint8_t>(out_val1 << 4u) |
90-
static_cast<uint8_t>(out_val0 & 0xFu);
91-
});
78+
argument result{output_shape};
79+
auto out = result.get<uint8_t>();
80+
input.visit([&](auto inp) {
81+
par_for(output_shape.elements(), [&](auto i) {
82+
using inp_type = typename decltype(inp)::value_type;
83+
auto data_idx = output_shape.multi(i);
84+
auto in_data_multi_idx = data_idx;
85+
in_data_multi_idx[axis] *= 2;
86+
inp_type inp_val0 = inp[in_data_multi_idx];
87+
in_data_multi_idx[axis] += 1;
88+
inp_type inp_val1 = inp[in_data_multi_idx];
89+
uint8_t out_val0 = float_to_fp4(inp_val0);
90+
uint8_t out_val1 = float_to_fp4(inp_val1);
91+
// NOTE: integral promotion occurs when bitshifting for uint8_t
92+
out[i] =
93+
static_cast<uint8_t>(out_val1 << 4u) | static_cast<uint8_t>(out_val0 & 0xFu);
9294
});
9395
});
94-
migraphx::argument result =
95-
uint8_arg.reshape({migraphx::shape::fp4x2_type, output_shape.lens()});
9696
return result;
9797
}
9898
};

src/include/migraphx/op/unpack_fp4.hpp

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -73,24 +73,21 @@ struct unpack_fp4
7373
const auto& input = args.front();
7474
auto in_shape = input.get_shape();
7575

76-
argument uint8_input = input.reshape({migraphx::shape::uint8_type, in_shape.lens()});
7776
migraphx::shape float_shape = shape{migraphx::shape::float_type, output_shape.lens()};
7877
argument float_arg{float_shape};
79-
78+
auto inp = input.get<uint8_t>();
8079
float_arg.visit([&](auto out) {
81-
uint8_input.visit([&](auto inp) {
82-
par_for(in_shape.elements(), [&](auto i) {
83-
auto data_idx = in_shape.multi(i);
84-
data_idx[axis] *= 2;
85-
// unpacking 2 unsigned parts
86-
// unpacking 4 least significant bits first
87-
uint8_t fp4_val = inp[i];
88-
out[data_idx] = fp4_to_float(fp4_val);
80+
par_for(in_shape.elements(), [&](auto i) {
81+
auto data_idx = in_shape.multi(i);
82+
data_idx[axis] *= 2;
83+
// unpacking 2 unsigned parts
84+
// unpacking 4 least significant bits first
85+
uint8_t fp4_val = inp[i];
86+
out[data_idx] = fp4_to_float(fp4_val);
8987

90-
data_idx[axis] += 1;
91-
fp4_val = fp4_val >> 4u;
92-
out[data_idx] = fp4_to_float(fp4_val);
93-
});
88+
data_idx[axis] += 1;
89+
fp4_val = fp4_val >> 4u;
90+
out[data_idx] = fp4_to_float(fp4_val);
9491
});
9592
});
9693
return float_arg;

src/include/migraphx/reduce_dims.hpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* The MIT License (MIT)
33
*
4-
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -31,6 +31,7 @@
3131
namespace migraphx {
3232
inline namespace MIGRAPHX_INLINE_NS {
3333

34+
/// Collapse adjacent shape dimensions that are the same between shapes.
3435
MIGRAPHX_EXPORT std::vector<shape> reduce_dims(const std::vector<shape>& shapes);
3536

3637
} // namespace MIGRAPHX_INLINE_NS

src/permutation.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* The MIT License (MIT)
33
*
4-
* Copyright (c) 2015-2023 Advanced Micro Devices, Inc. All rights reserved.
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -74,6 +74,7 @@ std::vector<int64_t> find_permutation(const std::vector<shape>& shapes)
7474
return it->first;
7575
}
7676

77+
/// Normalize shapes by reordering them by their permutation
7778
std::vector<shape> normalize_permutation(const std::vector<shape>& shapes)
7879
{
7980
auto result = shapes;

src/shape.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -270,7 +270,7 @@ std::string shape::cpp_type(shape::type_t t)
270270
switch(t)
271271
{
272272
case tuple_type: MIGRAPHX_THROW("No C++ type for tuple");
273-
case fp4x2_type: MIGRAPHX_THROW("No C++ type for fp4x2_type");
273+
case fp4x2_type: return "uint8_t";
274274
#define MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE(x, t) \
275275
case x: return #t;
276276
MIGRAPHX_SHAPE_VISIT_TYPES(MIGRAPHX_SHAPE_GENERATE_CPP_TYPE_CASE)

src/targets/gpu/jit/pack_fp4.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
#include "migraphx/instruction.hpp"
25+
#include "migraphx/instruction_ref.hpp"
26+
#include <migraphx/reduce_dims.hpp>
27+
#include <migraphx/gpu/compiler.hpp>
28+
#include <migraphx/gpu/context.hpp>
29+
#include <migraphx/gpu/compile_hip_code_object.hpp>
30+
#include <migraphx/gpu/compile_hip.hpp>
31+
#include <migraphx/gpu/compile_gen.hpp>
32+
33+
namespace migraphx {
34+
inline namespace MIGRAPHX_INLINE_NS {
35+
namespace gpu {
36+
37+
static const char* const pack_fp4_kernel = R"__migraphx__(
38+
#include <migraphx/kernels/pack_fp4.hpp>
39+
#include <args.hpp>
40+
41+
namespace migraphx {
42+
43+
extern "C" {
44+
45+
MIGRAPHX_GLOBAL void ${kernel}(${params})
46+
{
47+
transform_args(make_tensors())(${args})([](auto... xs) {
48+
pack_fp4<${axis}>(xs...);
49+
});
50+
}
51+
52+
}
53+
54+
} // namespace migraphx
55+
56+
)__migraphx__";
57+
58+
struct pack_fp4_compiler : compiler<pack_fp4_compiler>
59+
{
60+
std::vector<std::string> names() const { return {"pack_fp4"}; }
61+
62+
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
63+
{
64+
hip_compile_options options;
65+
options.inputs = inputs;
66+
options.output = inputs.back();
67+
options.virtual_inputs = reduce_dims(normalize_permutation(options.inputs));
68+
options.kernel_name = "pack_fp4_kernel";
69+
options.set_launch_params(v, compute_global_for(ctx, inputs.back().elements()));
70+
71+
auto src =
72+
interpolate_string(pack_fp4_kernel,
73+
{{"kernel", options.kernel_name},
74+
{"params", enum_params(options.inputs.size(), "void * private_p")},
75+
{"args", enum_params(options.inputs.size(), "private_p")},
76+
{"axis", std::to_string(v.at("axis").to<int>())}});
77+
return compile_hip_code_object(ctx, src, options);
78+
}
79+
80+
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
81+
{
82+
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
83+
}
84+
};
85+
86+
} // namespace gpu
87+
} // namespace MIGRAPHX_INLINE_NS
88+
} // namespace migraphx

src/targets/gpu/jit/unpack_fp4.cpp

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
/*
2+
* The MIT License (MIT)
3+
*
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
5+
*
6+
* Permission is hereby granted, free of charge, to any person obtaining a copy
7+
* of this software and associated documentation files (the "Software"), to deal
8+
* in the Software without restriction, including without limitation the rights
9+
* to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10+
* copies of the Software, and to permit persons to whom the Software is
11+
* furnished to do so, subject to the following conditions:
12+
*
13+
* The above copyright notice and this permission notice shall be included in
14+
* all copies or substantial portions of the Software.
15+
*
16+
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17+
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18+
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19+
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20+
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21+
* OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22+
* THE SOFTWARE.
23+
*/
24+
#include "migraphx/instruction.hpp"
25+
#include "migraphx/instruction_ref.hpp"
26+
#include <migraphx/reduce_dims.hpp>
27+
#include <migraphx/gpu/compiler.hpp>
28+
#include <migraphx/gpu/context.hpp>
29+
#include <migraphx/gpu/compile_hip_code_object.hpp>
30+
#include <migraphx/gpu/compile_hip.hpp>
31+
#include <migraphx/gpu/compile_gen.hpp>
32+
33+
namespace migraphx {
34+
inline namespace MIGRAPHX_INLINE_NS {
35+
namespace gpu {
36+
37+
static const char* const unpack_fp4_kernel = R"__migraphx__(
38+
#include <migraphx/kernels/unpack_fp4.hpp>
39+
#include <args.hpp>
40+
41+
namespace migraphx {
42+
43+
extern "C" {
44+
45+
MIGRAPHX_GLOBAL void ${kernel}(${params})
46+
{
47+
transform_args(make_tensors())(${args})([](auto... xs) {
48+
unpack_fp4<${axis}>(xs...);
49+
});
50+
}
51+
52+
}
53+
54+
} // namespace migraphx
55+
56+
)__migraphx__";
57+
58+
struct unpack_fp4_compiler : compiler<unpack_fp4_compiler>
59+
{
60+
std::vector<std::string> names() const { return {"unpack_fp4"}; }
61+
62+
operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const
63+
{
64+
hip_compile_options options;
65+
options.inputs = inputs;
66+
options.output = inputs.back();
67+
options.virtual_inputs = reduce_dims(normalize_permutation(options.inputs));
68+
options.kernel_name = "unpack_fp4_kernel";
69+
options.set_launch_params(v, compute_global_for(ctx, inputs.front().elements()));
70+
71+
auto src =
72+
interpolate_string(unpack_fp4_kernel,
73+
{{"kernel", options.kernel_name},
74+
{"params", enum_params(options.inputs.size(), "void * private_p")},
75+
{"args", enum_params(options.inputs.size(), "private_p")},
76+
{"axis", std::to_string(v.at("axis").to<int>())}});
77+
return compile_hip_code_object(ctx, src, options);
78+
}
79+
80+
compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const
81+
{
82+
return compile_op(ctx, to_shapes(ins->inputs()), op.to_value());
83+
}
84+
};
85+
86+
} // namespace gpu
87+
} // namespace MIGRAPHX_INLINE_NS
88+
} // namespace migraphx

src/targets/gpu/kernels/include/migraphx/kernels/algorithm.hpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,7 @@ constexpr Iterator upper_bound(Iterator first, Iterator last, const T& value, Co
281281

282282
while(count > 0)
283283
{
284+
// NOLINTNEXTLINE(readability-qualified-auto)
284285
auto it = first;
285286
auto step = count / 2;
286287
it += step;

0 commit comments

Comments
 (0)