Skip to content

feat: update hugr-passes to use visibility #2418

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion hugr-core/src/core.rs
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,8 @@ impl<N: HugrNode> std::fmt::Display for Wire<N> {
}

/// Marks [FuncDefn](crate::ops::FuncDefn)s and [FuncDecl](crate::ops::FuncDecl)s as
/// to whether they should be considered for linking.
/// to whether they should be considered for linking, and as reachable (starting points)
/// for optimization/analysis.
#[derive(
Clone,
Debug,
Expand Down
2 changes: 1 addition & 1 deletion hugr-passes/src/call_graph.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ pub enum CallGraphNode<N = Node> {
FuncDecl(N),
/// petgraph-node corresponds to a [`FuncDefn`](OpType::FuncDefn) node (specified) in the Hugr
FuncDefn(N),
/// petgraph-node corresponds to the root node of the hugr, that is not
/// petgraph-node corresponds to the entrypoint node of the hugr, that is not
/// a [`FuncDefn`](OpType::FuncDefn). Note that it will not be a [Module](OpType::Module)
/// either, as such a node could not have outgoing edges, so is not represented in the petgraph.
NonFuncRoot,
Expand Down
77 changes: 59 additions & 18 deletions hugr-passes/src/const_fold.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,38 +15,44 @@ use hugr_core::{
};
use value_handle::ValueHandle;

use crate::dataflow::{
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
partial_from_const,
};
use crate::dead_code::{DeadCodeElimPass, PreserveNode};
use crate::{ComposablePass, composable::validate_if_test};
use crate::{
VisPolicy,
dataflow::{
ConstLoader, ConstLocation, DFContext, Machine, PartialValue, TailLoopTermination,
partial_from_const,
},
};

#[derive(Debug, Clone, Default)]
/// A configuration for the Constant Folding pass.
///
/// Note that by default we assume that only the entrypoint is reachable and
/// only if it is not the module root; see [Self::with_inputs]. Mutation
/// occurs anywhere beneath the entrypoint.
pub struct ConstantFoldPass {
allow_increase_termination: bool,
/// Each outer key Node must be either:
/// - a `FuncDefn` child of the root, if the root is a module; or
/// - the root, if the root is not a Module
/// - a `FuncDefn` child of the module-root
/// - the entrypoint
inputs: HashMap<Node, HashMap<IncomingPort, Value>>,
}

#[derive(Clone, Debug, Error, PartialEq)]
#[non_exhaustive]
/// Errors produced by [`ConstantFoldPass`].
pub enum ConstFoldError {
/// Error raised when a Node is specified as an entry-point but
/// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor
/// a [Conditional](OpType::Conditional).
/// Error raised when inputs are provided for a Node that is neither a dataflow
/// parent, nor a [CFG](OpType::CFG), nor a [Conditional](OpType::Conditional).
#[error("{node} has OpType {op} which cannot be an entry-point")]
InvalidEntryPoint {
/// The node which was specified as an entry-point
node: Node,
/// The `OpType` of the node
op: OpType,
},
/// The chosen entrypoint is not in the hugr.
/// Inputs were provided for a node that is not in the hugr.
#[error("Entry-point {node} is not part of the Hugr")]
MissingEntryPoint {
/// The missing node
Expand All @@ -67,15 +73,25 @@ impl ConstantFoldPass {
}

/// Specifies a number of external inputs to an entry point of the Hugr.
/// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` child of the root;
/// or for non-Module-rooted Hugrs, `node` is the root of the Hugr. (This is not
/// In normal use, for Module-rooted Hugrs, `node` is a `FuncDefn` (child of the root);
/// for non-Module-rooted Hugrs, `node` is the [HugrView::entrypoint]. (This is not
/// enforced, but it must be a container and not a module itself.)
///
/// Multiple calls for the same entry-point combine their values, with later
/// values on the same in-port replacing earlier ones.
///
/// Note that if `inputs` is empty, this still marks the node as an entry-point, i.e.
/// we must preserve nodes required to compute its result.
/// Note that providing empty `inputs` indicates that we must preserve the ability
/// to compute the result of `node` for all possible inputs.
/// * If the entrypoint is the module-root, this method should be called for every
/// [FuncDefn] that is externally callable
/// * Otherwise, i.e. if the entrypoint is not the module-root,
/// * The default is to assume the entrypoint is callable with any inputs;
/// * If `node` is the entrypoint, this method allows to restrict the possible inputs
/// * If `node` is beneath the entrypoint, this merely degrades the analysis. (We
/// will mutate only beneath the entrypoint, but using results of analysing the
/// whole Hugr wrt. the specified/any inputs too).
///
/// [FuncDefn]: hugr_core::ops::FuncDefn
pub fn with_inputs(
mut self,
node: Node,
Expand All @@ -97,8 +113,7 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
///
/// # Errors
///
/// [`ConstFoldError::InvalidEntryPoint`] if an entry-point added by [`Self::with_inputs`]
/// was of an invalid [`OpType`]
/// [ConstFoldError] if inputs were provided via [`Self::with_inputs`] for an invalid node.
fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> {
let fresh_node = Node::from(portgraph::NodeIndex::new(
hugr.nodes().max().map_or(0, |n| n.index() + 1),
Expand Down Expand Up @@ -184,25 +199,51 @@ impl<H: HugrMut<Node = Node> + 'static> ComposablePass<H> for ConstantFoldPass {
}
}

const NO_INPUTS: [(IncomingPort, Value); 0] = [];

/// Exhaustively apply constant folding to a HUGR.
/// If the Hugr's entrypoint is its [`Module`], assumes all [`FuncDefn`] children are reachable.
/// Otherwise, assume that the [HugrView::entrypoint] is itself reachable.
///
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
/// [`Module`]: hugr_core::ops::OpType::Module
#[deprecated(note = "Use fold_constants, or manually configure ConstantFoldPass")]
pub fn constant_fold_pass<H: HugrMut<Node = Node> + 'static>(mut h: impl AsMut<H>) {
let h = h.as_mut();
let c = ConstantFoldPass::default();
let c = if h.get_optype(h.entrypoint()).is_module() {
let no_inputs: [(IncomingPort, _); 0] = [];
h.children(h.entrypoint())
.filter(|n| h.get_optype(*n).is_func_defn())
.fold(c, |c, n| c.with_inputs(n, no_inputs.iter().cloned()))
.fold(c, |c, n| c.with_inputs(n, NO_INPUTS.clone()))
} else {
c
};
validate_if_test(c, h).unwrap();
}

/// Exhaustively apply constant folding to a HUGR.
/// Assumes that the Hugr's entrypoint is reachable (if it is not a [`Module`]).
/// Also uses `policy` to determine which public [`FuncDefn`] children of the [`HugrView::module_root`] are reachable.
///
/// [`Module`]: hugr_core::ops::OpType::Module
/// [`FuncDefn`]: hugr_core::ops::OpType::FuncDefn
pub fn fold_constants(h: &mut (impl HugrMut<Node = Node> + 'static), policy: VisPolicy) {
let mut funcs = Vec::new();
if !h.entrypoint_optype().is_module() {
funcs.push(h.entrypoint());
}
if policy.for_hugr(&h) {
funcs.extend(
h.children(h.module_root())
.filter(|n| h.get_optype(*n).is_func_defn()),
)
}
let c = funcs.into_iter().fold(ConstantFoldPass::default(), |c, n| {
c.with_inputs(n, NO_INPUTS.clone())
});
validate_if_test(c, h).unwrap();
}

struct ConstFoldContext;

impl ConstLoader<ValueHandle<Node>> for ConstFoldContext {
Expand Down
8 changes: 6 additions & 2 deletions hugr-passes/src/const_fold/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,14 @@ use hugr_core::std_extensions::logic::LogicOp;
use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV};
use hugr_core::{Hugr, HugrView, IncomingPort, Node, type_row};

use crate::ComposablePass as _;
use crate::dataflow::{DFContext, PartialValue, partial_from_const};
use crate::{ComposablePass as _, VisPolicy};

use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, constant_fold_pass};
use super::{ConstFoldContext, ConstantFoldPass, ValueHandle, fold_constants};

fn constant_fold_pass(h: &mut (impl HugrMut<Node = Node> + 'static)) {
fold_constants(h, VisPolicy::AllPublic);
}

#[rstest]
#[case(ConstInt::new_u(4, 2).unwrap(), true)]
Expand Down
2 changes: 1 addition & 1 deletion hugr-passes/src/dataflow/datalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ impl<H: HugrView, V: AbstractValue> Machine<H, V> {
} else {
let ep = self.0.entrypoint();
let mut p = in_values.into_iter().peekable();
// We must provide some inputs to the root so that they are Top rather than Bottom.
// We must provide some inputs to the entrypoint so that they are Top rather than Bottom.
// (However, this test will fail for DataflowBlock or Case roots, i.e. if no
// inputs have been provided they will still see Bottom. We could store the "input"
// values for even these nodes in self.1 and then convert to actual Wire values
Expand Down
Loading
Loading