Skip to content

Raisetolinalg #412

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 58 commits into
base: raisetolinalg
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
0b97da9
Unfinished changes with prototype function
arpitj1 Jun 6, 2024
69ef423
Loop over linalg.generic's input and output ops
arpitj1 Jun 6, 2024
7678a05
Some comments
arpitj1 Jun 6, 2024
0e88095
Partial changes from coding session to implement fusion of linalg.gen…
arpitj1 Jun 11, 2024
b57c0b8
Incremental changes to fuse linalg and for loop- Logic for shifted op…
arpitj1 Jun 19, 2024
f54c33d
ran clang format
arpitj1 Jun 25, 2024
56e2c54
some compile time fixes
arpitj1 Jun 25, 2024
e253040
Some compile fixes
arpitj1 Jul 2, 2024
e99b8a5
Fixed all the compilation issues. Sample MLIR not raised
arpitj1 Jul 3, 2024
34f595c
Bug fixes, generating some output at getLinalgArgMap
arpitj1 Jul 16, 2024
05bad97
Almost implementated remap in affine dim for multi idx
arpitj1 Jul 17, 2024
5bbf5ef
Added submap op support and refactored the code to use submap
arpitj1 Jul 24, 2024
9018d92
bunch of fixes. Now able to generate raise linalg code
arpitj1 Jul 30, 2024
ec041a0
Now almost working second loop raising to linalg
arpitj1 Jul 31, 2024
23138fc
Fixes to correctly raise 2 level for loops to linalg.generic
arpitj1 Jul 31, 2024
5f20bd7
Missed file update to enable linalg dialect in polygeist
arpitj1 Jul 31, 2024
b0e96aa
Fix for syms and dims calculation
arpitj1 Aug 6, 2024
ea76f0a
More tests added to cover different loop cases
arpitj1 Aug 7, 2024
591c84e
Now able to compile 3/any number of loops with parallel iter type; Ad…
arpitj1 Aug 7, 2024
b0108e3
Non iter-arg variant of matrix-mul and conv are now raised to linalg.…
arpitj1 Aug 7, 2024
4362c80
submap canonicalizer implemented
arpitj1 Aug 21, 2024
77c8168
Added reduction loops for linalg
arpitj1 Aug 22, 2024
98f0119
Fix for incorrect for loop dims
arpitj1 Aug 28, 2024
59eec0b
Linalg.generic 4 loop cases raised- todo: reduction and some if-else …
arpitj1 Sep 5, 2024
a363f13
Adding test case for all passing raising and lowering, example case o…
arpitj1 Sep 18, 2024
814ca51
Added pass remove iter args from scf; Added psuedo code for submap ca…
arpitj1 Oct 12, 2024
701f25a
Added removal of iter_args for affine loops
arpitj1 Oct 12, 2024
d285fb5
Temporary reverted pass registeration as the code was failing
arpitj1 Oct 12, 2024
c40e7a9
WIP commit
arpitj1 Oct 15, 2024
788a3c4
Added submap of submap canonicalizer with test- failing
arpitj1 Oct 18, 2024
8265216
Added canonicalization for linalg with submap and test cases
arpitj1 Oct 25, 2024
532773a
Added modified 2d kernel for harris score- raised successfully to lin…
arpitj1 Oct 25, 2024
e2b4b2d
Added harris score kernel with gradient kernel- just to be able to ra…
arpitj1 Oct 25, 2024
f2ab09e
Initial working implementation of debufferize flow for linalg with ex…
arpitj1 Jan 13, 2025
2342381
Added more complex case to show debufferization ; Fixed bugs in debuf…
arpitj1 Jan 13, 2025
fde88fe
Fixed clang format
arpitj1 Jan 13, 2025
cf9f953
Ran git clang format locally to fix regression failures
arpitj1 Jan 13, 2025
f10c47a
Working implementation for function args memrefType with noinline att…
arpitj1 Jan 17, 2025
490f924
Added debufferization Alloc Removal pass, add working examples with l…
arpitj1 Jan 17, 2025
e20708c
Added support for debufferization across nested regions - working for…
arpitj1 Jan 31, 2025
4a7efe7
Bug fix for erasing the op correctly
arpitj1 Jan 31, 2025
6d8832f
Bug fixes for 1. recursive parent search in sorting users 2. traversi…
arpitj1 Jan 31, 2025
6ca2aeb
Added cases of buffer capture which doesn't debufferize
arpitj1 Jan 31, 2025
803ec30
Canonicalization gets rid of memref capture by loop
arpitj1 Feb 1, 2025
fb0ac18
Working implementation for scf.for op and scf.if op; added bug fix to…
arpitj1 Feb 7, 2025
0472c34
Added data structures to track expandedUsers that can include for loo…
arpitj1 Feb 7, 2025
3272f2c
Added logic in for loop case to find all users of iter_args and updat…
arpitj1 Feb 8, 2025
da2ae5b
Added a bunch of tests with nested regions- all getting connected and…
arpitj1 Feb 8, 2025
a570c1b
Added more complex region cases with mix of if-else statements
arpitj1 Feb 8, 2025
7ee707b
Generic solver to represent linalg.generic as kernel.def ops
arpitj1 May 8, 2025
c8561b4
Adding cases for generic solver
arpitj1 May 12, 2025
07d0dcb
Backup of previous edits
arpitj1 May 28, 2025
009ab9b
Temp changes for kernel dialect
Jun 11, 2025
c0f36d3
Enabled kernel dialect correctly running on sample IR with kernel def…
Jun 11, 2025
6a67379
Added linalgToKernel pass- compile failure
arpitj1 Jun 12, 2025
7f9d00f
Working pattern matching and replacement for linalg generics
arpitj1 Jun 12, 2025
d765bb9
Partial changes for different files for kernel and input
arpitj1 Jun 12, 2025
15ef84e
Crash fix
arpitj1 Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
360 changes: 360 additions & 0 deletions generic_solver/CublasDefnPattern.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,360 @@
//===- KernelDefnPattern.cpp - Pattern to match linalg.generic with kernel.defn ------===//
//
// This file implements a pattern to rewrite linalg.generic operations to kernel
// operations by matching against patterns defined in kernel.defn_collection.
//
//===----------------------------------------------------------------------===//

