Skip to content

Implement ppm_specs pass #1794

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 35 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
eacfe86
test new print spec pass
ritu-thombre99 May 30, 2025
24e0afd
fix some descriptions
ritu-thombre99 May 30, 2025
ca994e4
fix passes.td file
ritu-thombre99 Jun 1, 2025
54c2460
4 stats recorded
ritu-thombre99 Jun 2, 2025
1f6ac67
implement ppm_spec logic
ritu-thombre99 Jun 4, 2025
ba5fb0c
slicing working
ritu-thombre99 Jun 5, 2025
940a16a
revert ppm_compilation
ritu-thombre99 Jun 5, 2025
dca7989
working in frontend
ritu-thombre99 Jun 6, 2025
09512cf
format and comment
ritu-thombre99 Jun 6, 2025
2b53308
track with main
ritu-thombre99 Jun 6, 2025
aa16ead
check if opt contains json
ritu-thombre99 Jun 8, 2025
ef174bf
resolve PR comments
ritu-thombre99 Jun 11, 2025
3304c10
fix formatting
ritu-thombre99 Jun 11, 2025
726237a
track with main
ritu-thombre99 Jun 11, 2025
ddc30f1
check nullptr for num_qubit
ritu-thombre99 Jun 11, 2025
eeffd55
move get_ppm_spec() API function out of jit.py and compiler.py; remov…
paul0403 Jun 12, 2025
bc4efce
remove ppm_specs from unnecessary places
ritu-thombre99 Jun 12, 2025
2dc9fc2
code cleanup
ritu-thombre99 Jun 12, 2025
f17706a
assert dynamic qubit count
ritu-thombre99 Jun 12, 2025
05ef5c3
formatting
ritu-thombre99 Jun 12, 2025
c1d5322
track with main
ritu-thombre99 Jun 12, 2025
f72df33
resolve some PR comments
ritu-thombre99 Jun 12, 2025
19900c7
check op type instead of name:
ritu-thombre99 Jun 12, 2025
d88e70f
Merge branch 'main' into ritu/ppm_specs
ritu-thombre99 Jun 13, 2025
34aa5ff
separate count functions
ritu-thombre99 Jun 13, 2025
ecae329
remove snake_case
ritu-thombre99 Jun 13, 2025
1260335
add frontend pytests
ritu-thombre99 Jun 13, 2025
641a416
formatting changes
ritu-thombre99 Jun 13, 2025
341e75b
fix pytests
ritu-thombre99 Jun 13, 2025
29d5b74
mlir test
ritu-thombre99 Jun 13, 2025
b596d18
update with main
ritu-thombre99 Jun 13, 2025
04e914e
mlir test with ppm specs
ritu-thombre99 Jun 13, 2025
e4c17b9
fix asserts
ritu-thombre99 Jun 13, 2025
a2265be
lit tests
ritu-thombre99 Jun 13, 2025
b282b0b
lit tests
ritu-thombre99 Jun 13, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 28 additions & 1 deletion frontend/catalyst/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,12 @@
"""
import glob
import importlib
import json
import logging
import os
import pathlib
import platform
import re
import shutil
import subprocess
import sys
Expand Down Expand Up @@ -371,7 +373,32 @@
return _quantum_opt(*args, stdin=stdin)

opts = _options_to_cli_flags(options)
return _quantum_opt(*opts, *args, stdin=stdin)
raw_result = _quantum_opt(*opts, *args, stdin=stdin)
regex_search_for_json = re.search(r"\{[a-zA-Z0-9_\":\{\},\n]+\}", raw_result)
raw_result = raw_result.replace(regex_search_for_json.group(0), "")
return raw_result


def to_ppm_spec(*args, stdin=None, options: Optional[CompileOptions] = None):
"""echo ${input} | catalyst --tool=opt *args *opts -"""
# These are the options that may affect compilation
if not options:
return _quantum_opt(*args, stdin=stdin)

opts = _options_to_cli_flags(options)

raw_json_format = _quantum_opt(*opts, *args, stdin=stdin)
regex_search_for_json = re.search(r"\{[a-zA-Z0-9_\":\{\},\n]+\}", raw_json_format)

# No ppm_specs json is found
if regex_search_for_json is None:
return raw_json_format

Check notice on line 396 in frontend/catalyst/compiler.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/compiler.py#L396

Trailing whitespace (trailing-whitespace)
json_ppm_specs = regex_search_for_json.group(0)
json_ppm_specs = json_ppm_specs.replace(",\n}", "\n}")
json_ppm_specs = json.loads(json_ppm_specs)
return json_ppm_specs



class Compiler:
Expand Down
8 changes: 8 additions & 0 deletions frontend/catalyst/jit.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
canonicalize,
to_llvmir,
to_mlir_opt,
to_ppm_spec,
)
from catalyst.debug.instruments import instrument
from catalyst.from_plxpr import trace_from_pennylane
Expand Down Expand Up @@ -578,6 +579,13 @@ def mlir_opt(self):

return to_mlir_opt(stdin=str(self.mlir_module), options=self.compile_options)

def get_ppm_spec(self):
"""obtain the PPM specs after optimization"""
if not self.mlir_module:
return None

return to_ppm_spec(stdin=str(self.mlir_module), options=self.compile_options)

@debug_logger
def __call__(self, *args, **kwargs):
# Transparantly call Python function in case of nested QJIT calls.
Expand Down
2 changes: 2 additions & 0 deletions frontend/catalyst/passes/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
merge_ppr_ppm,
merge_rotations,
ppm_compilation,
ppm_specs,
ppr_to_ppm,
to_ppr,
)
Expand All @@ -58,4 +59,5 @@
"merge_ppr_ppm",
"ppr_to_ppm",
"ppm_compilation",
"ppm_specs",
)
5 changes: 5 additions & 0 deletions frontend/catalyst/passes/builtin_passes.py
Original file line number Diff line number Diff line change
Expand Up @@ -781,3 +781,8 @@
)

return PassPipelineWrapper(qnode, passes)


def ppm_specs(qnode):

Check notice on line 786 in frontend/catalyst/passes/builtin_passes.py

View check run for this annotation

codefactor.io / CodeFactor

frontend/catalyst/passes/builtin_passes.py#L786

Missing function or method docstring (missing-function-docstring)
# TODO: Add docstring
return PassPipelineWrapper(qnode, "ppm_specs")
1 change: 1 addition & 0 deletions frontend/catalyst/passes/pass_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,4 +379,5 @@ def _API_name_to_pass_name():
"merge_ppr_ppm": "merge_ppr_ppm",
"ppr_to_ppm": "ppr_to_ppm",
"ppm_compilation": "ppm_compilation",
"ppm_specs": "ppm_specs",
}
1 change: 1 addition & 0 deletions mlir/include/QEC/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,5 +29,6 @@ std::unique_ptr<mlir::Pass> createCommuteCliffordPastPPMPass();
std::unique_ptr<mlir::Pass> createDecomposeNonCliffordPPRPass();
std::unique_ptr<mlir::Pass> createDecomposeCliffordPPRPass();
std::unique_ptr<mlir::Pass> createCliffordTToPPMPass();
std::unique_ptr<mlir::Pass> createCountPPMSpecsPass();

} // namespace catalyst
10 changes: 10 additions & 0 deletions mlir/include/QEC/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -121,4 +121,14 @@ def CliffordTToPPMPass : Pass<"ppm_compilation"> {
let options = [MaxPauliSizeOption, DecomposeMethodOption, AvoidYMeasureOption];
}

def CountPPMSpecsPass : Pass<"ppm_specs"> {
let summary = "Count specs in Pauli Product Measurement operations.";

let dependentDialects = [
"catalyst::qec::QECDialect",
];

let constructor = "catalyst::createCountPPMSpecsPass()";
}

#endif // QEC_PASSES
1 change: 1 addition & 0 deletions mlir/lib/Catalyst/Transforms/RegisterAllPasses.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ void catalyst::registerAllCatalystPasses()
mlir::registerPass(catalyst::createCliffordTToPPMPass);
mlir::registerPass(catalyst::createDecomposeNonCliffordPPRPass);
mlir::registerPass(catalyst::createDecomposeCliffordPPRPass);
mlir::registerPass(catalyst::createCountPPMSpecsPass);
mlir::registerPass(catalyst::createDetensorizeSCFPass);
mlir::registerPass(catalyst::createDisableAssertionPass);
mlir::registerPass(catalyst::createDisentangleCNOTPass);
Expand Down
1 change: 1 addition & 0 deletions mlir/lib/QEC/Transforms/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ file(GLOB SRC
commute_clifford_t_ppr.cpp
CommuteCliffordPastPPM.cpp
commute_clifford_past_ppm.cpp
CountPPMSpecs.cpp
decompose_non_clifford_ppr.cpp
DecomposeNonCliffordPPR.cpp
decompose_clifford_ppr.cpp
Expand Down
126 changes: 126 additions & 0 deletions mlir/lib/QEC/Transforms/CountPPMSpecs.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright 2025 Xanadu Quantum Technologies Inc.

// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at

// http://www.apache.org/licenses/LICENSE-2.0

// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#define DEBUG_TYPE "ppm_specs"

#include "mlir/Analysis/SliceAnalysis.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"


#include "QEC/IR/QECDialect.h"
#include "QEC/Transforms/Patterns.h"
#include "Quantum/IR/QuantumOps.h"
#include "QEC/Utils/PauliStringWrapper.h"
#include <algorithm>
#include <string>

using namespace llvm;
using namespace mlir;
using namespace catalyst;
using namespace catalyst::qec;

namespace catalyst {
namespace qec {

#define GEN_PASS_DEF_COUNTPPMSPECSPASS
#define GEN_PASS_DECL_COUNTPPMSPECSPASS
#include "QEC/Transforms/Passes.h.inc"


struct CountPPMSpecsPass : public impl::CountPPMSpecsPassBase<CountPPMSpecsPass> {
using CountPPMSpecsPassBase::CountPPMSpecsPassBase;

void print_specs()
{
llvm::BumpPtrAllocator string_allocator;
llvm::DenseMap<StringRef, int> PPM_Specs;
PPM_Specs["num_logical_qubits"] = 0;
PPM_Specs["num_of_ppm"] = 0;

// Walk over all operations in the IR (could be ModuleOp or FuncOp)
getOperation()->walk([&](Operation *op) {
// Skip top-level container ops if desired
if (isa<ModuleOp>(op)) return;

// TODO: Remove debug in future
// llvm::outs()<<"\n-----------------------------MLIR------------------------------\n";
// op->print(llvm::outs());
// llvm::outs()<<"\n-----------------------------MLIR------------------------------\n";

StringRef gate_name = op->getName().getStringRef();

if (gate_name == "quantum.alloc") {
auto num_qubits_attr = op->getAttrOfType<mlir::IntegerAttr>("nqubits_attr");
u_int64_t num_qubits = num_qubits_attr ? static_cast<u_int64_t>(num_qubits_attr.getInt()) : 0;
PPM_Specs["num_logical_qubits"] = num_qubits;
}

if (gate_name == "qec.ppm") {
PPM_Specs["num_of_ppm"] = PPM_Specs["num_of_ppm"] + 1;
}

if (gate_name == "qec.ppr") {
auto rotation_attr = op->getAttrOfType<mlir::IntegerAttr>("rotation_kind");
auto pauli_product_attr = op->getAttrOfType<mlir::ArrayAttr>("pauli_product");
int16_t rotation_kind = rotation_attr ? static_cast<int16_t>(rotation_attr.getInt()) : 0;
if (rotation_kind) {
llvm::StringSaver saver(string_allocator);
StringRef num_pi_key = saver.save("num_pi"+std::to_string(abs(rotation_kind))+"_gates");
StringRef max_weight_pi_key = saver.save("max_weight_pi"+std::to_string(abs(rotation_kind)));

if (PPM_Specs.find(llvm::StringRef(num_pi_key)) == PPM_Specs.end()) {
PPM_Specs[num_pi_key] = 1;
PPM_Specs[max_weight_pi_key] = static_cast<int>(pauli_product_attr.size());
}
else {
PPM_Specs[num_pi_key] = PPM_Specs[num_pi_key] + 1;
PPM_Specs[max_weight_pi_key] = std::max(PPM_Specs[max_weight_pi_key], static_cast<int>(pauli_product_attr.size()));
}
}
}
// TODO: Implement depth using slicing
// mlir::SetVector <Operation *> backwardSlice;
// getBackwardSlice(op, &backwardSlice);
// llvm::outs()<<"\n-----------------------------SLICE-----------------------------\n";
// llvm::outs()<<"Backward slicing\n";
// for (Operation *o : backwardSlice) {
// if (o->getName().getStringRef() == "quantum.extract") {
// llvm::outs() << *o << "\n";
// }
// }
// llvm::outs()<<"\n-----------------------------SLICE------------------------------\n";
});

llvm::outs() << "{\n";
for (const auto &entry : PPM_Specs) {
llvm::outs() << '"' << entry.first << '"' << ":" << entry.second << ",\n";
}
llvm::outs() << "}\n";
return;
}

void runOnOperation() final
{
print_specs();
}
};

} // namespace qec

/// Create a pass for lowering operations in the `QECDialect`.
std::unique_ptr<Pass> createCountPPMSpecsPass() { return std::make_unique<CountPPMSpecsPass>(); }

} // namespace catalyst
Loading