|  | 
|  | 1 | +/* | 
|  | 2 | + * The MIT License (MIT) | 
|  | 3 | + * | 
|  | 4 | + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | 
|  | 5 | + * | 
|  | 6 | + * Permission is hereby granted, free of charge, to any person obtaining a copy | 
|  | 7 | + * of this software and associated documentation files (the "Software"), to deal | 
|  | 8 | + * in the Software without restriction, including without limitation the rights | 
|  | 9 | + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | 
|  | 10 | + * copies of the Software, and to permit persons to whom the Software is | 
|  | 11 | + * furnished to do so, subject to the following conditions: | 
|  | 12 | + * | 
|  | 13 | + * The above copyright notice and this permission notice shall be included in | 
|  | 14 | + * all copies or substantial portions of the Software. | 
|  | 15 | + * | 
|  | 16 | + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
|  | 17 | + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | 
|  | 18 | + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE | 
|  | 19 | + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | 
|  | 20 | + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | 
|  | 21 | + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | 
|  | 22 | + * THE SOFTWARE. | 
|  | 23 | + */ | 
|  | 24 | +#include "migraphx/instruction.hpp" | 
|  | 25 | +#include "migraphx/instruction_ref.hpp" | 
|  | 26 | +#include <migraphx/reduce_dims.hpp> | 
|  | 27 | +#include <migraphx/gpu/compiler.hpp> | 
|  | 28 | +#include <migraphx/gpu/context.hpp> | 
|  | 29 | +#include <migraphx/gpu/compile_hip_code_object.hpp> | 
|  | 30 | +#include <migraphx/gpu/compile_hip.hpp> | 
|  | 31 | +#include <migraphx/gpu/compile_gen.hpp> | 
|  | 32 | + | 
|  | 33 | +namespace migraphx { | 
|  | 34 | +inline namespace MIGRAPHX_INLINE_NS { | 
|  | 35 | +namespace gpu { | 
|  | 36 | + | 
|  | 37 | +static const char* const unpack_fp4_kernel = R"__migraphx__( | 
|  | 38 | +#include <migraphx/kernels/unpack_fp4.hpp> | 
|  | 39 | +#include <args.hpp> | 
|  | 40 | +
 | 
|  | 41 | +namespace migraphx { | 
|  | 42 | +
 | 
|  | 43 | +extern "C" { | 
|  | 44 | +
 | 
|  | 45 | +MIGRAPHX_GLOBAL void ${kernel}(${params})  | 
|  | 46 | +{ | 
|  | 47 | +    transform_args(make_tensors())(${args})([](auto... xs) { | 
|  | 48 | +        unpack_fp4<${axis}>(xs...); | 
|  | 49 | +    }); | 
|  | 50 | +} | 
|  | 51 | +     | 
|  | 52 | +} | 
|  | 53 | +
 | 
|  | 54 | +} // namespace migraphx | 
|  | 55 | +
 | 
|  | 56 | +)__migraphx__"; | 
|  | 57 | + | 
|  | 58 | +struct unpack_fp4_compiler : compiler<unpack_fp4_compiler> | 
|  | 59 | +{ | 
|  | 60 | +    std::vector<std::string> names() const { return {"unpack_fp4"}; } | 
|  | 61 | + | 
|  | 62 | +    operation compile_op(context& ctx, const std::vector<shape>& inputs, const value& v) const | 
|  | 63 | +    { | 
|  | 64 | +        hip_compile_options options; | 
|  | 65 | +        options.inputs         = inputs; | 
|  | 66 | +        options.output         = inputs.back(); | 
|  | 67 | +        options.virtual_inputs = reduce_dims(normalize_permutation(options.inputs)); | 
|  | 68 | +        options.kernel_name    = "unpack_fp4_kernel"; | 
|  | 69 | +        options.set_launch_params(v, compute_global_for(ctx, inputs.front().elements())); | 
|  | 70 | + | 
|  | 71 | +        auto src = | 
|  | 72 | +            interpolate_string(unpack_fp4_kernel, | 
|  | 73 | +                               {{"kernel", options.kernel_name}, | 
|  | 74 | +                                {"params", enum_params(options.inputs.size(), "void * private_p")}, | 
|  | 75 | +                                {"args", enum_params(options.inputs.size(), "private_p")}, | 
|  | 76 | +                                {"axis", std::to_string(v.at("axis").to<int>())}}); | 
|  | 77 | +        return compile_hip_code_object(ctx, src, options); | 
|  | 78 | +    } | 
|  | 79 | + | 
|  | 80 | +    compiler_replace compile(context& ctx, instruction_ref ins, const operation& op) const | 
|  | 81 | +    { | 
|  | 82 | +        return compile_op(ctx, to_shapes(ins->inputs()), op.to_value()); | 
|  | 83 | +    } | 
|  | 84 | +}; | 
|  | 85 | + | 
|  | 86 | +} // namespace gpu | 
|  | 87 | +} // namespace MIGRAPHX_INLINE_NS | 
|  | 88 | +} // namespace migraphx | 
0 commit comments