Skip to content

Commit fc6c0b2

Browse files
authored
Code comments update (#4032)
1 parent aca9be5 commit fc6c0b2

File tree

6 files changed

+53
-21
lines changed

6 files changed

+53
-21
lines changed

src/include/migraphx/matcher.hpp

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -752,6 +752,11 @@ MIGRAPHX_PRED_MATCHER(broadcast, instruction_ref ins)
752752
return contains({"broadcast", "multibroadcast"}, ins->name());
753753
}
754754

755+
/**
756+
* Makes a matcher that recursively traverses over single inputs to an instruction that
757+
* match the given matchers. The matcher will then be at the instruction before the `ms`
758+
* matched instructions.
759+
*/
755760
template <class... Ms>
756761
auto skip(Ms... ms)
757762
{
@@ -771,6 +776,12 @@ auto skip(Ms... ms)
771776
});
772777
}
773778

779+
/**
780+
* Makes a matcher that recursively traverses over single outputs to an instruction that
781+
* match the given matchers. The matcher will then return at the instruction after the `ms`
782+
* matched instructions. If any instruction matched has more than one output the matcher
783+
* returns nullopt.
784+
*/
774785
template <class... Ms>
775786
auto skip_output(Ms... ms)
776787
{

src/include/migraphx/module.hpp

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -259,16 +259,23 @@ struct MIGRAPHX_EXPORT module
259259
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
260260
const std::function<shape(const shape&)>& shape_transform = nullptr);
261261

262-
// Fuse the instruction into the module by inserting the instructions and
263-
// parameters for any missing inputs.
262+
/**
263+
* Fuse the instruction into the module by inserting the instructions and
264+
* parameters for any missing inputs.
265+
* `map_ins` is mapping from previous instructions to new instructions.
266+
*/
264267
std::vector<instruction_ref>
265268
fuse(const std::vector<instruction_ref>& inss,
266269
std::unordered_map<instruction_ref, instruction_ref>* map_ins = nullptr,
267270
inserter insert = nullptr,
268271
const std::function<shape(const shape&)>& shape_transform = nullptr);
269272

270-
// Fuse another module into this module by inserting the instructions and
271-
// parameters from the module
273+
/**
274+
* Fuse another module into this module by inserting the instructions and
275+
* parameters from the module
276+
* map_ins is mapping from previous instructions to new instructions
277+
* Returns output instructions to the module.
278+
*/
272279
std::vector<instruction_ref>
273280
fuse(const module& m,
274281
const std::vector<instruction_ref>& inputs,

src/include/migraphx/op/layout.hpp

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
/*
22
* The MIT License (MIT)
33
*
4-
* Copyright (c) 2015-2022 Advanced Micro Devices, Inc. All rights reserved.
4+
* Copyright (c) 2015-2025 Advanced Micro Devices, Inc. All rights reserved.
55
*
66
* Permission is hereby granted, free of charge, to any person obtaining a copy
77
* of this software and associated documentation files (the "Software"), to deal
@@ -38,6 +38,11 @@ namespace migraphx {
3838
inline namespace MIGRAPHX_INLINE_NS {
3939
namespace op {
4040

41+
/**
42+
* Rearrange the data layout of the input instruction based on the permutation attribute.
43+
* permutation: List with how to rearrange data buffer of input instruction from order of slowest
44+
* dimension to fastest dimension. Integers refer to input axes.
45+
*/
4146
struct layout : unary<layout>
4247
{
4348
std::vector<int64_t> permutation;

src/include/migraphx/shape.hpp

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -197,23 +197,17 @@ struct MIGRAPHX_EXPORT shape
197197
explicit shape(const std::vector<shape>& subs);
198198

199199
/**
200-
* Creates an output shape with dimensions equal to the input lengths and strides determined
201-
* by the permutation argument such that find_permutation() of the output shape returns the
202-
* inputted permuation.
200+
* Creates an output shape with dimensions `l` and strides computed to fulfill the given
201+
* permutation.
203202
*
204-
* 2D example:
205-
* parameters:
206-
* l = [2, 3], perm = [1, 0]
207-
* therefore:
208-
* "original" shape = {lens = [3, 2], strides = [2, 1]}
209-
* output_shape = {lens = [2, 3], strides = [1, 2]
203+
* `t` = shape type
204+
* `l` = output dimensions
205+
* `perm` = order dimensions from slowest dimension to fastest dimension
210206
*
211-
* 3D example:
212-
* parameters:
213-
* l = [2, 3, 4], perm = [1, 2, 0]
214-
* therefore:
215-
* "original" shape = {lens = [3, 4, 2], strides = [8, 2, 1]}
216-
* output_shape = {lens = [2, 3, 4], strides = [1, 8, 2]}
207+
* Example:
208+
* `t` = float_type, `l` = [2, 3, 4], `perm` = [1, 2, 0]
209+
* axis=1 to slowest dimension, axis=2 to second slowest, axis=0 to fastest
210+
* returns shape{type = float, lens = [2, 3, 4], strides = [1, 8 ,2]}
217211
*/
218212
static shape
219213
from_permutation(type_t t, const std::vector<std::size_t>& l, const std::vector<int64_t>& perm);

src/module.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ struct module_impl
5959
std::unordered_set<instruction*> instruction_set;
6060
std::string name;
6161
uint32_t nparams = 0;
62-
bool bypass = false;
62+
bool bypass = false; // used for skipping compiler passes
6363
bit_signal<64> changed{};
6464

6565
bool contains(instruction_ref ins) const

src/targets/gpu/fuse_mlir.cpp

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -642,6 +642,10 @@ struct find_mlir_split_reduce
642642
}
643643
};
644644

645+
/**
646+
* Fuses rocMLIR compatible dot or conv op -> reshapes -> pointwise
647+
* into a mlir_op with submodule.
648+
*/
645649
struct find_mlir_fused_ops
646650
{
647651
mlir_mode conv_mode = mlir_mode::none;
@@ -655,6 +659,12 @@ struct find_mlir_fused_ops
655659
return names;
656660
}
657661

662+
/**
663+
* Matches:
664+
* mlir_dot_or_conv <binds to "gemm_based_op"> ->
665+
* skip(conv_dot_reshaper_names) <binds to "x"> ->
666+
* mlir_pointwise <matcher result>
667+
*/
658668
auto matcher() const
659669
{
660670
static const auto conv_dot_reshaper_names = make_conv_dot_reshaper_names();
@@ -933,6 +943,11 @@ struct find_mlir_attention_fused_ops : public find_mlir_standalone_attention_op
933943
}
934944
};
935945

946+
/**
947+
* Input fusion of pointwise operators into a mlir_op.
948+
* Only fuses unary pointwise operators by default.
949+
* Fuses all fusable pw ops with MIGRAPHX_ENABLE_MLIR_INPUT_FUSION
950+
*/
936951
struct find_pointwise_mlir
937952
{
938953
auto supported_pointwise() const { return mlir_input_pointwise(match::used_once()); }

0 commit comments

Comments
 (0)