Skip to content

Update llvm version #1752

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 216 commits into from
Jun 3, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
216 commits
Select commit Hold shift + click to select a range
0d83653
init; changelog
paul0403 May 12, 2025
fb9c6db
boilerplate
paul0403 May 12, 2025
62a3f5a
missing namespace
paul0403 May 12, 2025
81e6e97
missing include
paul0403 May 12, 2025
6242a0b
typo
paul0403 May 12, 2025
7c149d0
(cherry pick) Add ForwardOp Bufferization
tzunghanjuang Sep 6, 2024
5504e29
move over Tzunghan's old work
paul0403 May 13, 2025
fcc0424
adjoint op
paul0403 May 13, 2025
a310d05
backprop op
paul0403 May 13, 2025
2e26140
add adjoint test with multiple results
paul0403 May 14, 2025
b7f62ae
some cleanup on backprop
paul0403 May 14, 2025
d113b4a
clean up backprop
paul0403 May 14, 2025
6282518
test file rename
paul0403 May 14, 2025
277f547
changelog
paul0403 May 14, 2025
5419116
update backprop test
paul0403 May 14, 2025
19e25f8
do not manually copy the cotangents for backprop
paul0403 May 14, 2025
0962b81
use ValueRange instead of vector
paul0403 May 14, 2025
7273e4e
update comment about backprop's mem write
paul0403 May 14, 2025
4714f3c
move over Tzunghan's forward and reverse
paul0403 May 15, 2025
4d995e7
try CI
paul0403 May 15, 2025
d037193
format
paul0403 May 15, 2025
14b8599
remove old gradient bufferization
paul0403 May 15, 2025
b98c9ea
clean prints
paul0403 May 16, 2025
d3389a9
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 16, 2025
89e0bbd
format
paul0403 May 16, 2025
491cfa1
add gradient preprocessing test
paul0403 May 16, 2025
c7b57f2
easier on the eyes
paul0403 May 16, 2025
128978b
reverse op preprocessing test
paul0403 May 16, 2025
bc3caf9
add bufferization test for forward and reverse
paul0403 May 16, 2025
1eb49c1
add post processing test for forward and reverse
paul0403 May 16, 2025
eff06f4
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 16, 2025
a194c06
make most dialects into one-shot-bufferize(dialect)
paul0403 May 19, 2025
66f043f
a stable version before trying one-shot-bufferize pass
paul0403 May 19, 2025
3ae9083
try one-shot
paul0403 May 20, 2025
e35b25e
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 20, 2025
ad3a83c
try new llvm (with Tzunghan's patches already in upstream)
paul0403 May 20, 2025
642ebcd
update llvm to 7f04a8ad131881b5a58b97c8191733ed42d18e20
paul0403 May 20, 2025
6a5a2f8
GreedyRewriteConfig.enableRegionSimplification is no longer just a pl…
paul0403 May 20, 2025
18ac813
update TopologicalSortUtils.h location
paul0403 May 20, 2025
4f18b6c
include <variant>
paul0403 May 20, 2025
335220e
track llvm and mhlo versions to jax 0.4.32
paul0403 May 20, 2025
9ed0272
update llvm and mhlo submodules to jax 0.4.32 versions
paul0403 May 20, 2025
e4a9c19
.dep-versions format :sweat-smile:
paul0403 May 20, 2025
5a48957
enzymestatic-19 -> 20
paul0403 May 20, 2025
2d1d1d6
MhloQuantToIntConversion is removed
paul0403 May 20, 2025
aaf3771
`translateModuleToLLVMIR` got a new argument `disableVerification`
paul0403 May 20, 2025
5b870ca
just comment out old bufferization passes in cpp pipeline for now
paul0403 May 20, 2025
5f5d94d
Remove optional modifier from unit attribute (#1746)
ritu-thombre99 May 15, 2025
7432b03
applyPatternsAndFoldGreedily -> applyPatternsGreedily
paul0403 May 20, 2025
b155169
pattern rewtier no longer has `match` the `rewrite`.
paul0403 May 20, 2025
b4be79d
`GreedyRewriteConfig` no longer has `strictMode` and `enableRegionSim…
paul0403 May 20, 2025
1a2c907
Update set_dep_version to account for jax repo rename
mehrdad2m May 20, 2025
c9e1b95
Update Enzyme to latest version
mehrdad2m May 20, 2025
506d083
update llvm and mhlo to compatiable tags required by jax v0.6.0
mehrdad2m May 20, 2025
b3f352c
Merge branch 'paul0403/one-shot-bufferize-final' into paul0403/update…
mehrdad2m May 20, 2025
dd04fb9
change the functionArgTypeConverterFn in gradient bufferization to ta…
paul0403 May 20, 2025
5e9d75c
use one-shot bufferization in python pipeline
paul0403 May 20, 2025
e306252
update llvm to the commit that has
paul0403 May 20, 2025
83f6fc2
add restrict unitattr to the ToTensorOps in gradient lowering pass
paul0403 May 20, 2025
ffb2c60
GreedyRewriteConfig.enableRegionSimplification is no longer just a pl…
paul0403 May 20, 2025
1848bed
update TopologicalSortUtils.h location
paul0403 May 20, 2025
9b090f3
include <variant>
paul0403 May 20, 2025
93ff924
track llvm and mhlo versions to jax 0.4.32
paul0403 May 20, 2025
ece1c6d
update llvm and mhlo submodules to jax 0.4.32 versions
paul0403 May 20, 2025
7dfd0cc
.dep-versions format :sweat-smile:
paul0403 May 20, 2025
b628f47
enzymestatic-19 -> 20
paul0403 May 20, 2025
d2bc0b6
MhloQuantToIntConversion is removed
paul0403 May 20, 2025
3790a68
`translateModuleToLLVMIR` got a new argument `disableVerification`
paul0403 May 20, 2025
f370449
just comment out old bufferization passes in cpp pipeline for now
paul0403 May 20, 2025
268bc64
update cpp pipeline
paul0403 May 21, 2025
d3a75ab
turn on `copy-before-write` for async
paul0403 May 21, 2025
6e0d8df
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 21, 2025
a55e941
Add Tzunghan as author
paul0403 May 21, 2025
0739553
changelog
paul0403 May 21, 2025
0fc83a5
Use `eliminate-empty-tensors` pass instead of `empty-tensor-to-alloc-…
paul0403 May 21, 2025
b8895ed
add `restrict` attr to to_tensor ops in mlir lit test
paul0403 May 21, 2025
26d8f50
line-too-long on the python bufferization options string
paul0403 May 21, 2025
d5cdb13
skip upstream ml_dtypes lit test
paul0403 May 21, 2025
ef323be
Update mlir lit tests impacted by mlir update.
paul0403 May 21, 2025
692cf99
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 21, 2025
0e88009
enzymestatic-19 -> 20
paul0403 May 20, 2025
e04b79f
Merge remote-tracking branch 'origin/main' into paul0403/new_bufferiz…
paul0403 May 22, 2025
1550045
Merge remote-tracking branch 'origin/paul0403/new_bufferize_gradient_…
paul0403 May 22, 2025
e347497
remove llvm keyword from comment
mehrdad2m May 22, 2025
018b493
Merge remote-tracking branch 'origin/paul0403/one-shot-bufferize-fina…
paul0403 May 22, 2025
6c4fa07
Merge remote-tracking branch 'origin/paul0403/update_llvm' into paul0…
paul0403 May 22, 2025
6af9f91
checkout to an earilier mhlo commit without the build bug
paul0403 May 22, 2025
33d482d
go back to the llvm and mhlo versions tracked by jax 0.6.0
paul0403 May 22, 2025
4a13d73
install nanobind==2.4 before building llvm
mehrdad2m May 22, 2025
05ef90b
enzymestatic-20 -> 21
paul0403 May 22, 2025
6eadec2
Patch enzyme's dependency on nvidia fabs llvm intrinsics
paul0403 May 22, 2025
7304b9f
patch away shardy in mhlo
paul0403 May 22, 2025
87aacce
just avoid CI enzyme cache for now
paul0403 May 22, 2025
e99392d
typo
paul0403 May 22, 2025
1d27581
more typo
paul0403 May 22, 2025
32220e5
The enzyme right now is db0181320d6e425ee963bd496ed0d8dbb615be18
paul0403 May 22, 2025
5113a23
Merge remote-tracking branch 'origin/main' into paul0403/update_llvm
paul0403 May 26, 2025
1d3f3e3
valuerange/typerange mixup from some git mixup before
paul0403 May 26, 2025
3c6fbed
a missed override removel on match/rewrite
paul0403 May 26, 2025
fb2bbf1
'applyOpPatternsAndFold' is deprecated: Use applyOpPatternsGreedily()…
paul0403 May 26, 2025
41933e4
mlir::LLVM::lookupOrCreateFn() now returns FailureOr<llvmfuncop> inst…
paul0403 May 26, 2025
1eb8c9b
four missed match/rewrite overrides
paul0403 May 26, 2025
252c494
a missed applyPatternsGreedily
paul0403 May 26, 2025
f080ce6
```
paul0403 May 26, 2025
1523018
more match/rewrite override and applypatternsgreedily
paul0403 May 26, 2025
2cafd50
override
paul0403 May 26, 2025
d1f52eb
override
paul0403 May 26, 2025
2123432
another lookuporcreate/FailureOr/.value()
paul0403 May 26, 2025
40d42c8
merge `match()` and `rewrite()` into `matchAndRewrite()` for a bunch …
paul0403 May 26, 2025
a485d9e
return success() in all branches in gradient return op conversion pat…
paul0403 May 26, 2025
5aeca47
OneShotBufferizationOptions -> OneShotBufferizePassOptions
paul0403 May 26, 2025
3687033
one-shot bufferization layout options don't have to be lambdas!
paul0403 May 26, 2025
a6e4fbe
mlir::call_interface_impl::resolveCallable
paul0403 May 26, 2025
7fff30f
format
paul0403 May 26, 2025
26ffeb0
machine target triples no longer do the `Triple`->`std::string`->`Tri…
paul0403 May 26, 2025
ad96d6e
ops with CallOpInterface must have two new optional attrs `arg_attrs`…
paul0403 May 26, 2025
9306998
createBufferizationToMemRefPass -> createConvertBufferizationToMemRef…
paul0403 May 27, 2025
fd219d0
mhlo removed 3 passes:
paul0403 May 27, 2025
24e9cce
createConvertSCFToCFPass -> createSCFToControlFlowPass :sweat_smile_a…
paul0403 May 27, 2025
d0ee361
-buffer-deallocation pass is removed. Use -ownership-based-buffer-dea…
paul0403 May 27, 2025
02ea1f0
typo
paul0403 May 27, 2025
2a457a4
--buffer-deallocation is removed.
paul0403 May 27, 2025
df8a289
3 missed matchAndRewrite migration
paul0403 May 27, 2025
51050d9
the three removed stablehlo passes are necessary.
paul0403 May 27, 2025
a0f6d69
enable --buffer-deallocation-pipeline;
paul0403 May 27, 2025
95bbe03
OpAdaptor -> OneToNOpAdaptor for listpush, pop, dealloc ops in cataly…
paul0403 May 27, 2025
45be033
mhlo patch TODO cleanup
paul0403 May 27, 2025
759ce19
Fix adjoint gradient.
paul0403 May 27, 2025
36a11b8
GreedyRewriteConfig.strictMode, enableRegionSimplification, maxIterat…
paul0403 May 27, 2025
d060f06
the lowering of cf::AssertOp to llvm was split from the overall --con…
paul0403 May 27, 2025
aab184d
a missed `OneToNOpAdaptor` on catalyst.list_load_data op
paul0403 May 27, 2025
2749846
bufferization.to_memref updated assembly format:
paul0403 May 27, 2025
e248670
add expval return in gradient adjoint's to-llvm lowering lit test
paul0403 May 28, 2025
b10f0e8
to_tensor and to_memref add `to` in assembly format in gradient postp…
paul0403 May 28, 2025
6e3c0c9
totensor tomemref assembly format in catalyst conversion lit test
paul0403 May 28, 2025
9f050c2
update quantum conversion lit test:
paul0403 May 28, 2025
9e2cc07
temporarily diable Quantum/AllocationTest.mlir
paul0403 May 28, 2025
bc1b37a
gradient adjoint minor update: don't just force a return type
paul0403 May 28, 2025
f8f9fe3
a std::optional include order
paul0403 May 28, 2025
4037fb7
typo
paul0403 May 28, 2025
e198cd3
move QuantumResource from quantum ops to quantum dialect
paul0403 May 28, 2025
21d9dba
MQBC measurement in basis has QuantumMemory read/write
paul0403 May 28, 2025
1dbd529
add quantum memory effects to quantum ops
paul0403 May 28, 2025
860512f
gates only read, do not write
paul0403 May 28, 2025
aa1957d
Add mem effects to gradient ops.
paul0403 May 28, 2025
3b85584
Add mem effects to catalyst.list ops
paul0403 May 28, 2025
82ac61f
printop is mem read and alloc
paul0403 May 28, 2025
e487a2e
assert op is memwrite
paul0403 May 28, 2025
7790d5e
global phase needs to write
paul0403 May 28, 2025
5c5ac37
xfail three gradient tests
paul0403 May 28, 2025
e1013fd
Quantum/AllocationTest.mlir uses the removed --buffer-deallocation pa…
paul0403 May 28, 2025
534150e
Merge remote-tracking branch 'origin/main' into paul0403/update_llvm
paul0403 May 28, 2025
6fd2dba
changelog
paul0403 May 29, 2025
e1d3c96
add buffer deallocation pipeline pass in cpp pipeline
paul0403 May 29, 2025
8a2b317
don't walk in the adjoint lowering pass when erasing device release a…
paul0403 May 29, 2025
31d98be
remove commented-out `bufferizesToAllocation` methods
paul0403 May 29, 2025
1ede6fb
git-style patch shardy
paul0403 May 29, 2025
03ada9e
git-style patch removed mhlo passes
paul0403 May 29, 2025
21bf3ac
mem alloc -> write on print op
paul0403 May 29, 2025
2acb631
fix enzyme tbaa: some `index` ops are pointers, but we used to always…
paul0403 May 29, 2025
281df43
is CI auto merging main??
paul0403 May 30, 2025
0ca7dc8
Merge remote-tracking branch 'origin/main' into paul0403/update_llvm
paul0403 May 30, 2025
6786286
a `applyPatternsAndFoldGreedily` fix from new stuff on main
paul0403 May 30, 2025
04c526d
add "compilers developers only" in changelog
paul0403 May 30, 2025
ab31034
just return the reg in quantum/allocation list test instead of a manu…
paul0403 May 30, 2025
d00284c
Update mlir/lib/Quantum/Transforms/emit_catalyst_pyface.cpp
paul0403 May 30, 2025
6897130
Update mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
paul0403 May 30, 2025
611d094
Update mlir/lib/Catalyst/Transforms/InlineNestedModules.cpp
paul0403 May 30, 2025
2112b21
Update mlir/lib/Catalyst/Transforms/DetectQNodes.cpp
paul0403 May 30, 2025
f0dd569
plumbing savepoint
paul0403 May 30, 2025
376e606
revert all added mem effects in tablegen
paul0403 May 30, 2025
6971a48
add back old buffer dealloc pass; credit llvm authors
paul0403 May 30, 2025
78212d1
restore quantum allocation test (that was impacted by buffer dealloc …
paul0403 May 30, 2025
3bfb843
Merge remote-tracking branch 'origin/main' into paul0403/update_llvm
paul0403 Jun 2, 2025
0f54789
remove "For example, a gate operation will both read from and write t…
paul0403 Jun 2, 2025
e9fbb2d
git apply --check on mhlo; remove "clean-mhlo" at the start of "make …
paul0403 Jun 2, 2025
c44deb2
Update tutorial to reflect the removal of separate match and rewrite …
mehrdad2m Jun 2, 2025
59ccb5a
add back bufferizesToAllocation
paul0403 Jun 2, 2025
b04b059
remove mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
paul0403 Jun 2, 2025
b85e6a1
pin nanobind==2.4 to requirements.txt
paul0403 Jun 2, 2025
8fa6fd9
Update mlir/lib/Gradient/Transforms/GradMethods/Adjoint.cpp
paul0403 Jun 2, 2025
5b32057
getTerminator directly for returnop in grad adjoint
paul0403 Jun 2, 2025
4cc38e4
use mapper to avoid iterating within the clone
paul0403 Jun 2, 2025
034098e
collect everything by iterating over `callee` once at the beginning
paul0403 Jun 2, 2025
f98af5d
remove buffer dealloc patch (we patch this entire pass now anyways)
paul0403 Jun 2, 2025
901478e
update patches in CI wheels script
paul0403 Jun 2, 2025
ed563b0
fix patch path in wheels script
paul0403 Jun 2, 2025
65350a3
fix wheel script again (no need to cd)
paul0403 Jun 2, 2025
8ff300f
remove ;\ in make clean-mhlo
paul0403 Jun 2, 2025
7a7e763
dyncast->cast
paul0403 Jun 2, 2025
04a221b
wheels script (where is $GITHUB_WORKSPACE??)
paul0403 Jun 2, 2025
90a184f
indent back in quantumops tablegen
paul0403 Jun 2, 2025
13853d5
return when we see non expval measurements
paul0403 Jun 2, 2025
1240c6a
only one device release; allow multiple expval; remove added expval i…
paul0403 Jun 2, 2025
5d249e2
one block = no unstructured control flow
paul0403 Jun 2, 2025
ed103bd
changelog add link; use "mlih-hlo"
paul0403 Jun 2, 2025
78ab6f1
(wheels script) I do need to cd...
paul0403 Jun 2, 2025
84ddaaf
(wheels script) I'm very confused
paul0403 Jun 2, 2025
2cd9f34
(wheels script) actually, no need to check patch success in wheels CI
paul0403 Jun 2, 2025
24d96ca
check that memref is the first two args to dealloc helper
paul0403 Jun 2, 2025
b6833bf
check for arg number in TBAA dealloc helper checker
paul0403 Jun 3, 2025
4aa8ddc
Merge remote-tracking branch 'origin/main' into paul0403/update_llvm
paul0403 Jun 3, 2025
f7455dd
backward slice in isFromExtractAlignedPointerAsIndexOp
paul0403 Jun 3, 2025
76613a2
(wheels script) `pushd . && cd blad` -> `pushd blah`
paul0403 Jun 3, 2025
af74ac6
wheels CI
paul0403 Jun 3, 2025
c2e8247
wheels CI check mhlo, delete cache
paul0403 Jun 3, 2025
f399ec8
pushd not cd-ing?
paul0403 Jun 3, 2025
d8aa363
remove prints in wheels script
paul0403 Jun 3, 2025
cff7ea1
clean all mhlo CI cache and try wheels
paul0403 Jun 3, 2025
5449518
wheels build passed, check standard CI
paul0403 Jun 3, 2025
36fdf95
clean CI cache and try again
paul0403 Jun 3, 2025
33e287f
small change to mhlo cache key (the old one is stuck in a reserved st…
paul0403 Jun 3, 2025
38cd129
nanobind requirement is >=2.4, no need to be 2.4 exactly
paul0403 Jun 3, 2025
0084fe7
explain why nanobind>=2.4 in requirements.txt
paul0403 Jun 3, 2025
ece9f19
Update requirements.txt
paul0403 Jun 3, 2025
6f4c162
Update .github/workflows/check-catalyst.yaml
paul0403 Jun 3, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 5 additions & 11 deletions .dep-versions
Original file line number Diff line number Diff line change
@@ -1,16 +1,10 @@
# Always update the version check in catalyst.__init__ when changing the JAX version.

#############
# We track mlir submodule versions from jax 0.4.32 for now
# These are the earliest versions with complete upstream bufferization changes
# Versions are retrieved from
# python3 .github/workflows/set_dep_versions.py 0.4.32
#############

# To update JAX version alongside compatible dependency tags, run the following script:
# python3 .github/workflows/set_dep_versions.py {JAX_version}
jax=0.6.0
mhlo=25b008569f413d76cfa8f481f3a84e82b89c47f4
llvm=5f74671c85877e03622e8d308aee15ed73ccee7c
enzyme=v0.0.149
mhlo=617a9361d186199480c080c9e8c474a5e30c22d1
llvm=179d30f8c3fddd3c85056fd2b8e877a4a8513158
enzyme=v0.0.180

# Always remove custom PL/LQ versions before release.

Expand Down
14 changes: 10 additions & 4 deletions .github/workflows/build-wheel-linux-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -187,9 +187,10 @@ jobs:
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH

export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
pushd $GITHUB_WORKSPACE/mlir/mlir-hlo
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
popd

cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
Expand All @@ -215,14 +216,19 @@ jobs:
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH

export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi

cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \
-DENZYME_STATIC_LIB=ON \
-DCMAKE_CXX_VISIBILITY_PRESET=default \
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"

cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21

- name: Save Enzyme Build
id: save-enzyme-build
Expand Down
14 changes: 10 additions & 4 deletions .github/workflows/build-wheel-linux-x86_64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -210,9 +210,10 @@ jobs:
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH

export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
pushd $GITHUB_WORKSPACE/mlir/mlir-hlo
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
popd

cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
Expand All @@ -238,14 +239,19 @@ jobs:
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH

export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi

cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \
-DENZYME_STATIC_LIB=ON \
-DCMAKE_CXX_VISIBILITY_PRESET=default \
-DCMAKE_CXX_FLAGS="-fuse-ld=lld"

cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21

- name: Save Enzyme Build
id: save-enzyme-build
Expand Down
13 changes: 9 additions & 4 deletions .github/workflows/build-wheel-macos-arm64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -185,9 +185,10 @@ jobs:
run: |
export PATH=$GITHUB_WORKSPACE/llvm-build/bin:$PATH

export TARGET_FILE=mlir/mlir-hlo/mhlo/transforms/CMakeLists.txt
export PATCH_FILE=mlir/patches/mhlo-Add-PassesIncGen-in-transforms-CMakeList.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi
pushd $GITHUB_WORKSPACE/mlir/mlir-hlo
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-remove-shardy.patch
git apply $GITHUB_WORKSPACE/mlir/patches/mhlo-add-back-necessary-passes.patch
popd

cmake -S mlir/mlir-hlo -B $GITHUB_WORKSPACE/mhlo-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
Expand All @@ -212,13 +213,17 @@ jobs:
- name: Build Enzyme
if: steps.cache-enzyme-build.outputs.cache-hit != 'true'
run: |
export TARGET_FILE=mlir/Enzyme/enzyme/Enzyme/TypeAnalysis/TypeAnalysis.cpp
export PATCH_FILE=mlir/patches/enzyme-nvvm-fabs-intrinsics.patch
if patch --dry-run -p1 -N $TARGET_FILE $PATCH_FILE > /dev/null 2>&1; then patch -p1 $TARGET_FILE $PATCH_FILE; fi

cmake -S mlir/Enzyme/enzyme -B $GITHUB_WORKSPACE/enzyme-build -G Ninja \
-DCMAKE_BUILD_TYPE=Release \
-DLLVM_DIR="$GITHUB_WORKSPACE/llvm-build/lib/cmake/llvm" \
-DENZYME_STATIC_LIB=ON \
-DCMAKE_CXX_VISIBILITY_PRESET=default

cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-20
cmake --build $GITHUB_WORKSPACE/enzyme-build --target EnzymeStatic-21

- name: Save Enzyme Build
id: save-enzyme-build
Expand Down
6 changes: 3 additions & 3 deletions .github/workflows/check-catalyst.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ jobs:
sudo apt-get update
sudo apt-get install -y python3 python3-pip cmake ninja-build clang lld
python3 --version | grep ${{ needs.constants.outputs.primary_python_version }}
python3 -m pip install numpy pybind11
python3 -m pip install numpy pybind11 nanobind

- name: Build LLVM
if: steps.cache-llvm-build.outputs.cache-hit != 'true'
Expand Down Expand Up @@ -194,7 +194,7 @@ jobs:
uses: actions/cache@v4
with:
path: mhlo-build
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0

- name: Get Cached LLVM Source
id: cache-llvm-source
Expand Down Expand Up @@ -351,7 +351,7 @@ jobs:
uses: actions/cache/restore@v4
with:
path: mhlo-build
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}
key: ${{ runner.os }}-mhlo-${{ needs.constants.outputs.mhlo_version }}-ci-build-${{ matrix.compiler }}-0
fail-on-cache-miss: true

- name: Get Cached Enzyme Source
Expand Down
33 changes: 14 additions & 19 deletions .github/workflows/set_dep_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,11 +32,11 @@
assert os.path.isfile(dep_versions_path)
assert os.path.isfile(catalyst_init_path)

url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/WORKSPACE"
url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/WORKSPACE"
response = requests.get(url)
match = re.search(r'strip_prefix = "xla-([a-zA-Z0-9]*)"', response.text)
if not match:
url = f"https://raw.githubusercontent.com/google/jax/jaxlib-v{jax_version}/third_party/xla/workspace.bzl"
url = f"https://raw.githubusercontent.com/jax-ml/jax/jax-v{jax_version}/third_party/xla/workspace.bzl"
response = requests.get(url)
match = re.search(r'XLA_COMMIT = "([a-zA-Z0-9]*)"', response.text)
xla_commit = match.group(1)
Expand Down Expand Up @@ -67,21 +67,16 @@
response = requests.get(url).json()
hlo_commit = response["items"][0]["sha"]

existing_text = open(dep_versions_path, "r", encoding="UTF-8").read()
match = re.search(r"enzyme=([a-zA-Z0-9]*)", existing_text)
enzyme_commit = match.group(1)

with open(dep_versions_path, "w", encoding="UTF-8") as f:
f.write(
f"""\
jax={jax_version}
mhlo={hlo_commit}
llvm={llvm_commit}
enzyme={enzyme_commit}
"""
)

quote = '"'
cmd = f"sed -i 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}"
res = os.system(cmd)
assert res == 0
# Update each version using sed
cmds = [
f"sed -i '' 's/^jax=.*/jax={jax_version}/' {dep_versions_path}",
f"sed -i '' 's/^mhlo=.*/mhlo={hlo_commit}/' {dep_versions_path}",
f"sed -i '' 's/^llvm=.*/llvm={llvm_commit}/' {dep_versions_path}",
# Update jaxlib version in __init__.py
rf"sed -i '' 's/_jaxlib_version = {quote}\([0-9.]\+\){quote}/_jaxlib_version = {quote}{jax_version}{quote}/g' {catalyst_init_path}",
]

for cmd in cmds:
res = os.system(cmd)
assert res == 0
72 changes: 47 additions & 25 deletions doc/dev/transforms.rst
Original file line number Diff line number Diff line change
Expand Up @@ -252,23 +252,22 @@ Note how the value ``%q2`` links the two operations together from definition ``(
across several other instructions.

As seen in the `pattern rewriter documentation <https://mlir.llvm.org/docs/PatternRewriter/#defining-patterns>`_,
a new rewrite pattern can be defined as a C++ class as follows, where we will focus on the ``match``
and ``rewrite`` methods (refer to the link for the full class and up to date information):
a new rewrite pattern can be defined as a C++ class as follows, where we will focus on the
``matchAndRewrite`` method (refer to the link for the full class and up to date information):

.. code-block:: cpp

struct QubitUnitaryFusion : public OpRewritePattern<QubitUnitaryOp>
{
...

LogicalResult match(QubitUnitaryOp op) const override {
// The ``match`` method returns ``success()`` if the pattern is a match, failure
// otherwise.
}

void rewrite(QubitUnitaryOp op, PatternRewriter &rewriter) {
// The ``rewrite`` method performs mutations on the IR rooted at ``op`` using
// the provided rewriter. All mutations must go through the provided rewriter.
LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override {
// The `matchAndRewrite` method performs both the pattern matching and the mutation
// on the IR rooted at `op` using the provided rewriter.
// All mutations must go through the provided rewriter and IR mutation should only
// take place after the match is deemed successful.
// matchAndRewrite must return "success" if and only if the IR was modified.
// The root operation is required to either be: updated in-place, replaced, or erased.
}

...
Expand All @@ -286,11 +285,11 @@ the second is a list of qubits):

QubitUnitary(*, QubitUnitary(*, *))

Let's implement it in C++:
Let's add the pattern-matching logic to the ``matchAndRewrite`` method:

.. code-block:: cpp

LogicalResult match(QubitUnitaryOp op) const override
LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override
{
ValueRange qbs = op.getInQubits();
Operation *parent = qbs[0].getDefiningOp();
Expand All @@ -314,6 +313,9 @@ Let's implement it in C++:
return failure();
}

// Rewrite logic
// ... We have matched the pattern, now rewrite the IR here

return success();
}

Expand Down Expand Up @@ -351,8 +353,8 @@ MLIR will automatically generate canonical ``get*`` methods for attributes like
``out_qubits``, and ``matrix``. When in doubt it's best to have a look at the generated C++ files in
the build folder, named ``QuantumOps.h.inc`` and ``QuantumOps.cpp.inc`` in this instance.

Alright, now that we have the matching part, let's implement the actual transformation via the
``rewrite`` method. All we need to do is replace the original pattern with the following:
Alright, now that we have the matching part, let's add the actual transformation to the
``matchAndRewrite`` method. All we need to do is replace the original pattern with the following:

.. code-block::

Expand All @@ -362,8 +364,13 @@ In C++ it will look as follows:

.. code-block:: cpp

void rewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override
LogicalResult matchAndRewrite(QubitUnitaryOp op, PatternRewriter &rewriter) const override
{

// Pattern matching logic
// ... match the pattern

// Rewrite logic
ValueRange qbs = op.getInQubits();
QubitUnitaryOp parentOp = cast<QubitUnitaryOp>(qbs[0].getDefiningOp());

Expand Down Expand Up @@ -410,11 +417,13 @@ In C++ it will look as follows:
// The second unitary is not needed anymore
// Whoever uses the second unitary, use the first one instead!
op.replaceAllUsesWith(parentOp);

return success();
}

When writing transformations, the rewriter is the most important tool we have. It can create new
operations for us, delete others, or change the place in the IR where we are choosing to make
changes (also called the insertion point). Let's have look at some of these elements:
changes (also called the insertion point). Let's have a look at some of these elements:

- **Constructing new operations**:

Expand Down Expand Up @@ -512,15 +521,15 @@ and other function operations, which themselves can contain other operations, an
quantumPatterns.add<QubitUnitaryFusion>(ctx);

// Apply patterns in an iterative and greedy manner.
if (failed(applyPatternsAndFoldGreedily(op, std::move(quantumPatterns)))) {
if (failed(applyPatternsGreedily(op, std::move(quantumPatterns)))) {
return signalPassFailure();
}
}
};

To apply patterns we need a `pattern applicator <https://mlir.llvm.org/docs/PatternRewriter/#common-pattern-drivers>`_.
There a few in MLIR but typically you can just use the greedy pattern rewrite driver
(``applyPatternsAndFoldGreedily``), which will iterative over the IR and apply patterns until a
(``applyPatternsGreedily``), which will iterative over the IR and apply patterns until a
fixed point is reached.

.. note::
Expand Down Expand Up @@ -565,21 +574,30 @@ gradient ops that specify the finite-difference method, indicated via the ``"fd"

.. code-block:: cpp

LogicalResult FiniteDiffLowering::match(GradOp op)
LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter)
{
if (op.getMethod() == "fd")
return success();
// Pattern matching logic
if (op.getMethod() != "fd")
return failure();

return failure();
// Rewrite logic
// ...

return success();
}

For the rewriting part we'll want to introduce a few new elements, such as looking up symbols
(function names), creating new functions, and changing the insertion point.

.. code-block:: cpp

void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter)
LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter)
{
// Pattern matching logic
if (op.getMethod() != "fd")
return failure();

// Rewrite logic
// First let's find the function the grad operation is referencing.
func::FuncOp callee =
SymbolTable::lookupNearestSymbolFrom<func::FuncOp>(op, op.getCalleeAttr());
Expand Down Expand Up @@ -609,6 +627,8 @@ For the rewriting part we'll want to introduce a few new elements, such as looki
// Populate the function body.
populateFiniteDiffMethod(rewriter, op, gradFn);
}

return success();
}

Symbols are string references to IR objects, which rather than containing a physical reference or
Expand Down Expand Up @@ -711,18 +731,20 @@ Alright, our function should now look something like this:
func.return %dx, %dy, %dz : f64, f64, f64
}

Finally, we have to amend our rewrite function to invoke the new function we created and delete the
Finally, we have to amend our ``matchAndRewrite`` function to invoke the new function we created and delete the
``GradOp`` from the IR:

.. code-block:: cpp

void FiniteDiffLowering::rewrite(GradOp op, PatternRewriter &rewriter)
LogicalResult FiniteDiffLowering::matchAndRewrite(GradOp op, PatternRewriter &rewriter)
{
...
populateFiniteDiffMethod(rewriter, op, gradFn);
}

rewriter.replaceOpWithNewOp<func::CallOp>(op, gradFn, op.getArgOperands());

return success();
}

Note how we can create a new operation, take its results, and use those to replace another operation
Expand Down
Loading