|  | 
|  | 1 | +/* | 
|  | 2 | + * The MIT License (MIT) | 
|  | 3 | + * | 
|  | 4 | + * Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved. | 
|  | 5 | + * | 
|  | 6 | + * Permission is hereby granted, free of charge, to any person obtaining a copy | 
|  | 7 | + * of this software and associated documentation files (the "Software"), to deal | 
|  | 8 | + * in the Software without restriction, including without limitation the rights | 
|  | 9 | + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell | 
|  | 10 | + * copies of the Software, and to permit persons to whom the Software is | 
|  | 11 | + * furnished to do so, subject to the following conditions: | 
|  | 12 | + * | 
|  | 13 | + * The above copyright notice and this permission notice shall be included in | 
|  | 14 | + * all copies or substantial portions of the Software. | 
|  | 15 | + * | 
|  | 16 | + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR | 
|  | 17 | + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, | 
|  | 18 | + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE | 
|  | 19 | + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER | 
|  | 20 | + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, | 
|  | 21 | + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN | 
|  | 22 | + * THE SOFTWARE. | 
|  | 23 | + * | 
|  | 24 | + */ | 
|  | 25 | +#include <migraphx/propagate_precision.hpp> | 
|  | 26 | +#include <migraphx/module.hpp> | 
|  | 27 | +#include <migraphx/iterator_for.hpp> | 
|  | 28 | +#include <migraphx/instruction.hpp> | 
|  | 29 | +#include <migraphx/pass_manager.hpp> | 
|  | 30 | +#include <migraphx/make_op.hpp> | 
|  | 31 | +#include <migraphx/functional.hpp> | 
|  | 32 | +#include <migraphx/ranges.hpp> | 
|  | 33 | +#include <migraphx/eliminate_convert.hpp> | 
|  | 34 | +#include <unordered_set> | 
|  | 35 | +#include <unordered_map> | 
|  | 36 | + | 
|  | 37 | +namespace migraphx { | 
|  | 38 | +inline namespace MIGRAPHX_INLINE_NS { | 
|  | 39 | + | 
|  | 40 | +namespace { | 
|  | 41 | +#ifdef __clang__ | 
|  | 42 | +#pragma clang diagnostic push | 
|  | 43 | +#pragma clang diagnostic ignored "-Wunused-function" | 
|  | 44 | +#endif | 
|  | 45 | +// Class wrappper so we can compare precision using comparison operators | 
|  | 46 | +struct precision | 
|  | 47 | +{ | 
|  | 48 | +    shape::type_t type; | 
|  | 49 | + | 
|  | 50 | +    friend bool operator==(const precision& xp, const precision& yp) { return xp.type == yp.type; } | 
|  | 51 | +    friend bool operator<(const precision& xp, const precision& yp) | 
|  | 52 | +    { | 
|  | 53 | +        bool is_less = false; | 
|  | 54 | +        shape::visit(xp.type, [&](auto x) { | 
|  | 55 | +            shape::visit(yp.type, [&](auto y) { | 
|  | 56 | +                if(x.is_integral() != y.is_integral()) | 
|  | 57 | +                    return; | 
|  | 58 | +                if(x.is_integral()) | 
|  | 59 | +                { | 
|  | 60 | +                    if(x.is_unsigned() != y.is_unsigned() and x.size() == y.size()) | 
|  | 61 | +                        is_less = y.is_unsigned(); | 
|  | 62 | +                    else | 
|  | 63 | +                        is_less = x.size() < y.size(); | 
|  | 64 | +                } | 
|  | 65 | +                else | 
|  | 66 | +                { | 
|  | 67 | +                    is_less = x.size() < y.size(); | 
|  | 68 | +                } | 
|  | 69 | +            }); | 
|  | 70 | +        }); | 
|  | 71 | +        return is_less; | 
|  | 72 | +    } | 
|  | 73 | +    friend bool operator!=(const precision& xp, const precision& yp) { return not(xp == yp); } | 
|  | 74 | +    friend bool operator>(const precision& xp, const precision& yp) { return yp < xp; } | 
|  | 75 | +    // This is not totally ordered | 
|  | 76 | +    friend bool operator<=(const precision& xp, const precision& yp) | 
|  | 77 | +    { | 
|  | 78 | +        return (xp < yp) or (xp == yp); | 
|  | 79 | +    } | 
|  | 80 | +    friend bool operator>=(const precision& xp, const precision& yp) | 
|  | 81 | +    { | 
|  | 82 | +        return (xp > yp) or (xp == yp); | 
|  | 83 | +    } | 
|  | 84 | +}; | 
|  | 85 | +#ifdef __clang__ | 
|  | 86 | +#pragma clang diagnostic pop | 
|  | 87 | +#endif | 
|  | 88 | +} // namespace | 
|  | 89 | + | 
|  | 90 | +static bool is_pointwise_or_reduce(instruction_ref ins) | 
|  | 91 | +{ | 
|  | 92 | +    return contains(ins->name(), "reduce") or | 
|  | 93 | +           ins->get_operator().attributes().get("pointwise", false); | 
|  | 94 | +} | 
|  | 95 | +// Check if its not a scalar constant | 
|  | 96 | +static bool is_non_scalar_const(instruction_ref ins) | 
|  | 97 | +{ | 
|  | 98 | +    return not(ins->get_shape().scalar() and ins->can_eval()); | 
|  | 99 | +} | 
|  | 100 | +// Get the next input instruction otherwise return a nullopt | 
|  | 101 | +static std::optional<instruction_ref> get_next_input(instruction_ref ins) | 
|  | 102 | +{ | 
|  | 103 | +    if(ins->inputs().size() == 1) | 
|  | 104 | +        return ins->inputs().front(); | 
|  | 105 | +    if(ins->inputs().size() > 1) | 
|  | 106 | +    { | 
|  | 107 | +        std::unordered_set<instruction_ref> non_scalars; | 
|  | 108 | +        std::copy_if(ins->inputs().begin(), | 
|  | 109 | +                     ins->inputs().end(), | 
|  | 110 | +                     std::inserter(non_scalars, non_scalars.end()), | 
|  | 111 | +                     &is_non_scalar_const); | 
|  | 112 | +        if(non_scalars.size() == 1) | 
|  | 113 | +            return *non_scalars.begin(); | 
|  | 114 | +    } | 
|  | 115 | +    return nullopt; | 
|  | 116 | +} | 
|  | 117 | + | 
|  | 118 | +// Find all adjacent instructions that could be upgraded with higher precision | 
|  | 119 | +// by traversing the inputs from a convert | 
|  | 120 | + | 
|  | 121 | +static std::unordered_set<instruction_ref> find_adjacent_inputs(instruction_ref start) | 
|  | 122 | +{ | 
|  | 123 | +    std::unordered_set<instruction_ref> result; | 
|  | 124 | +    // Promote inputs | 
|  | 125 | +    fix([&](auto self, instruction_ref ins) { | 
|  | 126 | +        if(not is_pointwise_or_reduce(ins)) | 
|  | 127 | +            return; | 
|  | 128 | +        if(contains(result, ins)) | 
|  | 129 | +            return; | 
|  | 130 | +        auto next = get_next_input(ins); | 
|  | 131 | +        if(not next.has_value()) | 
|  | 132 | +            return; | 
|  | 133 | +        result.insert(ins); | 
|  | 134 | +        self(*next); | 
|  | 135 | +    })(start->inputs().front()); | 
|  | 136 | +    return result; | 
|  | 137 | +} | 
|  | 138 | + | 
|  | 139 | +// Find all adjacent instructions that could be upgraded with higher precision | 
|  | 140 | +// by traversing the outputs from a convert | 
|  | 141 | +static std::unordered_set<instruction_ref> find_adjacent_outputs(instruction_ref start) | 
|  | 142 | +{ | 
|  | 143 | +    std::unordered_set<instruction_ref> result; | 
|  | 144 | +    // Promote outputs | 
|  | 145 | +    fix([&](auto self, instruction_ref ins) { | 
|  | 146 | +        for(auto output : ins->outputs()) | 
|  | 147 | +        { | 
|  | 148 | +            if(not is_pointwise_or_reduce(output)) | 
|  | 149 | +                continue; | 
|  | 150 | +            if(contains(result, output)) | 
|  | 151 | +                continue; | 
|  | 152 | +            auto next = get_next_input(output); | 
|  | 153 | +            if(not next.has_value()) | 
|  | 154 | +                continue; | 
|  | 155 | +            if(*next != ins) | 
|  | 156 | +                continue; | 
|  | 157 | +            result.insert(output); | 
|  | 158 | +            self(output); | 
|  | 159 | +        } | 
|  | 160 | +    })(start); | 
|  | 161 | +    return result; | 
|  | 162 | +} | 
|  | 163 | + | 
|  | 164 | +// Insert the instructions to upgrade into the map. If the map already has the | 
|  | 165 | +// instruction then choose the highest precision | 
|  | 166 | +template <class Map, class Instructions> | 
|  | 167 | +static void | 
|  | 168 | +insert_instructions_to_upgrade(Map& m, const Instructions& instructions, shape::type_t t) | 
|  | 169 | +{ | 
|  | 170 | +    for(auto ins : instructions) | 
|  | 171 | +    { | 
|  | 172 | +        auto it = m.find(ins); | 
|  | 173 | +        if(it == m.end()) | 
|  | 174 | +        { | 
|  | 175 | +            m[ins] = t; | 
|  | 176 | +        } | 
|  | 177 | +        else | 
|  | 178 | +        { | 
|  | 179 | +            it->second = std::max(precision{t}, precision{it->second}).type; | 
|  | 180 | +        } | 
|  | 181 | +    } | 
|  | 182 | +} | 
|  | 183 | + | 
|  | 184 | +// Find adjacent instructions from a convert to upgrade to use a higher | 
|  | 185 | +// precision | 
|  | 186 | +static std::unordered_map<instruction_ref, shape::type_t> find_instruction_to_upgrade(module& m) | 
|  | 187 | +{ | 
|  | 188 | +    std::unordered_map<instruction_ref, shape::type_t> result; | 
|  | 189 | +    for(auto ins : iterator_for(m)) | 
|  | 190 | +    { | 
|  | 191 | +        if(ins->name() != "convert") | 
|  | 192 | +            continue; | 
|  | 193 | +        auto output = precision{ins->get_shape().type()}; | 
|  | 194 | +        auto input  = precision{ins->inputs().front()->get_shape().type()}; | 
|  | 195 | +        if(output.type == shape::type_t::bool_type) | 
|  | 196 | +            continue; | 
|  | 197 | +        if(input < output) | 
|  | 198 | +        { | 
|  | 199 | +            insert_instructions_to_upgrade(result, find_adjacent_inputs(ins), output.type); | 
|  | 200 | +        } | 
|  | 201 | +        else if(input > output) | 
|  | 202 | +        { | 
|  | 203 | +            insert_instructions_to_upgrade(result, find_adjacent_outputs(ins), input.type); | 
|  | 204 | +        } | 
|  | 205 | +    } | 
|  | 206 | +    return result; | 
|  | 207 | +} | 
|  | 208 | + | 
|  | 209 | +void propagate_precision::apply(module_pass_manager& mpm) const | 
|  | 210 | +{ | 
|  | 211 | +    auto upgrade = find_instruction_to_upgrade(mpm.get_module()); | 
|  | 212 | +    for(const auto& p : upgrade) | 
|  | 213 | +    { | 
|  | 214 | +        auto ins      = p.first; | 
|  | 215 | +        auto t        = p.second; | 
|  | 216 | +        auto convert1 = mpm.get_module().insert_instruction( | 
|  | 217 | +            std::next(ins), make_op("convert", {{"target_type", ins->get_shape().type()}}), ins); | 
|  | 218 | +        mpm.get_module().replace_instruction(ins, convert1); | 
|  | 219 | +        std::vector<instruction_ref> inputs; | 
|  | 220 | +        std::transform(ins->inputs().begin(), | 
|  | 221 | +                       ins->inputs().end(), | 
|  | 222 | +                       std::back_inserter(inputs), | 
|  | 223 | +                       [&](auto input) { | 
|  | 224 | +                           return mpm.get_module().insert_instruction( | 
|  | 225 | +                               ins, make_op("convert", {{"target_type", t}}), input); | 
|  | 226 | +                       }); | 
|  | 227 | +        mpm.get_module().replace_instruction(ins, ins->get_operator(), inputs); | 
|  | 228 | +    } | 
|  | 229 | +    mpm.run_pass(eliminate_convert{}); | 
|  | 230 | +} | 
|  | 231 | + | 
|  | 232 | +} // namespace MIGRAPHX_INLINE_NS | 
|  | 233 | +} // namespace migraphx | 
0 commit comments