-
Notifications
You must be signed in to change notification settings - Fork 14
Open
Description
hugr-qir will need to inline all functions to produce valid qir.
We already have a Callgraph in hugr-passes, which is required for inlining.
I suggest that the inlining pass take a set of Call nodes as input, and that if these calls contain a cycle we error. Otherwise inline all those calls. That way we don't dictate the inlining policty, and in particular hugr-qir can easily "inline everything"
Prototype, which is pre our callgraph and perhaps out of date in other ways:
use hugr_core::{
extension::ExtensionRegistry,
hugr::{
hugrmut::HugrMut,
views::{DescendantsGraph, ExtractHugr as _, HierarchyView},
HugrError, Rewrite, ValidationError,
},
ops::{DataflowOpTrait as _, OpTrait, DFG},
Direction, HugrView, Node,
};
use itertools::Itertools as _;
use petgraph::visit::EdgeRef as _;
use thiserror::Error;
use crate::validation::ValidationLevel;
#[derive(Debug, Clone, Default)]
/// TODO docs
pub struct InlinePass {
validation: ValidationLevel,
}
impl InlinePass {
/// Sets the validation level used before and after the pass is run
pub fn validation_level(mut self, level: ValidationLevel) -> Self {
self.validation = level;
self
}
pub fn run(
&self,
hugr: &mut impl HugrMut,
registry: &ExtensionRegistry,
) -> Result<(), Box<dyn std::error::Error>> {
self.validation
.run_validated_pass_mut(hugr, registry, |hugr, _| {
let mut calls = {
let cg = CallGraph::new(hugr);
let Some(calls) = cg.iter_nonrecursive() else {
Err("InlinePass: recursion")?
};
let mut calls = calls.collect_vec();
calls.reverse();
calls
};
// dbg!(&calls);
let rewrites = calls
.iter()
.filter_map(|(caller, _)| InlineRewrite::try_new(hugr, *caller, registry).ok())
.collect_vec();
for rewrite in rewrites {
hugr.apply_rewrite(rewrite).unwrap();
}
calls.reverse();
for func_node in calls.into_iter().map(|x| x.1).dedup() {
let Some(func) = hugr.get_optype(func_node).as_func_defn() else {
panic!("impossible")
};
if hugr.linked_inputs(func_node, 0).count() == 0 && func.name != "main" {
// eprintln!("Removing func: {}", func.name);
let func_hugr = DescendantsGraph::<Node>::try_new(hugr, func_node).unwrap();
let to_delete = func_hugr.nodes().dedup().collect_vec();
for n in to_delete {
hugr.remove_node(n);
}
}
}
hugr.validate(registry)?;
Ok(())
})
}
}
pub struct CallGraph {
g: petgraph::graph::Graph<Node, Node>,
}
fn func_of_node(hugr: &impl HugrView, node: Node) -> Option<Node> {
let mut n = node;
while let Some(parent) = hugr.get_parent(n) {
if hugr.get_optype(parent).is_func_defn() {
return Some(parent);
}
n = parent;
}
None
}
impl CallGraph {
pub fn new(hugr: &impl HugrView) -> Self {
let mut g: petgraph::graph::Graph<Node, Node> = Default::default();
let node_to_cg: HashMap<_, _> = hugr
.nodes()
.filter(|&n| (hugr.get_optype(n).is_func_decl() || hugr.get_optype(n).is_func_defn()))
.map(|n| (n, g.add_node(n)))
.collect();
for n in hugr.nodes() {
if let Some(call) = hugr.get_optype(n).as_call() {
if let Some(caller_func) = func_of_node(hugr, n) {
if let Some((callee_func, _)) =
hugr.single_linked_output(n, call.called_function_port())
{
g.add_edge(node_to_cg[&caller_func], node_to_cg[&callee_func], n);
}
}
}
}
Self { g }
}
pub fn iter_nonrecursive(&self) -> Option<impl Iterator<Item = (Node, Node)> + '_> {
let funcs = petgraph::algo::toposort(&self.g, None).ok()?;
Some(funcs.into_iter().flat_map(move |f| {
self.g
.edges(f)
.map(move |e| (*e.weight(), self.g[e.target()]))
}))
}
}
pub struct InlineRewrite<'a> {
call: Node,
func: Node,
registry: &'a ExtensionRegistry,
}
impl<'a> InlineRewrite<'a> {
pub fn try_new(
hugr: &impl HugrView,
call: Node,
registry: &'a ExtensionRegistry,
) -> Result<Self, InlineRewriteError> {
if !hugr.valid_node(call) {
Err(InlineRewriteError::InvalidCall)?
}
let Some(call_ot) = hugr.get_optype(call).as_call() else {
Err(InlineRewriteError::InvalidCall)?
};
let Some((func, _)) = hugr.single_linked_output(call, call_ot.called_function_port())
else {
Err(InlineRewriteError::InvalidCall)?
};
if !hugr.get_optype(func).is_func_defn() {
Err(InlineRewriteError::InvalidFunction)?
}
let r = Self {
call,
func,
registry,
};
debug_assert!(r.verify(hugr).is_ok());
Ok(r)
}
}
#[derive(Debug, Clone, Error)]
pub enum InlineRewriteError {
#[error("Invalid Function")]
InvalidFunction,
#[error("Invalid Call")]
InvalidCall,
#[error("Call does not target func")]
Invalid,
#[error(transparent)]
HugrError(#[from] HugrError),
#[error(transparent)]
Validation(#[from] ValidationError),
}
impl<'a> Rewrite for InlineRewrite<'a> {
type Error = InlineRewriteError;
type ApplyResult = ();
const UNCHANGED_ON_FAILURE: bool = true;
fn verify(&self, h: &impl HugrView) -> Result<(), Self::Error> {
let Some(call) = h.get_optype(self.call).as_call() else {
Err(InlineRewriteError::InvalidCall)?
};
if !call.type_args.is_empty() {
Err(InlineRewriteError::InvalidCall)?
}
let Some(_) = h.get_optype(self.func).as_func_defn() else {
Err(InlineRewriteError::InvalidFunction)?
};
if let Some((n, _)) = h.single_linked_output(self.call, call.called_function_port()) {
if self.func != n {
Err(InlineRewriteError::Invalid)?
}
}
Ok(())
}
fn apply(self, h: &mut impl HugrMut) -> Result<Self::ApplyResult, Self::Error> {
self.verify(h)?;
// dbg!(self.call, self.func);
let func_hugr = DescendantsGraph::<Node>::try_new(h, self.func)
.map_err(|_| InlineRewriteError::InvalidFunction)?
.extract_hugr();
func_hugr.validate(self.registry)?;
let call = h.get_optype(self.call).as_call().unwrap().to_owned();
let call_parent = h.get_parent(self.call).unwrap();
let signature = call.signature();
let insertion = h.insert_hugr(call_parent, func_hugr);
let dfg_node = insertion.new_root;
let dfg = DFG { signature };
h.set_num_ports(
dfg_node,
dfg.signature().input_count() + dfg.non_df_port_count(Direction::Incoming),
dfg.signature().output_count() + dfg.non_df_port_count(Direction::Outgoing),
);
h.replace_op(dfg_node, dfg)?;
let connections = h
.node_inputs(self.call)
.filter(|&x| x != call.called_function_port())
.flat_map(|in_p| {
h.linked_outputs(self.call, in_p)
.map(move |(out_n, out_p)| (out_n, out_p, dfg_node, in_p))
})
.chain(h.node_outputs(self.call).flat_map(|out_p| {
h.linked_inputs(self.call, out_p)
.map(move |(in_n, in_p)| (dfg_node, out_p, in_n, in_p))
}))
.collect_vec();
for (from_n, from_p, to_n, to_p) in connections {
h.connect(from_n, from_p, to_n, to_p)
}
h.remove_node(self.call);
Ok(())
}
fn invalidation_set(&self) -> impl Iterator<Item = Node> {
[self.call, self.func].into_iter()
}
}
Metadata
Metadata
Assignees
Labels
No labels