#include "mlir/Dialect/Linalg/IR/Linalg.h"
#include "mlir/Dialect/Tensor/IR/Tensor.h"
#include "mlir/IR/PatternMatch.h"
#include "llvm/ADT/TypeSwitch.h"
#include "KernelOps.h"

using namespace mlir;
using namespace mlir::linalg;

namespace {

// Cases:
// 1. What if they do a*(b+c) as a*b+a*c ?
// 2. What is they do (a+b)/c as a/c+b/c ?
// - The required best form can vary based on a cost model for a given architecture
// - The expectation is that kernel.defn is the best form an op is expected to take
// - The generic solver will employ heuristics to match the best form
// - Heuristics can be as simple as "is the op a commutative operation ?",
// "is the op an associative operation ?", "is the op distributive ?", etc.
// 3. What if the order of operations is different ? add(a,b) as add(b,a)
// - This requires a commutative check for operations, i.e in commutative ops
// we don't need to match positions
// 4. What if order of uses are different for an op? Eg-
// a1 = ... | a2 = ...
// b1 = a1/c1 | d2 = a2*c2
// d1 = a1*c1 | b2 = a2/c2
// - In this case, we need to find the corresponding uses of the operands
// 5.

// Non-recursive traversal of use-def chain using a stack
bool compareUseDefChains(Value firstValue, Value secondValue) {
// Use a std::stack to track operations we need to visit
std::stack<std::pair<Value, Value>> workList;
std::set<std::pair<void*, void*>> visited;

// Start with the initial values
workList.push({firstValue, secondValue});

while (!workList.empty()) {
auto [value1, value2] = workList.top();
workList.pop();

// Skip if we've already processed this pair
auto valuePtrPair = std::make_pair(value1.getImpl(), value2.getImpl());
if (visited.count(valuePtrPair))
continue;
visited.insert(valuePtrPair);

// Compare the values themselves
if (value1.getType() != value2.getType())
return false;

// Compare all uses
auto uses1 = value1.getUses();
auto uses2 = value2.getUses();

// Process each use
for (auto &use1 : uses1) {
Operation *op1 = use1.getOwner();

// Find corresponding use in second value
bool foundMatch = false;
for (auto &use2 : uses2) {
Operation *op2 = use2.getOwner();

// Compare operations (customize based on your definition of equivalence)
if (op1->getName() == op2->getName() &&
//This requires a commutative check
use1.getOperandNumber() == use2.getOperandNumber()) {
foundMatch = true;

// Add results to worklist to continue traversal
for (unsigned i = 0; i < op1->getNumResults(); ++i) {
if (i < op2->getNumResults())
workList.push({op1->getResult(i), op2->getResult(i)});
}
break;
}
}

if (!foundMatch)
return false;
}
}

return true;
}


// Helper function to check if two regions are structurally equivalent
bool areRegionsEquivalent(Region &first, Region &second) {
// Compare number of blocks
if (first.getBlocks().size() != second.getBlocks().size())
return false;

// Compare corresponding blocks
for (auto blockPair : llvm::zip(first.getBlocks(), second.getBlocks())) {
Block &firstBlock = std::get<0>(blockPair);
Block &secondBlock = std::get<1>(blockPair);

// Compare number of arguments
if (firstBlock.getNumArguments() != secondBlock.getNumArguments())
return false;

//// Compare argument types
//for (auto argPair : llvm::zip(firstBlock.getArguments(),
// secondBlock.getArguments())) {
// if (std::get<0>(argPair).getType() != std::get<1>(argPair).getType())
// return false;
//}

//Traverse the use-def chain of the arguments and compare the operation names
for (auto argPair : llvm::zip(firstBlock.getArguments(),
secondBlock.getArguments())) {
if (std::get<0>(argPair).getName() != std::get<1>(argPair).getName())
return false;
//Traverse the use-def chain of the argument
for (auto use : std::get<0>(argPair).getUses()) {
if (use.getOwner().getName() != std::get<1>(argPair).getName())
return false;
}
}

//// Compare operations (simplified - real implementation would be more complex)
//if (firstBlock.getOperations().size() != secondBlock.getOperations().size())
// return false;

//// For a full implementation, you'd need more sophisticated operation comparison
//// based on operands, attributes, and result types
}

return true;
}

// Helper to check if indexing maps are equivalent
bool areIndexingMapsEquivalent(ArrayAttr firstMaps, ArrayAttr secondMaps) {
if (firstMaps.size() != secondMaps.size())
return false;

for (auto mapPair : llvm::zip(firstMaps, secondMaps)) {
auto firstMap = std::get<0>(mapPair).cast<AffineMapAttr>().getValue();
auto secondMap = std::get<1>(mapPair).cast<AffineMapAttr>().getValue();

if (firstMap != secondMap)
return false;
}

return true;
}

// Helper to check if iterator types are equivalent
bool areIteratorTypesEquivalent(ArrayAttr firstTypes, ArrayAttr secondTypes) {
if (firstTypes.size() != secondTypes.size())
return false;

for (auto typePair : llvm::zip(firstTypes, secondTypes)) {
auto firstType = std::get<0>(typePair).cast<StringAttr>().getValue();
auto secondType = std::get<1>(typePair).cast<StringAttr>().getValue();

if (firstType != secondType)
return false;
}

return true;
}

// Check if a linalg.generic operation matches a kernel.defn in a collection
FailureOr<std::string> matchGenericWithDefn(
GenericOp genericOp,
kernel::DefnCollectionOp collectionOp) {

// Get attributes from the generic operation
ArrayAttr indexingMaps = genericOp.getIndexingMapsAttr();
ArrayAttr iteratorTypes = genericOp.getIteratorTypesAttr();
unsigned numInputs = genericOp.getNumDpsInputs();
unsigned numOutputs = genericOp.getNumDpsInits();

// Walk through each defn in the collection
for (Operation &op : collectionOp.getDefns()) {
auto defnOp = cast<kernel::DefnOp>(op);
StringAttr opName = defnOp.getNameAttr();

// Check for linalg.generic in the defn's body
bool foundMatch = false;
defnOp.getBody().walk([&](GenericOp candidateOp) {
// Skip if already found a match
if (foundMatch)
return;

// Check if this linalg.generic matches our target
if (candidateOp.getNumDpsInputs() == numInputs &&
candidateOp.getNumDpsInits() == numOutputs &&
//DONE: Generalize to a single dialect, with no special ops
//TODO: Indexing maps and orders might differ
//TODO: More complex case- where extra loops exists around the ops we have
//TODO: Custom cost model ?
//TODO: Constants might require special handling such as bounds
//IDEA: Descheduling / removing tiles
int numOfIndexingMaps = indexingMaps.size();
int combinations = calculate_combinations(numOfIndexingMaps);
int calculatedCombinations(int numOfPos) {
//Calculate factorial of numOfPos
int result = 1;
for (int i = 1; i <= numOfPos; i++) {
result *= i;
}
return result;
}
areIndexingMapsEquivalent(candidateOp.getIndexingMapsAttr(), indexingMaps) &&
areIteratorTypesEquivalent(candidateOp.getIteratorTypesAttr(), iteratorTypes) &&
areRegionsEquivalent(candidateOp.getRegion(), genericOp.getRegion())) {
foundMatch = true;
}
});

if (foundMatch)
return opName.str();
}

return failure();
}

// Rewrite pattern to convert linalg.generic to kernel ops
class LinalgGenericToKernelPattern : public OpRewritePattern<GenericOp> {
public:
LinalgGenericToKernelPattern(MLIRContext *context,
kernel::DefnCollectionOp collectionOp)
: OpRewritePattern<GenericOp>(context), collectionOp(collectionOp) {}

LogicalResult matchAndRewrite(GenericOp genericOp,
PatternRewriter &rewriter) const override {
// Try to match with a defn in the collection
auto matchResult = matchGenericWithDefn(genericOp, collectionOp);
if (failed(matchResult))
return failure();

std::string opName = *matchResult;

// Create the appropriate kernel operation based on the matched pattern
if (opName == "Kernel_gemm") {
// Get inputs and outputs
Value outputTensor = genericOp.getDpsInitOperand(0)->get();
Value inputA = genericOp.getDpsInputOperand(0)->get();
Value inputB = genericOp.getDpsInputOperand(1)->get();

// Default alpha and beta values (could be extracted from pattern)
FloatAttr alpha = rewriter.getF32FloatAttr(1.0);
FloatAttr beta = rewriter.getF32FloatAttr(0.0);

// Create the kernel.gemm operation
rewriter.replaceOpWithNewOp<kernel::GemmOp>(
genericOp, genericOp.getResultTypes(),
outputTensor, inputA, inputB, alpha, beta);

return success();
}
else if (opName == "Kernel_batched_gemm") {
// Get inputs and outputs
Value outputTensor = genericOp.getDpsInitOperand(0)->get();
Value inputA = genericOp.getDpsInputOperand(0)->get();
Value inputB = genericOp.getDpsInputOperand(1)->get();

// Default alpha and beta values
FloatAttr alpha = rewriter.getF32FloatAttr(1.0);
FloatAttr beta = rewriter.getF32FloatAttr(0.0);

// Create the kernel.batched_gemm operation
rewriter.replaceOpWithNewOp<kernel::BatchedGemmOp>(
genericOp, genericOp.getResultTypes(),
outputTensor, inputA, inputB, alpha, beta);

return success();
}
else if (opName == "Kernel_iamax") {
// Get input
Value input = genericOp.getDpsInputOperand(0)->get();

// Create the kernel.iamax operation
rewriter.replaceOpWithNewOp<kernel::IndexMaxAbsOp>(
genericOp, genericOp.getResultTypes(), input);

return success();
}
else if (opName == "Kernel_iamin") {
// Get input
Value input = genericOp.getDpsInputOperand(0)->get();

// Create the kernel.iamin operation
rewriter.replaceOpWithNewOp<kernel::IndexMinAbsOp>(
genericOp, genericOp.getResultTypes(), input);

return success();
}
else if (opName == "Kernel_asum") {
// Get input
Value input = genericOp.getDpsInputOperand(0)->get();

// Create the kernel.asum operation
rewriter.replaceOpWithNewOp<kernel::AbsSumOp>(
genericOp, genericOp.getResultTypes(), input);

return success();
}

return failure();
}

private:
kernel::DefnCollectionOp collectionOp;
};

// Pass to apply the rewrite pattern
class LinalgToKernelPass
: public PassWrapper<LinalgToKernelPass, OperationPass<ModuleOp>> {
public:
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(LinalgToKernelPass)

void runOnOperation() override {
ModuleOp module = getOperation();

// Find the kernel.defn_collection in the module
kernel::DefnCollectionOp collectionOp;
module.walk([&](kernel::DefnCollectionOp op) {
collectionOp = op;
return WalkResult::interrupt();
});

if (!collectionOp) {
module.emitError("No kernel.defn_collection found in module");
return signalPassFailure();
}

// Apply the rewrite pattern
RewritePatternSet patterns(&getContext());
patterns.add<LinalgGenericToKernelPattern>(&getContext(), collectionOp);

if (failed(applyPatternsAndFoldGreedily(module, std::move(patterns))))
return signalPassFailure();
}
};

} // namespace

// Create a pass to convert linalg.generic to kernel
std::unique_ptr<Pass> createLinalgToKernelPass() {
return std::make_unique<LinalgToKernelPass>();
}

// Register the pass
void registerLinalgToKernelPasses() {
PassRegistration<LinalgToKernelPass>("linalg-to-kernel",
"Convert linalg.generic to kernel operations");
}
Loading
Loading