From 4db3f5d72d1661bd0ec099a33afdcb075740785a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= Date: Tue, 15 Apr 2025 13:52:31 +0100 Subject: [PATCH 01/18] ci: Run ci checks on PRs to any branch --- .github/workflows/ci-py.yml | 2 +- .github/workflows/ci-rs.yml | 2 +- .github/workflows/pr-title.yml | 2 +- .github/workflows/semver-checks.yml | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/.github/workflows/ci-py.yml b/.github/workflows/ci-py.yml index c393c195f..6ef3edce5 100644 --- a/.github/workflows/ci-py.yml +++ b/.github/workflows/ci-py.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - '**' merge_group: types: [checks_requested] workflow_dispatch: {} diff --git a/.github/workflows/ci-rs.yml b/.github/workflows/ci-rs.yml index d7c94e3a3..824291a8b 100644 --- a/.github/workflows/ci-rs.yml +++ b/.github/workflows/ci-rs.yml @@ -6,7 +6,7 @@ on: - main pull_request: branches: - - '*' + - '**' merge_group: types: [checks_requested] workflow_dispatch: {} diff --git a/.github/workflows/pr-title.yml b/.github/workflows/pr-title.yml index 333eec96f..f778c1363 100644 --- a/.github/workflows/pr-title.yml +++ b/.github/workflows/pr-title.yml @@ -2,7 +2,7 @@ name: Check Conventional Commits format on: pull_request_target: branches: - - main + - '**' types: - opened - edited diff --git a/.github/workflows/semver-checks.yml b/.github/workflows/semver-checks.yml index e884b2e36..2c410aa85 100644 --- a/.github/workflows/semver-checks.yml +++ b/.github/workflows/semver-checks.yml @@ -2,7 +2,7 @@ name: Rust Semver Checks on: pull_request_target: branches: - - main + - '**' jobs: # Check if changes were made to the relevant files. From 81447ecf83acdba6d50427a64637447899553054 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 15:58:19 +0100 Subject: [PATCH 02/18] feat!: Allow generic Nodes in HugrMut insert operations (#2075) `insert_hugr`, `insert_from_view`, and `insert_subgraph` were written before we made `Node` a type generic, and incorrectly assumed the return type maps were always `hugr::Node`s. The methods were either unusable or incorrect when using generic `HugrView`s source/targets with non-base node types. This PR fixes that, and additionally allows us us to have `SiblingSubgraph::extract_subgraph` work for generic `HugrViews`. BREAKING CHANGE: Added Node type parameters to extraction operations in `HugrMut`. --- hugr-core/src/builder/build_traits.rs | 2 +- hugr-core/src/hugr/hugrmut.rs | 114 ++++++++++++------- hugr-core/src/hugr/views/sibling_subgraph.rs | 4 +- 3 files changed, 75 insertions(+), 45 deletions(-) diff --git a/hugr-core/src/builder/build_traits.rs b/hugr-core/src/builder/build_traits.rs index f1613895d..e17d172ca 100644 --- a/hugr-core/src/builder/build_traits.rs +++ b/hugr-core/src/builder/build_traits.rs @@ -119,7 +119,7 @@ pub trait Container { } /// Insert a copy of a HUGR as a child of the container. - fn add_hugr_view(&mut self, child: &impl HugrView) -> InsertionResult { + fn add_hugr_view(&mut self, child: &H) -> InsertionResult { let parent = self.container_node(); self.hugr_mut().insert_from_view(parent, child) } diff --git a/hugr-core/src/hugr/hugrmut.rs b/hugr-core/src/hugr/hugrmut.rs index f3ef094be..38eb59222 100644 --- a/hugr-core/src/hugr/hugrmut.rs +++ b/hugr-core/src/hugr/hugrmut.rs @@ -1,13 +1,15 @@ //! Low-level interface for modifying a HUGR. use core::panic; -use std::collections::{BTreeMap, HashMap}; +use std::collections::{BTreeMap, HashMap, HashSet}; use std::sync::Arc; use portgraph::view::{NodeFilter, NodeFiltered}; use portgraph::{LinkMut, PortMut, PortView, SecondaryMap}; +use crate::core::HugrNode; use crate::extension::ExtensionRegistry; +use crate::hugr::internal::HugrInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrView, Node, OpType, RootTagged}; use crate::hugr::{NodeMetadata, Rewrite}; @@ -162,10 +164,10 @@ pub trait HugrMut: HugrMutInternals { /// correspondingly for `Dom` edges) fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { panic_invalid_node(self, root); panic_invalid_node(self, new_parent); self.hugr_mut().copy_descendants(root, new_parent, subst) @@ -225,7 +227,7 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_hugr(&mut self, root: Node, other: Hugr) -> InsertionResult { + fn insert_hugr(&mut self, root: Self::Node, other: Hugr) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_hugr(root, other) } @@ -236,7 +238,11 @@ pub trait HugrMut: HugrMutInternals { /// /// If the root node is not in the graph. #[inline] - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { panic_invalid_node(self, root); self.hugr_mut().insert_from_view(root, other) } @@ -255,12 +261,12 @@ pub trait HugrMut: HugrMutInternals { // TODO: Try to preserve the order when possible? We cannot always ensure // it, since the subgraph may have arbitrary nodes without including their // parent. - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { panic_invalid_node(self, root); self.hugr_mut().insert_subgraph(root, other, subgraph) } @@ -307,20 +313,32 @@ pub trait HugrMut: HugrMutInternals { /// Records the result of inserting a Hugr or view /// via [HugrMut::insert_hugr] or [HugrMut::insert_from_view]. -pub struct InsertionResult { +/// +/// Contains a map from the nodes in the source HUGR to the nodes in the +/// target HUGR, using their respective `Node` types. +pub struct InsertionResult { /// The node, after insertion, that was the root of the inserted Hugr. /// /// That is, the value in [InsertionResult::node_map] under the key that was the [HugrView::root] - pub new_root: Node, + pub new_root: TargetN, /// Map from nodes in the Hugr/view that was inserted, to their new /// positions in the Hugr into which said was inserted. - pub node_map: HashMap, + pub node_map: HashMap, } -fn translate_indices( +/// Translate a portgraph node index map into a map from nodes in the source +/// HUGR to nodes in the target HUGR. +/// +/// This is as a helper in `insert_hugr` and `insert_subgraph`, where the source +/// HUGR may be an arbitrary `HugrView` with generic node types. +fn translate_indices( + mut source_node: impl FnMut(portgraph::NodeIndex) -> N, + mut target_node: impl FnMut(portgraph::NodeIndex) -> Node, node_map: HashMap, -) -> impl Iterator { - node_map.into_iter().map(|(k, v)| (k.into(), v.into())) +) -> impl Iterator { + node_map + .into_iter() + .map(move |(k, v)| (source_node(k), target_node(v))) } /// Impl for non-wrapped Hugrs. Overwrites the recursive default-impls to directly use the hugr. @@ -406,7 +424,11 @@ impl + AsMut> HugrMut for T (src_port, dst_port) } - fn insert_hugr(&mut self, root: Node, mut other: Hugr) -> InsertionResult { + fn insert_hugr( + &mut self, + root: Self::Node, + mut other: Hugr, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, &other); // Update the optypes and metadata, taking them from the other graph. // @@ -423,11 +445,16 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_from_view(&mut self, root: Node, other: &impl HugrView) -> InsertionResult { + fn insert_from_view( + &mut self, + root: Self::Node, + other: &H, + ) -> InsertionResult { let (new_root, node_map) = insert_hugr_internal(self.as_mut(), root, other); // Update the optypes and metadata, copying them from the other graph. // @@ -444,22 +471,28 @@ impl + AsMut> HugrMut for T ); InsertionResult { new_root, - node_map: translate_indices(node_map).collect(), + node_map: translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map) + .collect(), } } - fn insert_subgraph( + fn insert_subgraph( &mut self, - root: Node, - other: &impl HugrView, - subgraph: &SiblingSubgraph, - ) -> HashMap { + root: Self::Node, + other: &H, + subgraph: &SiblingSubgraph, + ) -> HashMap { // Create a portgraph view with the explicit list of nodes defined by the subgraph. - let portgraph: NodeFiltered<_, NodeFilter<&[Node]>, &[Node]> = + let context: HashSet = subgraph + .nodes() + .iter() + .map(|&n| other.get_pg_index(n)) + .collect(); + let portgraph: NodeFiltered<_, NodeFilter>, _> = NodeFiltered::new_node_filtered( other.portgraph(), - |node, ctx| ctx.contains(&node.into()), - subgraph.nodes(), + |node, ctx| ctx.contains(&node), + context, ); let node_map = insert_subgraph_internal(self.as_mut(), root, other, &portgraph); // Update the optypes and metadata, copying them from the other graph. @@ -473,25 +506,24 @@ impl + AsMut> HugrMut for T self.use_extensions(exts); } } - translate_indices(node_map).collect() + translate_indices(|n| other.get_node(n), |n| self.get_node(n), node_map).collect() } fn copy_descendants( &mut self, - root: Node, - new_parent: Node, + root: Self::Node, + new_parent: Self::Node, subst: Option, - ) -> BTreeMap { + ) -> BTreeMap { let mut descendants = self.base_hugr().hierarchy.descendants(root.pg_index()); let root2 = descendants.next(); debug_assert_eq!(root2, Some(root.pg_index())); let nodes = Vec::from_iter(descendants); - let node_map = translate_indices( - portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) - .copy_in_parent() - .expect("Is a MultiPortGraph"), - ) - .collect::>(); + let node_map = portgraph::view::Subgraph::with_nodes(&mut self.as_mut().graph, nodes) + .copy_in_parent() + .expect("Is a MultiPortGraph"); + let node_map = translate_indices(|n| self.get_node(n), |n| self.get_node(n), node_map) + .collect::>(); for node in self.children(root).collect::>() { self.set_parent(*node_map.get(&node).unwrap(), new_parent); @@ -563,10 +595,10 @@ fn insert_hugr_internal( /// sibling order in the hierarchy. This is due to the subgraph not necessarily /// having a single root, so the logic for reconstructing the hierarchy is not /// able to just do a BFS. -fn insert_subgraph_internal( +fn insert_subgraph_internal( hugr: &mut Hugr, root: Node, - other: &impl HugrView, + other: &impl HugrView, portgraph: &impl portgraph::LinkView, ) -> HashMap { let node_map = hugr diff --git a/hugr-core/src/hugr/views/sibling_subgraph.rs b/hugr-core/src/hugr/views/sibling_subgraph.rs index a0bf1a3da..c681fafc9 100644 --- a/hugr-core/src/hugr/views/sibling_subgraph.rs +++ b/hugr-core/src/hugr/views/sibling_subgraph.rs @@ -446,16 +446,14 @@ impl SiblingSubgraph { nu_out, )) } -} -impl SiblingSubgraph { /// Create a new Hugr containing only the subgraph. /// /// The new Hugr will contain a [FuncDefn][crate::ops::FuncDefn] root /// with the same signature as the subgraph and the specified `name` pub fn extract_subgraph( &self, - hugr: &impl HugrView, + hugr: &impl HugrView, name: impl Into, ) -> Hugr { let mut builder = FunctionBuilder::new(name, self.signature(hugr)).unwrap(); From ef1cba0a85f803423e9f14450844ad4f7300f1fd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 16:02:17 +0100 Subject: [PATCH 03/18] fix!: Don't expose `HugrMutInternals` (#2071) `HugrMutInternals` is part of the semi-private traits defined in `hugr-core`. While most things get re-exported in `hugr`, we `*Internal` traits require you to explicitly declare a dependency on the `-core` package (as we don't want most users to have to interact with them). For some reason there was a public re-export of the trait in a re-exported module, so it ended up appearing in `hugr` anyways. BREAKING CHANGE: Removed public re-export of `HugrMutInternals` from `hugr`. --- hugr-core/src/hugr/rewrite/simple_replace.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/hugr-core/src/hugr/rewrite/simple_replace.rs b/hugr-core/src/hugr/rewrite/simple_replace.rs index cf7f2922a..b4ec37db1 100644 --- a/hugr-core/src/hugr/rewrite/simple_replace.rs +++ b/hugr-core/src/hugr/rewrite/simple_replace.rs @@ -4,7 +4,6 @@ use std::collections::HashMap; use crate::core::HugrNode; use crate::hugr::hugrmut::InsertionResult; -pub use crate::hugr::internal::HugrMutInternals; use crate::hugr::views::SiblingSubgraph; use crate::hugr::{HugrMut, HugrView, Rewrite}; use crate::ops::{OpTag, OpTrait, OpType}; From fac6c8b92bab8904f47eaa9ebf4581eb1e1f4095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Tue, 15 Apr 2025 17:10:32 +0100 Subject: [PATCH 04/18] feat!: Mark all Error enums as non_exhaustive (#2056) #2027 ended up being breaking due to adding a new variant to an error enum missing the `non_exhaustive` marker. This (breaking) PR makes sure all error enums have the flag. BREAKING CHANGE: Marked all Error enums as `non_exhaustive` --- hugr-core/src/extension.rs | 1 + hugr-core/src/hugr/serialize/upgrade.rs | 1 + hugr-core/src/import.rs | 2 ++ hugr-model/src/v0/ast/resolve.rs | 1 + hugr-model/src/v0/table/mod.rs | 1 + hugr-passes/src/lower.rs | 1 + hugr-passes/src/non_local.rs | 1 + hugr-passes/src/replace_types/linearize.rs | 1 + hugr-passes/src/validation.rs | 1 + 9 files changed, 10 insertions(+) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index 408c88e15..b6e059050 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -378,6 +378,7 @@ pub static EMPTY_REG: ExtensionRegistry = ExtensionRegistry { /// TODO: decide on failure modes #[derive(Debug, Clone, Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum SignatureError { /// Name mismatch #[error("Definition name ({0}) and instantiation name ({1}) do not match.")] diff --git a/hugr-core/src/hugr/serialize/upgrade.rs b/hugr-core/src/hugr/serialize/upgrade.rs index 2741b6175..ac1ac1eea 100644 --- a/hugr-core/src/hugr/serialize/upgrade.rs +++ b/hugr-core/src/hugr/serialize/upgrade.rs @@ -1,6 +1,7 @@ use thiserror::Error; #[derive(Debug, Error)] +#[non_exhaustive] pub enum UpgradeError { #[error(transparent)] Deserialize(#[from] serde_json::Error), diff --git a/hugr-core/src/import.rs b/hugr-core/src/import.rs index 642c84c41..899deb17d 100644 --- a/hugr-core/src/import.rs +++ b/hugr-core/src/import.rs @@ -35,6 +35,7 @@ use thiserror::Error; /// Error during import. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ImportError { /// The model contains a feature that is not supported by the importer yet. /// Errors of this kind are expected to be removed as the model format and @@ -75,6 +76,7 @@ pub enum ImportError { /// Import error caused by incorrect order hints. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum OrderHintError { /// Duplicate order hint key in the same region. #[error("duplicate order hint key {0}")] diff --git a/hugr-model/src/v0/ast/resolve.rs b/hugr-model/src/v0/ast/resolve.rs index 2f8a5ba6e..c9be8896b 100644 --- a/hugr-model/src/v0/ast/resolve.rs +++ b/hugr-model/src/v0/ast/resolve.rs @@ -362,6 +362,7 @@ impl<'a> Context<'a> { /// Error that may occur in [`Module::resolve`]. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ResolveError { /// Unknown variable. #[error("unknown var: {0}")] diff --git a/hugr-model/src/v0/table/mod.rs b/hugr-model/src/v0/table/mod.rs index 756a52c1e..55a4b9889 100644 --- a/hugr-model/src/v0/table/mod.rs +++ b/hugr-model/src/v0/table/mod.rs @@ -456,6 +456,7 @@ pub struct VarId(pub NodeId, pub VarIndex); /// Errors that can occur when traversing and interpreting the model. #[derive(Debug, Clone, Error)] +#[non_exhaustive] pub enum ModelError { /// There is a reference to a node that does not exist. #[error("node not found: {0}")] diff --git a/hugr-passes/src/lower.rs b/hugr-passes/src/lower.rs index 09e02c41d..8f8920967 100644 --- a/hugr-passes/src/lower.rs +++ b/hugr-passes/src/lower.rs @@ -35,6 +35,7 @@ pub fn replace_many_ops>( /// Errors produced by the [`lower_ops`] function. #[derive(Debug, Error)] #[error(transparent)] +#[non_exhaustive] pub enum LowerError { /// Invalid subgraph. #[error("Subgraph formed by node is invalid: {0}")] diff --git a/hugr-passes/src/non_local.rs b/hugr-passes/src/non_local.rs index fca74657b..180e9d6fc 100644 --- a/hugr-passes/src/non_local.rs +++ b/hugr-passes/src/non_local.rs @@ -23,6 +23,7 @@ pub fn nonlocal_edges(hugr: &H) -> impl Iterator { #[error("Found {} nonlocal edges", .0.len())] Edges(Vec<(N, IncomingPort)>), diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 371798dce..b3fc20da9 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -127,6 +127,7 @@ pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum LinearizeError { #[error("Need copy/discard op for {_0}")] NeedCopyDiscard(Type), diff --git a/hugr-passes/src/validation.rs b/hugr-passes/src/validation.rs index 5f53f403c..6c3e61fb4 100644 --- a/hugr-passes/src/validation.rs +++ b/hugr-passes/src/validation.rs @@ -25,6 +25,7 @@ pub enum ValidationLevel { #[derive(Error, Debug, PartialEq)] #[allow(missing_docs)] +#[non_exhaustive] pub enum ValidatePassError { #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] InputError { From baaca02359a307f8691ab3985313272339a8c494 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 16 Apr 2025 11:11:13 +0100 Subject: [PATCH 05/18] feat!: Handle CallIndirect in Dataflow Analysis (#2059) * PartialValue now has a LoadedFunction variant, created by LoadFunction nodes (only, although other analyses are able to create PartialValues if they want) * This requires adding a type parameter to PartialValue for the type of Node, which gets everywhere :-(. * Use this to handle CallIndirects *with known targets* (it'll be a single known target or none at all) just like other Calls to the same function * deprecate (and ignore) `value_from_function` * Add a new trait `AsConcrete` for the result type of `PartialValue::try_into_concrete` and `PartialSum::try_into_sum` Note almost no change to constant folding (only to drop impl of `value_from_function`) BREAKING CHANGE: in dataflow framework, PartialValue now has additional variant; `try_into_concrete` requires the target type to implement AsConcrete. --- hugr-passes/src/const_fold.rs | 63 ++--- hugr-passes/src/const_fold/test.rs | 6 +- hugr-passes/src/const_fold/value_handle.rs | 23 +- hugr-passes/src/dataflow.rs | 17 +- hugr-passes/src/dataflow/datalog.rs | 171 +++++++++++-- hugr-passes/src/dataflow/partial_value.rs | 267 +++++++++++++-------- hugr-passes/src/dataflow/results.rs | 22 +- hugr-passes/src/dataflow/test.rs | 108 ++++++++- hugr-passes/src/dataflow/value_row.rs | 38 +-- 9 files changed, 492 insertions(+), 223 deletions(-) diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index 7552ed36f..e73e3cd0e 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -7,15 +7,11 @@ use std::{collections::HashMap, sync::Arc}; use thiserror::Error; use hugr_core::{ - hugr::{ - hugrmut::HugrMut, - views::{DescendantsGraph, ExtractHugr, HierarchyView}, - }, + hugr::hugrmut::HugrMut, ops::{ - constant::OpaqueValue, handle::FuncID, Const, DataflowOpTrait, ExtensionOp, LoadConstant, - OpType, Value, + constant::OpaqueValue, Const, DataflowOpTrait, ExtensionOp, LoadConstant, OpType, Value, }, - types::{EdgeKind, TypeArg}, + types::EdgeKind, HugrView, IncomingPort, Node, NodeIndex, OutgoingPort, PortIndex, Wire, }; use value_handle::ValueHandle; @@ -102,7 +98,7 @@ impl ConstantFoldPass { n, in_vals.iter().map(|(p, v)| { let const_with_dummy_loc = partial_from_const( - &ConstFoldContext(hugr), + &ConstFoldContext, ConstLocation::Field(p.index(), &fresh_node.into()), v, ); @@ -112,7 +108,7 @@ impl ConstantFoldPass { .map_err(|opty| ConstFoldError::InvalidEntryPoint(n, opty))?; } - let results = m.run(ConstFoldContext(hugr), []); + let results = m.run(ConstFoldContext, []); let mb_root_inp = hugr.get_io(hugr.root()).map(|[i, _]| i); let wires_to_break = hugr @@ -131,7 +127,7 @@ impl ConstantFoldPass { n, ip, results - .try_read_wire_concrete::(Wire::new(src, outp)) + .try_read_wire_concrete::(Wire::new(src, outp)) .ok()?, )) }) @@ -205,60 +201,35 @@ pub fn constant_fold_pass(h: &mut H) { c.run(h).unwrap() } -struct ConstFoldContext<'a, H>(&'a H); - -impl std::ops::Deref for ConstFoldContext<'_, H> { - type Target = H; - fn deref(&self) -> &H { - self.0 - } -} +struct ConstFoldContext; -impl> ConstLoader> for ConstFoldContext<'_, H> { - type Node = H::Node; +impl ConstLoader> for ConstFoldContext { + type Node = Node; fn value_from_opaque( &self, - loc: ConstLocation, + loc: ConstLocation, val: &OpaqueValue, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_opaque(loc, val.clone())) } fn value_from_const_hugr( &self, - loc: ConstLocation, + loc: ConstLocation, h: &hugr_core::Hugr, - ) -> Option> { + ) -> Option> { Some(ValueHandle::new_const_hugr(loc, Box::new(h.clone()))) } - - fn value_from_function( - &self, - node: H::Node, - type_args: &[TypeArg], - ) -> Option> { - if !type_args.is_empty() { - // TODO: substitution across Hugr (https://github.com/CQCL/hugr/issues/709) - return None; - }; - // Returning the function body as a value, here, would be sufficient for inlining IndirectCall - // but not for transforming to a direct Call. - let func = DescendantsGraph::>::try_new(&**self, node).ok()?; - Some(ValueHandle::new_const_hugr( - ConstLocation::Node(node), - Box::new(func.extract_hugr()), - )) - } } -impl> DFContext> for ConstFoldContext<'_, H> { +impl DFContext> for ConstFoldContext { fn interpret_leaf_op( &mut self, - node: H::Node, + node: Node, op: &ExtensionOp, - ins: &[PartialValue>], - outs: &mut [PartialValue>], + ins: &[PartialValue>], + outs: &mut [PartialValue>], ) { let sig = op.signature(); let known_ins = sig diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index b84d65d7d..58e69c568 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -42,8 +42,7 @@ fn value_handling(#[case] k: impl CustomConst + Clone, #[case] eq: bool) { let n = Node::from(portgraph::NodeIndex::new(7)); let st = SumType::new([vec![k.get_type()], vec![]]); let subject_val = Value::sum(0, [k.clone().into()], st).unwrap(); - let temp = Hugr::default(); - let ctx: ConstFoldContext = ConstFoldContext(&temp); + let ctx = ConstFoldContext; let v1 = partial_from_const(&ctx, n, &subject_val); let v1_subfield = { @@ -114,8 +113,7 @@ fn test_add(#[case] a: f64, #[case] b: f64, #[case] c: f64) { v.get_custom_value::().unwrap().value() } let [n, n_a, n_b] = [0, 1, 2].map(portgraph::NodeIndex::new).map(Node::from); - let temp = Hugr::default(); - let mut ctx = ConstFoldContext(&temp); + let mut ctx = ConstFoldContext; let v_a = partial_from_const(&ctx, n_a, &f2c(a)); let v_b = partial_from_const(&ctx, n_b, &f2c(b)); assert_eq!(unwrap_float(v_a.clone()), a); diff --git a/hugr-passes/src/const_fold/value_handle.rs b/hugr-passes/src/const_fold/value_handle.rs index bda7bffd2..e5c99a8e7 100644 --- a/hugr-passes/src/const_fold/value_handle.rs +++ b/hugr-passes/src/const_fold/value_handle.rs @@ -1,16 +1,18 @@ //! Total equality (and hence [AbstractValue] support for [Value]s //! (by adding a source-Node and part unhashable constants) use std::collections::hash_map::DefaultHasher; // Moves into std::hash in Rust 1.76. +use std::convert::Infallible; use std::hash::{Hash, Hasher}; use std::sync::Arc; use hugr_core::core::HugrNode; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::Value; +use hugr_core::types::ConstTypeError; use hugr_core::{Hugr, Node}; use itertools::Either; -use crate::dataflow::{AbstractValue, ConstLocation}; +use crate::dataflow::{AbstractValue, AsConcrete, ConstLocation, LoadedFunction, Sum}; /// A custom constant that has been successfully hashed via [TryHash](hugr_core::ops::constant::TryHash) #[derive(Clone, Debug)] @@ -153,9 +155,12 @@ impl Hash for ValueHandle { // Unfortunately we need From for Value to be able to pass // Value's into interpret_leaf_op. So that probably doesn't make sense... -impl From> for Value { - fn from(value: ValueHandle) -> Self { - match value { +impl AsConcrete, N> for Value { + type ValErr = Infallible; + type SumErr = ConstTypeError; + + fn from_value(value: ValueHandle) -> Result { + Ok(match value { ValueHandle::Hashable(HashedConst { val, .. }) | ValueHandle::Unhashable { leaf: Either::Left(val), @@ -169,7 +174,15 @@ impl From> for Value { } => Value::function(Arc::try_unwrap(hugr).unwrap_or_else(|a| a.as_ref().clone())) .map_err(|e| e.to_string()) .unwrap(), - } + }) + } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) } } diff --git a/hugr-passes/src/dataflow.rs b/hugr-passes/src/dataflow.rs index 43caa9c94..1f7c1ae5a 100644 --- a/hugr-passes/src/dataflow.rs +++ b/hugr-passes/src/dataflow.rs @@ -9,7 +9,7 @@ mod results; pub use results::{AnalysisResults, TailLoopTermination}; mod partial_value; -pub use partial_value::{AbstractValue, PartialSum, PartialValue, Sum}; +pub use partial_value::{AbstractValue, AsConcrete, LoadedFunction, PartialSum, PartialValue, Sum}; use hugr_core::ops::constant::OpaqueValue; use hugr_core::ops::{ExtensionOp, Value}; @@ -31,8 +31,8 @@ pub trait DFContext: ConstLoader { &mut self, _node: Self::Node, _e: &ExtensionOp, - _ins: &[PartialValue], - _outs: &mut [PartialValue], + _ins: &[PartialValue], + _outs: &mut [PartialValue], ) { } } @@ -55,8 +55,8 @@ impl From for ConstLocation<'_, N> { } /// Trait for loading [PartialValue]s from constant [Value]s in a Hugr. -/// Implementors will likely want to override some/all of [Self::value_from_opaque], -/// [Self::value_from_const_hugr], and [Self::value_from_function]: the defaults +/// Implementors will likely want to override either/both of [Self::value_from_opaque] +/// and [Self::value_from_const_hugr]: the defaults /// are "correct" but maximally conservative (minimally informative). pub trait ConstLoader { /// The type of nodes in the Hugr. @@ -81,6 +81,7 @@ pub trait ConstLoader { /// [FuncDefn]: hugr_core::ops::FuncDefn /// [FuncDecl]: hugr_core::ops::FuncDecl /// [LoadFunction]: hugr_core::ops::LoadFunction + #[deprecated(note = "Automatically handled by Datalog, implementation will be ignored")] fn value_from_function(&self, _node: Self::Node, _type_args: &[TypeArg]) -> Option { None } @@ -94,7 +95,7 @@ pub fn partial_from_const<'a, V, CL: ConstLoader>( cl: &CL, loc: impl Into>, cst: &Value, -) -> PartialValue +) -> PartialValue where CL::Node: 'a, { @@ -120,8 +121,8 @@ where /// A row of inputs to a node contains bottom (can't happen, the node /// can't execute) if any element [contains_bottom](PartialValue::contains_bottom). -pub fn row_contains_bottom<'a, V: AbstractValue + 'a>( - elements: impl IntoIterator>, +pub fn row_contains_bottom<'a, V: 'a, N: 'a>( + elements: impl IntoIterator>, ) -> bool { elements.into_iter().any(PartialValue::contains_bottom) } diff --git a/hugr-passes/src/dataflow/datalog.rs b/hugr-passes/src/dataflow/datalog.rs index 13e510daf..ad1a99345 100644 --- a/hugr-passes/src/dataflow/datalog.rs +++ b/hugr-passes/src/dataflow/datalog.rs @@ -3,19 +3,22 @@ use std::collections::HashMap; use ascent::lattice::BoundedLattice; +use ascent::Lattice; use itertools::Itertools; use hugr_core::extension::prelude::{MakeTuple, UnpackTuple}; -use hugr_core::ops::{OpTrait, OpType, TailLoop}; +use hugr_core::ops::{DataflowOpTrait, OpTrait, OpType, TailLoop}; use hugr_core::{HugrView, IncomingPort, OutgoingPort, PortIndex as _, Wire}; use super::value_row::ValueRow; use super::{ partial_from_const, row_contains_bottom, AbstractValue, AnalysisResults, DFContext, - PartialValue, + LoadedFunction, PartialValue, }; -type PV = PartialValue; +type PV = PartialValue; + +type NodeInputs = Vec<(IncomingPort, PV)>; /// Basic structure for performing an analysis. Usage: /// 1. Make a new instance via [Self::new()] @@ -25,10 +28,7 @@ type PV = PartialValue; /// [Self::prepopulate_inputs] can be used on each externally-callable /// [FuncDefn](OpType::FuncDefn) to set all inputs to [PartialValue::Top]. /// 3. Call [Self::run] to produce [AnalysisResults] -pub struct Machine( - H, - HashMap)>>, -); +pub struct Machine(H, HashMap>); impl Machine { /// Create a new Machine to analyse the given Hugr(View) @@ -40,7 +40,7 @@ impl Machine { impl Machine { /// Provide initial values for a wire - these will be `join`d with any computed /// or any value previously prepopulated for the same Wire. - pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { + pub fn prepopulate_wire(&mut self, w: Wire, v: PartialValue) { for (n, inp) in self.0.linked_inputs(w.node(), w.source()) { self.1.entry(n).or_default().push((inp, v.clone())); } @@ -54,7 +54,7 @@ impl Machine { pub fn prepopulate_inputs( &mut self, parent: H::Node, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> Result<(), OpType> { match self.0.get_optype(parent) { OpType::DataflowBlock(_) | OpType::Case(_) | OpType::FuncDefn(_) => { @@ -102,7 +102,7 @@ impl Machine { pub fn run( mut self, context: impl DFContext, - in_values: impl IntoIterator)>, + in_values: impl IntoIterator)>, ) -> AnalysisResults { let root = self.0.root(); if self.0.get_optype(root).is_module() { @@ -135,10 +135,12 @@ impl Machine { } } +pub(super) type InWire = (N, IncomingPort, PartialValue); + pub(super) fn run_datalog( mut ctx: impl DFContext, hugr: H, - in_wire_value_proto: Vec<(H::Node, IncomingPort, PV)>, + in_wire_value_proto: Vec>, ) -> AnalysisResults { // ascent-(macro-)generated code generates a bunch of warnings, // keep code in here to a minimum. @@ -155,9 +157,9 @@ pub(super) fn run_datalog( relation parent_of_node(H::Node, H::Node); // is parent of relation input_child(H::Node, H::Node); // has 1st child that is its `Input` relation output_child(H::Node, H::Node); // has 2nd child that is its `Output` - lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value - lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value - lattice node_in_value_row(H::Node, ValueRow); // 's inputs are + lattice out_wire_value(H::Node, OutgoingPort, PV); // produces, on , the value + lattice in_wire_value(H::Node, IncomingPort, PV); // receives, on , the value + lattice node_in_value_row(H::Node, ValueRow); // 's inputs are node(n) <-- for n in hugr.nodes(); @@ -322,6 +324,37 @@ pub(super) fn run_datalog( func_call(call, func), output_child(func, outp), in_wire_value(outp, p, v); + + // CallIndirect -------------------- + lattice indirect_call(H::Node, LatticeWrapper); // is an `IndirectCall` to `FuncDefn` + indirect_call(call, tgt) <-- + node(call), + if let OpType::CallIndirect(_) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + let tgt = load_func(v); + + out_wire_value(inp, OutgoingPort::from(p.index()-1), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + input_child(func, inp), + in_wire_value(call, p, v) + if p.index() > 0; + + out_wire_value(call, OutgoingPort::from(p.index()), v) <-- + indirect_call(call, lv), + if let LatticeWrapper::Value(func) = lv, + output_child(func, outp), + in_wire_value(outp, p, v); + + // Default out-value is Bottom, but if we can't determine the called function, + // assign everything to Top + out_wire_value(call, p, PV::Top) <-- + node(call), + if let OpType::CallIndirect(ci) = hugr.get_optype(*call), + in_wire_value(call, IncomingPort::from(0), v), + // Second alternative below addresses function::Value's: + if matches!(v, PartialValue::Top | PartialValue::Value(_)), + for p in ci.signature().output_ports(); }; let out_wire_values = all_results .out_wire_value @@ -337,13 +370,58 @@ pub(super) fn run_datalog( } } +#[derive(Debug, PartialEq, Eq, Hash, Clone, PartialOrd)] +enum LatticeWrapper { + Bottom, + Value(T), + Top, +} + +impl Lattice for LatticeWrapper { + fn meet_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + return false; + }; + if *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + *self = other; + return true; + }; + // Both are `Value`s and not equal + *self = LatticeWrapper::Bottom; + true + } + + fn join_mut(&mut self, other: Self) -> bool { + if *self == other || *self == LatticeWrapper::Top || other == LatticeWrapper::Bottom { + return false; + }; + if *self == LatticeWrapper::Bottom || other == LatticeWrapper::Top { + *self = other; + return true; + }; + // Both are `Value`s and are not equal + *self = LatticeWrapper::Top; + true + } +} + +fn load_func(v: &PV) -> LatticeWrapper { + match v { + PartialValue::Bottom | PartialValue::PartialSum(_) => LatticeWrapper::Bottom, + PartialValue::LoadedFunction(LoadedFunction { func_node, .. }) => { + LatticeWrapper::Value(*func_node) + } + PartialValue::Value(_) | PartialValue::Top => LatticeWrapper::Top, + } +} + fn propagate_leaf_op( ctx: &mut impl DFContext, hugr: &H, n: H::Node, - ins: &[PV], + ins: &[PV], num_outs: usize, -) -> Option> { +) -> Option> { match hugr.get_optype(n) { // Handle basics here. We could instead leave these to DFContext, // but at least we'd want these impls to be easily reusable. @@ -362,8 +440,7 @@ fn propagate_leaf_op( ins.iter().cloned(), )])), OpType::Input(_) | OpType::Output(_) | OpType::ExitBlock(_) => None, // handled by parent - OpType::Call(_) => None, // handled via Input/Output of FuncDefn - OpType::Const(_) => None, // handled by LoadConstant: + OpType::Call(_) | OpType::CallIndirect(_) => None, // handled via Input/Output of FuncDefn OpType::LoadConstant(load_op) => { assert!(ins.is_empty()); // static edge, so need to find constant let const_node = hugr @@ -380,10 +457,10 @@ fn propagate_leaf_op( .unwrap() .0; // Node could be a FuncDefn or a FuncDecl, so do not pass the node itself - Some(ValueRow::singleton( - ctx.value_from_function(func_node, &load_op.type_args) - .map_or(PV::Top, PV::Value), - )) + Some(ValueRow::singleton(PartialValue::new_load( + func_node, + load_op.type_args.clone(), + ))) } OpType::ExtensionOp(e) => { Some(ValueRow::from_iter(if row_contains_bottom(ins) { @@ -401,6 +478,54 @@ fn propagate_leaf_op( outs })) } - o => todo!("Unhandled: {:?}", o), // At least CallIndirect, and OpType is "non-exhaustive" + // We only call propagate_leaf_op for dataflow op non-containers, + o => todo!("Unhandled: {:?}", o), // and OpType is non-exhaustive + } +} + +#[cfg(test)] +mod test { + use ascent::Lattice; + + use super::LatticeWrapper; + + #[test] + fn latwrap_join() { + for lv in [ + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + LatticeWrapper::Top, + ] { + let mut subject = LatticeWrapper::Bottom; + assert!(subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.join_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.join_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Top + ); + assert_eq!(subject, LatticeWrapper::Top); + } + } + + #[test] + fn latwrap_meet() { + for lv in [ + LatticeWrapper::Bottom, + LatticeWrapper::Value(3), + LatticeWrapper::Value(5), + ] { + let mut subject = LatticeWrapper::Top; + assert!(subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert!(!subject.meet_mut(lv.clone())); + assert_eq!(subject, lv); + assert_eq!( + subject.meet_mut(LatticeWrapper::Value(11)), + lv != LatticeWrapper::Bottom + ); + assert_eq!(subject, LatticeWrapper::Bottom); + } } } diff --git a/hugr-passes/src/dataflow/partial_value.rs b/hugr-passes/src/dataflow/partial_value.rs index f2a497806..240f4f2d6 100644 --- a/hugr-passes/src/dataflow/partial_value.rs +++ b/hugr-passes/src/dataflow/partial_value.rs @@ -1,7 +1,7 @@ use ascent::lattice::BoundedLattice; use ascent::Lattice; -use hugr_core::ops::Value; -use hugr_core::types::{ConstTypeError, SumType, Type, TypeEnum, TypeRow}; +use hugr_core::types::{SumType, Type, TypeArg, TypeEnum, TypeRow}; +use hugr_core::Node; use itertools::{zip_eq, Itertools}; use std::cmp::Ordering; use std::collections::HashMap; @@ -51,15 +51,25 @@ pub struct Sum { pub st: SumType, } +/// The output of an [LoadFunction](hugr_core::ops::LoadFunction) - a "pointer" +/// to a function at a specific node, instantiated with the provided type-args. +#[derive(Clone, Debug, Hash, PartialEq, Eq)] +pub struct LoadedFunction { + /// The [FuncDefn](hugr_core::ops::FuncDefn) or `FuncDecl`` that was loaded + pub func_node: N, + /// The type arguments provided when loading + pub args: Vec, +} + /// A representation of a value of [SumType], that may have one or more possible tags, /// with a [PartialValue] representation of each element-value of each possible tag. #[derive(PartialEq, Clone, Eq)] -pub struct PartialSum(pub HashMap>>); +pub struct PartialSum(pub HashMap>>); -impl PartialSum { +impl PartialSum { /// New instance for a single known tag. /// (Multi-tag instances can be created via [Self::try_join_mut].) - pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { + pub fn new_variant(tag: usize, values: impl IntoIterator>) -> Self { Self(HashMap::from([(tag, Vec::from_iter(values))])) } @@ -75,9 +85,21 @@ impl PartialSum { pv.assert_invariants(); } } + + /// Whether this sum might have the specified tag + pub fn supports_tag(&self, tag: usize) -> bool { + self.0.contains_key(&tag) + } + + /// Can this ever occur at runtime? See [PartialValue::contains_bottom] + pub fn contains_bottom(&self) -> bool { + self.0 + .iter() + .all(|(_tag, elements)| row_contains_bottom(elements)) + } } -impl PartialSum { +impl PartialSum { /// Joins (towards `Top`) self with another [PartialSum]. If successful, returns /// whether `self` has changed. /// @@ -141,12 +163,33 @@ impl PartialSum { } Ok(changed) } +} - /// Whether this sum might have the specified tag - pub fn supports_tag(&self, tag: usize) -> bool { - self.0.contains_key(&tag) - } +/// Trait implemented by value types into which [PartialValue]s can be converted, +/// so long as the PV has no [Top](PartialValue::Top), [Bottom](PartialValue::Bottom) +/// or [PartialSum]s with more than one possible tag. See [PartialSum::try_into_sum] +/// and [PartialValue::try_into_concrete]. +/// +/// `V` is the type of [AbstractValue] from which `Self` can (fallibly) be constructed, +/// `N` is the type of [HugrNode](hugr_core::core::HugrNode) for function pointers +pub trait AsConcrete: Sized { + /// Kind of error raised when creating `Self` from a value `V`, see [Self::from_value] + type ValErr: std::error::Error; + /// Kind of error that may be raised when creating `Self` from a [Sum] of `Self`s, + /// see [Self::from_sum] + type SumErr: std::error::Error; + + /// Convert an abstract value into concrete + fn from_value(val: V) -> Result; + + /// Convert a sum (of concrete values, already recursively converted) into concrete + fn from_sum(sum: Sum) -> Result; + + /// Convert a function pointer into a concrete value + fn from_func(func: LoadedFunction) -> Result>; +} +impl PartialSum { /// Turns this instance into a [Sum] of some "concrete" value type `C`, /// *if* this PartialSum has exactly one possible tag. /// @@ -155,11 +198,11 @@ impl PartialSum { /// If this PartialSum had multiple possible tags; or if `typ` was not a [TypeEnum::Sum] /// supporting the single possible tag with the correct number of elements and no row variables; /// or if converting a child element failed via [PartialValue::try_into_concrete]. - pub fn try_into_sum(self, typ: &Type) -> Result, ExtractValueError> - where - V: TryInto, - Sum: TryInto, - { + #[allow(clippy::type_complexity)] // Since C is a parameter, can't declare type aliases + pub fn try_into_sum>( + self, + typ: &Type, + ) -> Result, ExtractValueError> { if self.0.len() != 1 { return Err(ExtractValueError::MultipleVariants(self)); } @@ -185,22 +228,15 @@ impl PartialSum { num_elements: v.len(), }) } - - /// Can this ever occur at runtime? See [PartialValue::contains_bottom] - pub fn contains_bottom(&self) -> bool { - self.0 - .iter() - .all(|(_tag, elements)| row_contains_bottom(elements)) - } } /// An error converting a [PartialValue] or [PartialSum] into a concrete value type /// via [PartialValue::try_into_concrete] or [PartialSum::try_into_sum] #[derive(Clone, Debug, PartialEq, Eq, Error)] #[allow(missing_docs)] -pub enum ExtractValueError { +pub enum ExtractValueError { #[error("PartialSum value had multiple possible tags: {0}")] - MultipleVariants(PartialSum), + MultipleVariants(PartialSum), #[error("Value contained `Bottom`")] ValueIsBottom, #[error("Value contained `Top`")] @@ -209,6 +245,8 @@ pub enum ExtractValueError { CouldNotConvert(V, #[source] VE), #[error("Could not build Sum from concrete element values")] CouldNotBuildSum(#[source] SE), + #[error("Could not convert into concrete function pointer {0}")] + CouldNotLoadFunction(LoadedFunction), #[error("Expected a SumType with tag {tag} having {num_elements} elements, found {typ}")] BadSumType { typ: Type, @@ -217,14 +255,14 @@ pub enum ExtractValueError { }, } -impl PartialSum { +impl PartialSum { /// If this Sum might have the specified `tag`, get the elements inside that tag. - pub fn variant_values(&self, variant: usize) -> Option>> { + pub fn variant_values(&self, variant: usize) -> Option>> { self.0.get(&variant).cloned() } } -impl PartialOrd for PartialSum { +impl PartialOrd for PartialSum { fn partial_cmp(&self, other: &Self) -> Option { let max_key = self.0.keys().chain(other.0.keys()).copied().max().unwrap(); let (mut keys1, mut keys2) = (vec![0; max_key + 1], vec![0; max_key + 1]); @@ -254,13 +292,13 @@ impl PartialOrd for PartialSum { } } -impl std::fmt::Debug for PartialSum { +impl std::fmt::Debug for PartialSum { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { self.0.fmt(f) } } -impl Hash for PartialSum { +impl Hash for PartialSum { fn hash(&self, state: &mut H) { for (k, v) in &self.0 { k.hash(state); @@ -273,30 +311,32 @@ impl Hash for PartialSum { /// for use in dataflow analysis, including that an instance may be a [PartialSum] /// of values of the underlying representation #[derive(PartialEq, Clone, Eq, Hash, Debug)] -pub enum PartialValue { +pub enum PartialValue { /// No possibilities known (so far) Bottom, + /// The output of an [LoadFunction](hugr_core::ops::LoadFunction) + LoadedFunction(LoadedFunction), /// A single value (of the underlying representation) Value(V), /// Sum (with at least one, perhaps several, possible tags) of underlying values - PartialSum(PartialSum), + PartialSum(PartialSum), /// Might be more than one distinct value of the underlying type `V` Top, } -impl From for PartialValue { +impl From for PartialValue { fn from(v: V) -> Self { Self::Value(v) } } -impl From> for PartialValue { - fn from(v: PartialSum) -> Self { +impl From> for PartialValue { + fn from(v: PartialSum) -> Self { Self::PartialSum(v) } } -impl PartialValue { +impl PartialValue { fn assert_invariants(&self) { if let Self::PartialSum(ps) = self { ps.assert_invariants(); @@ -312,33 +352,59 @@ impl PartialValue { pub fn new_unit() -> Self { Self::new_variant(0, []) } + + /// New instance of self for a [LoadFunction](hugr_core::ops::LoadFunction) + pub fn new_load(func_node: N, args: impl Into>) -> Self { + Self::LoadedFunction(LoadedFunction { + func_node, + args: args.into(), + }) + } + + /// Tells us whether this value might be a Sum with the specified `tag` + pub fn supports_tag(&self, tag: usize) -> bool { + match self { + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + false + } + PartialValue::PartialSum(ps) => ps.supports_tag(tag), + PartialValue::Top => true, + } + } + + /// A value contains bottom means that it cannot occur during execution: + /// it may be an artefact during bootstrapping of the analysis, or else + /// the value depends upon a `panic` or a loop that + /// [never terminates](super::TailLoopTermination::NeverBreaks). + pub fn contains_bottom(&self) -> bool { + match self { + PartialValue::Bottom => true, + PartialValue::Top | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => false, + PartialValue::PartialSum(ps) => ps.contains_bottom(), + } + } } -impl PartialValue { +impl PartialValue { /// If this value might be a Sum with the specified `tag`, get the elements inside that tag. /// /// # Panics /// /// if the value is believed, for that tag, to have a number of values other than `len` - pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { + pub fn variant_values(&self, tag: usize, len: usize) -> Option>> { let vals = match self { - PartialValue::Bottom | PartialValue::Value(_) => return None, + PartialValue::Bottom | PartialValue::Value(_) | PartialValue::LoadedFunction(_) => { + return None + } PartialValue::PartialSum(ps) => ps.variant_values(tag)?, PartialValue::Top => vec![PartialValue::Top; len], }; assert_eq!(vals.len(), len); Some(vals) } +} - /// Tells us whether this value might be a Sum with the specified `tag` - pub fn supports_tag(&self, tag: usize) -> bool { - match self { - PartialValue::Bottom | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.supports_tag(tag), - PartialValue::Top => true, - } - } - +impl PartialValue { /// Turns this instance into some "concrete" value type `C`, *if* it is a single value, /// or a [Sum](PartialValue::PartialSum) (of a single tag) convertible by /// [PartialSum::try_into_sum]. @@ -348,47 +414,27 @@ impl PartialValue { /// If this PartialValue was `Top` or `Bottom`, or was a [PartialSum](PartialValue::PartialSum) /// that could not be converted into a [Sum] by [PartialSum::try_into_sum] (e.g. if `typ` is /// incorrect), or if that [Sum] could not be converted into a `V2`. - pub fn try_into_concrete(self, typ: &Type) -> Result> - where - V: TryInto, - Sum: TryInto, - { + pub fn try_into_concrete>( + self, + typ: &Type, + ) -> Result> { match self { - Self::Value(v) => v - .clone() - .try_into() - .map_err(|e| ExtractValueError::CouldNotConvert(v.clone(), e)), - Self::PartialSum(ps) => ps - .try_into_sum(typ)? - .try_into() - .map_err(ExtractValueError::CouldNotBuildSum), + Self::Value(v) => { + C::from_value(v.clone()).map_err(|e| ExtractValueError::CouldNotConvert(v, e)) + } + Self::LoadedFunction(lf) => { + C::from_func(lf).map_err(ExtractValueError::CouldNotLoadFunction) + } + Self::PartialSum(ps) => { + C::from_sum(ps.try_into_sum(typ)?).map_err(ExtractValueError::CouldNotBuildSum) + } Self::Top => Err(ExtractValueError::ValueIsTop), Self::Bottom => Err(ExtractValueError::ValueIsBottom), } } - - /// A value contains bottom means that it cannot occur during execution: - /// it may be an artefact during bootstrapping of the analysis, or else - /// the value depends upon a `panic` or a loop that - /// [never terminates](super::TailLoopTermination::NeverBreaks). - pub fn contains_bottom(&self) -> bool { - match self { - PartialValue::Bottom => true, - PartialValue::Top | PartialValue::Value(_) => false, - PartialValue::PartialSum(ps) => ps.contains_bottom(), - } - } } -impl TryFrom> for Value { - type Error = ConstTypeError; - - fn try_from(value: Sum) -> Result { - Self::sum(value.tag, value.values, value.st) - } -} - -impl Lattice for PartialValue { +impl Lattice for PartialValue { fn join_mut(&mut self, other: Self) -> bool { self.assert_invariants(); let mut old_self = Self::Top; @@ -400,13 +446,17 @@ impl Lattice for PartialValue { Some((h3, b)) => (Self::Value(h3), b), None => (Self::Top, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also join the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_join_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Top, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Top, true) - } + _ => (Self::Top, true), }; *self = res; ch @@ -423,20 +473,24 @@ impl Lattice for PartialValue { Some((h3, ch)) => (Self::Value(h3), ch), None => (Self::Bottom, true), }, + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) + if lf1.func_node == lf2.func_node => + { + // TODO we should also meet the TypeArgs but at the moment these are ignored + (Self::LoadedFunction(lf1), false) + } (Self::PartialSum(mut ps1), Self::PartialSum(ps2)) => match ps1.try_meet_mut(ps2) { Ok(ch) => (Self::PartialSum(ps1), ch), Err(_) => (Self::Bottom, true), }, - (Self::Value(_), Self::PartialSum(_)) | (Self::PartialSum(_), Self::Value(_)) => { - (Self::Bottom, true) - } + _ => (Self::Bottom, true), }; *self = res; ch } } -impl BoundedLattice for PartialValue { +impl BoundedLattice for PartialValue { fn top() -> Self { Self::Top } @@ -446,7 +500,7 @@ impl BoundedLattice for PartialValue { } } -impl PartialOrd for PartialValue { +impl PartialOrd for PartialValue { fn partial_cmp(&self, other: &Self) -> Option { use std::cmp::Ordering; match (self, other) { @@ -457,6 +511,9 @@ impl PartialOrd for PartialValue { (Self::Top, _) => Some(Ordering::Greater), (_, Self::Top) => Some(Ordering::Less), (Self::Value(v1), Self::Value(v2)) => (v1 == v2).then_some(Ordering::Equal), + (Self::LoadedFunction(lf1), Self::LoadedFunction(lf2)) => { + (lf1 == lf2).then_some(Ordering::Equal) + } (Self::PartialSum(ps1), Self::PartialSum(ps2)) => ps1.partial_cmp(ps2), _ => None, } @@ -468,19 +525,20 @@ mod test { use std::sync::Arc; use ascent::{lattice::BoundedLattice, Lattice}; + use hugr_core::NodeIndex; use itertools::{zip_eq, Itertools as _}; use prop::sample::subsequence; use proptest::prelude::*; use proptest_recurse::{StrategyExt, StrategySet}; - use super::{AbstractValue, PartialSum, PartialValue}; + use super::{AbstractValue, LoadedFunction, PartialSum, PartialValue}; #[derive(Debug, PartialEq, Eq, Clone)] enum TestSumType { Branch(Vec>>), - /// None => unit, Some => TestValue <= this *usize* - Leaf(Option), + LeafVal(usize), // contains a TestValue <= this usize + LeafPtr(usize), // contains a LoadedFunction with node <= this *usize* } #[derive(Clone, Debug, PartialEq, Eq, Hash)] @@ -509,8 +567,11 @@ mod test { fn check_value(&self, pv: &PartialValue) -> bool { match (self, pv) { (_, PartialValue::Bottom) | (_, PartialValue::Top) => true, - (Self::Leaf(None), _) => pv == &PartialValue::new_unit(), - (Self::Leaf(Some(max)), PartialValue::Value(TestValue(val))) => val <= max, + (Self::LeafVal(max), PartialValue::Value(TestValue(val))) => val <= max, + ( + Self::LeafPtr(max), + PartialValue::LoadedFunction(LoadedFunction { func_node, args }), + ) => args.is_empty() && func_node.index() <= *max, (Self::Branch(sop), PartialValue::PartialSum(ps)) => { for (k, v) in &ps.0 { if *k >= sop.len() { @@ -537,8 +598,11 @@ mod test { fn arbitrary_with(params: Self::Parameters) -> Self::Strategy { fn arb(params: SumTypeParams, set: &mut StrategySet) -> SBoxedStrategy { use proptest::collection::vec; - let int_strat = (0..usize::MAX).prop_map(|i| TestSumType::Leaf(Some(i))); - let leaf_strat = prop_oneof![Just(TestSumType::Leaf(None)), int_strat]; + let leaf_strat = prop_oneof![ + (0..usize::MAX).prop_map(TestSumType::LeafVal), + // This is the maximum value accepted by portgraph::NodeIndex::new + (0..((2usize ^ 31) - 2)).prop_map(TestSumType::LeafPtr) + ]; leaf_strat.prop_mutually_recursive( params.depth as u32, params.desired_size as u32, @@ -605,11 +669,18 @@ mod test { ust: &TestSumType, ) -> impl Strategy> { match ust { - TestSumType::Leaf(None) => Just(PartialValue::new_unit()).boxed(), - TestSumType::Leaf(Some(i)) => (0..*i) + TestSumType::LeafVal(i) => (0..=*i) .prop_map(TestValue) .prop_map(PartialValue::from) .boxed(), + TestSumType::LeafPtr(i) => (0..=*i) + .prop_map(|i| { + PartialValue::LoadedFunction(LoadedFunction { + func_node: portgraph::NodeIndex::new(i).into(), + args: vec![], + }) + }) + .boxed(), TestSumType::Branch(sop) => partial_sum_strat(sop).prop_map(PartialValue::from).boxed(), } } diff --git a/hugr-passes/src/dataflow/results.rs b/hugr-passes/src/dataflow/results.rs index c40f1d87f..c4a94a9e7 100644 --- a/hugr-passes/src/dataflow/results.rs +++ b/hugr-passes/src/dataflow/results.rs @@ -1,17 +1,19 @@ use std::collections::HashMap; -use hugr_core::{HugrView, IncomingPort, PortIndex, Wire}; +use hugr_core::{HugrView, PortIndex, Wire}; -use super::{partial_value::ExtractValueError, AbstractValue, PartialValue, Sum}; +use super::{ + datalog::InWire, partial_value::ExtractValueError, AbstractValue, AsConcrete, PartialValue, +}; /// Results of a dataflow analysis, packaged with the Hugr for easy inspection. /// Methods allow inspection, specifically [read_out_wire](Self::read_out_wire). pub struct AnalysisResults { pub(super) hugr: H, - pub(super) in_wire_value: Vec<(H::Node, IncomingPort, PartialValue)>, + pub(super) in_wire_value: Vec>, pub(super) case_reachable: Vec<(H::Node, H::Node)>, pub(super) bb_reachable: Vec<(H::Node, H::Node)>, - pub(super) out_wire_values: HashMap, PartialValue>, + pub(super) out_wire_values: HashMap, PartialValue>, } impl AnalysisResults { @@ -21,7 +23,7 @@ impl AnalysisResults { } /// Gets the lattice value computed for the given wire - pub fn read_out_wire(&self, w: Wire) -> Option> { + pub fn read_out_wire(&self, w: Wire) -> Option> { self.out_wire_values.get(&w).cloned() } @@ -84,13 +86,11 @@ impl AnalysisResults { /// `None` if the analysis did not produce a result for that wire, or if /// the Hugr did not have a [Type](hugr_core::types::Type) for the specified wire /// `Some(e)` if [conversion to a concrete value](PartialValue::try_into_concrete) failed with error `e` - pub fn try_read_wire_concrete( + #[allow(clippy::type_complexity)] + pub fn try_read_wire_concrete>( &self, w: Wire, - ) -> Result>> - where - V2: TryFrom + TryFrom, Error = SE>, - { + ) -> Result>> { let v = self.read_out_wire(w).ok_or(None)?; let (_, typ) = self .hugr @@ -116,7 +116,7 @@ pub enum TailLoopTermination { } impl TailLoopTermination { - fn from_control_value(v: &PartialValue) -> Self { + fn from_control_value(v: &PartialValue) -> Self { let (may_continue, may_break) = (v.supports_tag(0), v.supports_tag(1)); if may_break { if may_continue { diff --git a/hugr-passes/src/dataflow/test.rs b/hugr-passes/src/dataflow/test.rs index 3af0097f7..1c4b4e439 100644 --- a/hugr-passes/src/dataflow/test.rs +++ b/hugr-passes/src/dataflow/test.rs @@ -1,10 +1,12 @@ +use std::convert::Infallible; + use ascent::{lattice::BoundedLattice, Lattice}; -use hugr_core::builder::{CFGBuilder, Container, DataflowHugr, ModuleBuilder}; +use hugr_core::builder::{inout_sig, CFGBuilder, Container, DataflowHugr, ModuleBuilder}; use hugr_core::hugr::views::{DescendantsGraph, HierarchyView}; use hugr_core::ops::handle::DfgID; -use hugr_core::ops::TailLoop; -use hugr_core::types::TypeRow; +use hugr_core::ops::{CallIndirect, TailLoop}; +use hugr_core::types::{ConstTypeError, TypeRow}; use hugr_core::{ builder::{endo_sig, DFGBuilder, Dataflow, DataflowSubContainer, HugrBuilder, SubContainer}, extension::{ @@ -19,7 +21,10 @@ use hugr_core::{ use hugr_core::{Hugr, Node, Wire}; use rstest::{fixture, rstest}; -use super::{AbstractValue, ConstLoader, DFContext, Machine, PartialValue, TailLoopTermination}; +use super::{ + AbstractValue, AsConcrete, ConstLoader, DFContext, LoadedFunction, Machine, PartialValue, Sum, + TailLoopTermination, +}; // ------- Minimal implementation of DFContext and AbstractValue ------- #[derive(Debug, Clone, PartialEq, Eq, Hash)] @@ -35,10 +40,22 @@ impl ConstLoader for TestContext { impl DFContext for TestContext {} // This allows testing creation of tuple/sum Values (only) -impl From for Value { - fn from(v: Void) -> Self { +impl AsConcrete for Value { + type ValErr = Infallible; + + type SumErr = ConstTypeError; + + fn from_value(v: Void) -> Result { match v {} } + + fn from_sum(value: Sum) -> Result { + Self::sum(value.tag, value.values, value.st) + } + + fn from_func(func: LoadedFunction) -> Result> { + Err(func) + } } fn pv_false() -> PartialValue { @@ -295,9 +312,7 @@ fn test_conditional() { let cond_r1: Value = results.try_read_wire_concrete(cond_o1).unwrap(); assert_eq!(cond_r1, Value::false_val()); - assert!(results - .try_read_wire_concrete::(cond_o2) - .is_err()); + assert!(results.try_read_wire_concrete::(cond_o2).is_err()); assert_eq!(results.case_reachable(case1.node()), Some(false)); // arg_pv is variant 1 or 2 only assert_eq!(results.case_reachable(case2.node()), Some(true)); @@ -547,3 +562,78 @@ fn test_module() { ); } } + +#[rstest] +#[case(pv_false(), pv_false())] +#[case(pv_false(), pv_true())] +#[case(pv_true(), pv_false())] +#[case(pv_true(), pv_true())] +fn call_indirect(#[case] inp1: PartialValue, #[case] inp2: PartialValue) { + let b2b = || Signature::new_endo(bool_t()); + let mut dfb = DFGBuilder::new(inout_sig(vec![bool_t(); 3], vec![bool_t(); 2])).unwrap(); + + let [id1, id2] = ["id1", "[id2]"].map(|name| { + let fb = dfb.define_function(name, b2b()).unwrap(); + let [inp] = fb.input_wires_arr(); + fb.finish_with_outputs([inp]).unwrap() + }); + + let [inp_direct, which, inp_indirect] = dfb.input_wires_arr(); + let [res1] = dfb + .call(id1.handle(), &[], [inp_direct]) + .unwrap() + .outputs_arr(); + + // We'll unconditionally load both functions, to demonstrate that it's + // the CallIndirect that matters, not just which functions are loaded. + let lf1 = dfb.load_func(id1.handle(), &[]).unwrap(); + let lf2 = dfb.load_func(id2.handle(), &[]).unwrap(); + let bool_func = || Type::new_function(b2b()); + let mut cond = dfb + .conditional_builder( + (vec![type_row![]; 2], which), + [(bool_func(), lf1), (bool_func(), lf2)], + bool_func().into(), + ) + .unwrap(); + let case_false = cond.case_builder(0).unwrap(); + let [f0, _f1] = case_false.input_wires_arr(); + case_false.finish_with_outputs([f0]).unwrap(); + let case_true = cond.case_builder(1).unwrap(); + let [_f0, f1] = case_true.input_wires_arr(); + case_true.finish_with_outputs([f1]).unwrap(); + let [tgt] = cond.finish_sub_container().unwrap().outputs_arr(); + let [res2] = dfb + .add_dataflow_op(CallIndirect { signature: b2b() }, [tgt, inp_indirect]) + .unwrap() + .outputs_arr(); + let h = dfb.finish_hugr_with_outputs([res1, res2]).unwrap(); + + let run = |which| { + Machine::new(&h).run( + TestContext, + [ + (0.into(), inp1.clone()), + (1.into(), which), + (2.into(), inp2.clone()), + ], + ) + }; + let (w1, w2) = (Wire::new(h.root(), 0), Wire::new(h.root(), 1)); + + // 1. Test with `which` unknown -> second output unknown + let results = run(PartialValue::Top); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(PartialValue::Top)); + + // 2. Test with `which` selecting second function -> both passthrough + let results = run(pv_true()); + assert_eq!(results.read_out_wire(w1), Some(inp1.clone())); + assert_eq!(results.read_out_wire(w2), Some(inp2.clone())); + + //3. Test with `which` selecting first function -> alias + let results = run(pv_false()); + let out = Some(inp1.join(inp2)); + assert_eq!(results.read_out_wire(w1), out); + assert_eq!(results.read_out_wire(w2), out); +} diff --git a/hugr-passes/src/dataflow/value_row.rs b/hugr-passes/src/dataflow/value_row.rs index 50cf10318..43c842d91 100644 --- a/hugr-passes/src/dataflow/value_row.rs +++ b/hugr-passes/src/dataflow/value_row.rs @@ -5,25 +5,25 @@ use std::{ ops::{Index, IndexMut}, }; -use ascent::{lattice::BoundedLattice, Lattice}; +use ascent::Lattice; use itertools::zip_eq; use super::{AbstractValue, PartialValue}; #[derive(PartialEq, Clone, Debug, Eq, Hash)] -pub(super) struct ValueRow(Vec>); +pub(super) struct ValueRow(Vec>); -impl ValueRow { +impl ValueRow { pub fn new(len: usize) -> Self { - Self(vec![PartialValue::bottom(); len]) + Self(vec![PartialValue::Bottom; len]) } - pub fn set(mut self, idx: usize, v: PartialValue) -> Self { + pub fn set(mut self, idx: usize, v: PartialValue) -> Self { *self.0.get_mut(idx).unwrap() = v; self } - pub fn singleton(v: PartialValue) -> Self { + pub fn singleton(v: PartialValue) -> Self { Self(vec![v]) } @@ -34,25 +34,25 @@ impl ValueRow { &self, variant: usize, len: usize, - ) -> Option>> { + ) -> Option>> { let vals = self[0].variant_values(variant, len)?; Some(vals.into_iter().chain(self.0[1..].to_owned())) } } -impl FromIterator> for ValueRow { - fn from_iter>>(iter: T) -> Self { +impl FromIterator> for ValueRow { + fn from_iter>>(iter: T) -> Self { Self(iter.into_iter().collect()) } } -impl PartialOrd for ValueRow { +impl PartialOrd for ValueRow { fn partial_cmp(&self, other: &Self) -> Option { self.0.partial_cmp(&other.0) } } -impl Lattice for ValueRow { +impl Lattice for ValueRow { fn join_mut(&mut self, other: Self) -> bool { assert_eq!(self.0.len(), other.0.len()); let mut changed = false; @@ -72,30 +72,30 @@ impl Lattice for ValueRow { } } -impl IntoIterator for ValueRow { - type Item = PartialValue; +impl IntoIterator for ValueRow { + type Item = PartialValue; - type IntoIter = > as IntoIterator>::IntoIter; + type IntoIter = > as IntoIterator>::IntoIter; fn into_iter(self) -> Self::IntoIter { self.0.into_iter() } } -impl Index for ValueRow +impl Index for ValueRow where - Vec>: Index, + Vec>: Index, { - type Output = > as Index>::Output; + type Output = > as Index>::Output; fn index(&self, index: Idx) -> &Self::Output { self.0.index(index) } } -impl IndexMut for ValueRow +impl IndexMut for ValueRow where - Vec>: IndexMut, + Vec>: IndexMut, { fn index_mut(&mut self, index: Idx) -> &mut Self::Output { self.0.index_mut(index) From 63477565de0dbfb8027736cf905f6f148e2ddcab Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Agust=C3=ADn=20Borgna?= <121866228+aborgna-q@users.noreply.github.com> Date: Wed, 16 Apr 2025 14:52:21 +0100 Subject: [PATCH 06/18] feat: Make NodeHandle generic (#2092) Adds a generic node type to the `NodeHandle` type. This is a required change for #2029. drive-by: Implement the "Link the NodeHandles to the OpType" TODO --- hugr-core/src/ops.rs | 16 +++++++- hugr-core/src/ops/handle.rs | 73 ++++++++++++++++++++----------------- 2 files changed, 54 insertions(+), 35 deletions(-) diff --git a/hugr-core/src/ops.rs b/hugr-core/src/ops.rs index 0c7d3bb3f..ce0d44de0 100644 --- a/hugr-core/src/ops.rs +++ b/hugr-core/src/ops.rs @@ -9,6 +9,7 @@ pub mod module; pub mod sum; pub mod tag; pub mod validate; +use crate::core::HugrNode; use crate::extension::resolution::{ collect_op_extension, collect_op_types_extensions, ExtensionCollectionError, }; @@ -20,6 +21,7 @@ use crate::types::{EdgeKind, Signature, Substitution}; use crate::{Direction, OutgoingPort, Port}; use crate::{IncomingPort, PortIndex}; use derive_more::Display; +use handle::NodeHandle; use paste::paste; use portgraph::NodeIndex; @@ -41,7 +43,6 @@ pub use tag::OpTag; #[derive(Clone, Debug, PartialEq, serde::Serialize, serde::Deserialize)] #[cfg_attr(test, derive(proptest_derive::Arbitrary))] /// The concrete operation types for a node in the HUGR. -// TODO: Link the NodeHandles to the OpType. #[non_exhaustive] #[allow(missing_docs)] #[serde(tag = "op")] @@ -377,6 +378,19 @@ pub trait OpTrait: Sized + Clone { /// Tag identifying the operation. fn tag(&self) -> OpTag; + /// Tries to create a specific [`NodeHandle`] for a node with this operation + /// type. + /// + /// Fails if the operation's [`OpTrait::tag`] does not match the + /// [`NodeHandle::TAG`] of the requested handle. + fn try_node_handle(&self, node: N) -> Option + where + N: HugrNode, + H: NodeHandle + From, + { + H::TAG.is_superset(self.tag()).then(|| node.into()) + } + /// The signature of the operation. /// /// Only dataflow operations have a signature, otherwise returns None. diff --git a/hugr-core/src/ops/handle.rs b/hugr-core/src/ops/handle.rs index d7fe16419..a5a3c294a 100644 --- a/hugr-core/src/ops/handle.rs +++ b/hugr-core/src/ops/handle.rs @@ -1,4 +1,5 @@ //! Handles to nodes in HUGR. +use crate::core::HugrNode; use crate::types::{Type, TypeBound}; use crate::Node; @@ -9,12 +10,12 @@ use super::{AliasDecl, OpTag}; /// Common trait for handles to a node. /// Typically wrappers around [`Node`]. -pub trait NodeHandle: Clone { +pub trait NodeHandle: Clone { /// The most specific operation tag associated with the handle. const TAG: OpTag; /// Index of underlying node. - fn node(&self) -> Node; + fn node(&self) -> N; /// Operation tag for the handle. #[inline] @@ -23,7 +24,7 @@ pub trait NodeHandle: Clone { } /// Cast the handle to a different more general tag. - fn try_cast>(&self) -> Option { + fn try_cast + From>(&self) -> Option { T::TAG.is_superset(Self::TAG).then(|| self.node().into()) } @@ -36,30 +37,30 @@ pub trait NodeHandle: Clone { /// Trait for handles that contain children. /// /// The allowed children handles are defined by the associated type. -pub trait ContainerHandle: NodeHandle { +pub trait ContainerHandle: NodeHandle { /// Handle type for the children of this node. - type ChildrenHandle: NodeHandle; + type ChildrenHandle: NodeHandle; } #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowOp](crate::ops::dataflow). -pub struct DataflowOpID(Node); +pub struct DataflowOpID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DFG](crate::ops::DFG) node. -pub struct DfgID(Node); +pub struct DfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [CFG](crate::ops::CFG) node. -pub struct CfgID(Node); +pub struct CfgID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a module [Module](crate::ops::Module) node. -pub struct ModuleRootID(Node); +pub struct ModuleRootID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [module op](crate::ops::module) node. -pub struct ModuleID(Node); +pub struct ModuleID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [def](crate::ops::OpType::FuncDefn) @@ -67,7 +68,7 @@ pub struct ModuleID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct FuncID(Node); +pub struct FuncID(N); #[derive(Debug, Clone, PartialEq, Eq)] /// Handle to an [AliasDefn](crate::ops::OpType::AliasDefn) @@ -75,15 +76,15 @@ pub struct FuncID(Node); /// /// The `DEF` const generic is used to indicate whether the function is /// defined or just declared. -pub struct AliasID { - node: Node, +pub struct AliasID { + node: N, name: SmolStr, bound: TypeBound, } -impl AliasID { +impl AliasID { /// Construct new AliasID - pub fn new(node: Node, name: SmolStr, bound: TypeBound) -> Self { + pub fn new(node: N, name: SmolStr, bound: TypeBound) -> Self { Self { node, name, bound } } @@ -99,27 +100,27 @@ impl AliasID { #[derive(DerFrom, Debug, Clone, PartialEq, Eq)] /// Handle to a [Const](crate::ops::OpType::Const) node. -pub struct ConstID(Node); +pub struct ConstID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [DataflowBlock](crate::ops::DataflowBlock) or [Exit](crate::ops::ExitBlock) node. -pub struct BasicBlockID(Node); +pub struct BasicBlockID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Case](crate::ops::Case) node. -pub struct CaseID(Node); +pub struct CaseID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [TailLoop](crate::ops::TailLoop) node. -pub struct TailLoopID(Node); +pub struct TailLoopID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a [Conditional](crate::ops::Conditional) node. -pub struct ConditionalID(Node); +pub struct ConditionalID(N); #[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, DerFrom, Debug)] /// Handle to a dataflow container node. -pub struct DataflowParentID(Node); +pub struct DataflowParentID(N); /// Implements the `NodeHandle` trait for a tuple struct that contains just a /// NodeIndex. Takes the name of the struct, and the corresponding OpTag. @@ -131,11 +132,11 @@ macro_rules! impl_nodehandle { impl_nodehandle!($name, $tag, 0); }; ($name:ident, $tag:expr, $node_attr:tt) => { - impl NodeHandle for $name { + impl NodeHandle for $name { const TAG: OpTag = $tag; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.$node_attr } } @@ -156,35 +157,35 @@ impl_nodehandle!(ConstID, OpTag::Const); impl_nodehandle!(BasicBlockID, OpTag::DataflowBlock); -impl NodeHandle for FuncID { +impl NodeHandle for FuncID { const TAG: OpTag = OpTag::Function; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.0 } } -impl NodeHandle for AliasID { +impl NodeHandle for AliasID { const TAG: OpTag = OpTag::Alias; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { self.node } } -impl NodeHandle for Node { +impl NodeHandle for N { const TAG: OpTag = OpTag::Any; #[inline] - fn node(&self) -> Node { + fn node(&self) -> N { *self } } /// Implements the `ContainerHandle` trait, with the given child handle type. macro_rules! impl_containerHandle { - ($name:path, $children:ident) => { - impl ContainerHandle for $name { - type ChildrenHandle = $children; + ($name:ident, $children:ident) => { + impl ContainerHandle for $name { + type ChildrenHandle = $children; } }; } @@ -197,5 +198,9 @@ impl_containerHandle!(CaseID, DataflowOpID); impl_containerHandle!(ModuleRootID, ModuleID); impl_containerHandle!(CfgID, BasicBlockID); impl_containerHandle!(BasicBlockID, DataflowOpID); -impl_containerHandle!(FuncID, DataflowOpID); -impl_containerHandle!(AliasID, DataflowOpID); +impl ContainerHandle for FuncID { + type ChildrenHandle = DataflowOpID; +} +impl ContainerHandle for AliasID { + type ChildrenHandle = DataflowOpID; +} From 5b43c0d351720a2d1ba66053467d64573bbbb9c6 Mon Sep 17 00:00:00 2001 From: Seyon Sivarajah Date: Thu, 17 Apr 2025 10:38:12 +0100 Subject: [PATCH 07/18] feat!: remove ExtensionValue (#2093) Closes #1595 BREAKING CHANGE: `values` field in `Extension` and `ExtensionValue` struct/class removed in rust and python. Use 0-input ops that return constant values. --- hugr-core/src/extension.rs | 64 +------------------ .../src/extension/resolution/extension.rs | 11 +--- hugr-core/src/hugr/validate/test.rs | 8 +-- hugr-core/src/std_extensions/logic.rs | 26 +------- hugr-py/src/hugr/_serialization/extension.py | 21 ------ hugr-py/src/hugr/ext.py | 42 +----------- .../_json_defs/arithmetic/conversions.json | 1 - .../hugr/std/_json_defs/arithmetic/float.json | 1 - .../_json_defs/arithmetic/float/types.json | 1 - .../hugr/std/_json_defs/arithmetic/int.json | 1 - .../std/_json_defs/arithmetic/int/types.json | 1 - .../std/_json_defs/collections/array.json | 1 - .../hugr/std/_json_defs/collections/list.json | 1 - .../_json_defs/collections/static_array.json | 1 - hugr-py/src/hugr/std/_json_defs/logic.json | 28 -------- hugr-py/src/hugr/std/_json_defs/prelude.json | 1 - hugr-py/src/hugr/std/_json_defs/ptr.json | 1 - specification/schema/hugr_schema_live.json | 30 --------- .../schema/hugr_schema_strict_live.json | 30 --------- .../schema/testing_hugr_schema_live.json | 30 --------- .../testing_hugr_schema_strict_live.json | 30 --------- .../arithmetic/conversions.json | 1 - .../std_extensions/arithmetic/float.json | 1 - .../arithmetic/float/types.json | 1 - .../std_extensions/arithmetic/int.json | 1 - .../std_extensions/arithmetic/int/types.json | 1 - .../std_extensions/collections/array.json | 1 - .../std_extensions/collections/list.json | 1 - .../collections/static_array.json | 1 - specification/std_extensions/logic.json | 28 -------- specification/std_extensions/prelude.json | 1 - specification/std_extensions/ptr.json | 1 - 32 files changed, 7 insertions(+), 361 deletions(-) diff --git a/hugr-core/src/extension.rs b/hugr-core/src/extension.rs index b6e059050..23238ccfd 100644 --- a/hugr-core/src/extension.rs +++ b/hugr-core/src/extension.rs @@ -19,9 +19,8 @@ use derive_more::Display; use thiserror::Error; use crate::hugr::IdentList; -use crate::ops::constant::{ValueName, ValueNameRef}; use crate::ops::custom::{ExtensionOp, OpaqueOp}; -use crate::ops::{self, OpName, OpNameRef}; +use crate::ops::{OpName, OpNameRef}; use crate::types::type_param::{TypeArg, TypeArgError, TypeParam}; use crate::types::RowVariable; use crate::types::{check_typevar_decl, CustomType, Substitution, TypeBound, TypeName}; @@ -497,37 +496,6 @@ impl CustomConcrete for CustomType { } } -/// A constant value provided by a extension. -/// Must be an instance of a type available to the extension. -#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] -pub struct ExtensionValue { - extension: ExtensionId, - name: ValueName, - typed_value: ops::Value, -} - -impl ExtensionValue { - /// Returns a reference to the typed value of this [`ExtensionValue`]. - pub fn typed_value(&self) -> &ops::Value { - &self.typed_value - } - - /// Returns a mutable reference to the typed value of this [`ExtensionValue`]. - pub(super) fn typed_value_mut(&mut self) -> &mut ops::Value { - &mut self.typed_value - } - - /// Returns a reference to the name of this [`ExtensionValue`]. - pub fn name(&self) -> &str { - self.name.as_str() - } - - /// Returns a reference to the extension this [`ExtensionValue`] belongs to. - pub fn extension(&self) -> &ExtensionId { - &self.extension - } -} - /// A unique identifier for a extension. /// /// The actual [`Extension`] is stored externally. @@ -583,8 +551,6 @@ pub struct Extension { pub runtime_reqs: ExtensionSet, /// Types defined by this extension. types: BTreeMap, - /// Static values defined by this extension. - values: BTreeMap, /// Operation declarations with serializable definitions. // Note: serde will serialize this because we configure with `features=["rc"]`. // That will clone anything that has multiple references, but each @@ -608,7 +574,6 @@ impl Extension { version, runtime_reqs: Default::default(), types: Default::default(), - values: Default::default(), operations: Default::default(), } } @@ -680,11 +645,6 @@ impl Extension { self.types.get(type_name) } - /// Allows read-only access to the values in this Extension - pub fn get_value(&self, value_name: &ValueNameRef) -> Option<&ExtensionValue> { - self.values.get(value_name) - } - /// Returns the name of the extension. pub fn name(&self) -> &ExtensionId { &self.name @@ -705,25 +665,6 @@ impl Extension { self.types.iter() } - /// Add a named static value to the extension. - pub fn add_value( - &mut self, - name: impl Into, - typed_value: ops::Value, - ) -> Result<&mut ExtensionValue, ExtensionBuildError> { - let extension_value = ExtensionValue { - extension: self.name.clone(), - name: name.into(), - typed_value, - }; - match self.values.entry(extension_value.name.clone()) { - btree_map::Entry::Occupied(_) => { - Err(ExtensionBuildError::ValueExists(extension_value.name)) - } - btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension_value)), - } - } - /// Instantiate an [`ExtensionOp`] which references an [`OpDef`] in this extension. pub fn instantiate_extension_op( &self, @@ -784,9 +725,6 @@ pub enum ExtensionBuildError { /// Existing [`TypeDef`] #[error("Extension already has an type called {0}.")] TypeDefExists(TypeName), - /// Existing [`ExtensionValue`] - #[error("Extension already has an extension value called {0}.")] - ValueExists(ValueName), } /// A set of extensions identified by their unique [`ExtensionId`]. diff --git a/hugr-core/src/extension/resolution/extension.rs b/hugr-core/src/extension/resolution/extension.rs index 61adc1dea..05c0faf69 100644 --- a/hugr-core/src/extension/resolution/extension.rs +++ b/hugr-core/src/extension/resolution/extension.rs @@ -9,7 +9,7 @@ use std::sync::Arc; use crate::extension::{Extension, ExtensionId, ExtensionRegistry, OpDef, SignatureFunc, TypeDef}; -use super::types_mut::{resolve_signature_exts, resolve_value_exts}; +use super::types_mut::resolve_signature_exts; use super::{ExtensionResolutionError, WeakExtensionRegistry}; impl ExtensionRegistry { @@ -59,14 +59,7 @@ impl Extension { for type_def in self.types.values_mut() { resolve_typedef_exts(&self.name, type_def, extensions, &mut used_extensions)?; } - for val in self.values.values_mut() { - resolve_value_exts( - None, - val.typed_value_mut(), - extensions, - &mut used_extensions, - )?; - } + let ops = mem::take(&mut self.operations); for (op_id, mut op_def) in ops { // TODO: We should be able to clone the definition if needed by using `make_mut`, diff --git a/hugr-core/src/hugr/validate/test.rs b/hugr-core/src/hugr/validate/test.rs index ecb417ec5..37157020d 100644 --- a/hugr-core/src/hugr/validate/test.rs +++ b/hugr-core/src/hugr/validate/test.rs @@ -20,7 +20,6 @@ use crate::ops::handle::NodeHandle; use crate::ops::{self, OpType, Value}; use crate::std_extensions::logic::test::{and_op, or_op}; use crate::std_extensions::logic::LogicOp; -use crate::std_extensions::logic::{self}; use crate::types::type_param::{TypeArg, TypeArgError}; use crate::types::{ CustomType, FuncValueType, PolyFuncType, PolyFuncTypeRV, Signature, Type, TypeBound, TypeRV, @@ -307,12 +306,7 @@ fn test_local_const() { port_kind: EdgeKind::Value(bool_t()) }) ); - let const_op: ops::Const = logic::EXTENSION - .get_value(&logic::TRUE_NAME) - .unwrap() - .typed_value() - .clone() - .into(); + let const_op: ops::Const = ops::Value::from_bool(true).into(); // Second input of Xor from a constant let cst = h.add_node_with_parent(h.root(), const_op); let lcst = h.add_node_with_parent(h.root(), ops::LoadConstant { datatype: bool_t() }); diff --git a/hugr-core/src/std_extensions/logic.rs b/hugr-core/src/std_extensions/logic.rs index fcc8be9d3..20977cb51 100644 --- a/hugr-core/src/std_extensions/logic.rs +++ b/hugr-core/src/std_extensions/logic.rs @@ -124,13 +124,6 @@ pub const VERSION: semver::Version = semver::Version::new(0, 1, 0); fn extension() -> Arc { Extension::new_arc(EXTENSION_ID, VERSION, |extension, extension_ref| { LogicOp::load_all_ops(extension, extension_ref).unwrap(); - - extension - .add_value(FALSE_NAME, ops::Value::false_val()) - .unwrap(); - extension - .add_value(TRUE_NAME, ops::Value::true_val()) - .unwrap(); }) } @@ -172,12 +165,9 @@ fn read_inputs(consts: &[(IncomingPort, ops::Value)]) -> Option> { pub(crate) mod test { use std::sync::Arc; - use super::{extension, LogicOp, FALSE_NAME, TRUE_NAME}; + use super::{extension, LogicOp}; use crate::{ - extension::{ - prelude::bool_t, - simple_op::{MakeOpDef, MakeRegisteredOp}, - }, + extension::simple_op::{MakeOpDef, MakeRegisteredOp}, ops::{NamedOp, Value}, Extension, }; @@ -207,18 +197,6 @@ pub(crate) mod test { } } - #[test] - fn test_values() { - let r: Arc = extension(); - let false_val = r.get_value(&FALSE_NAME).unwrap(); - let true_val = r.get_value(&TRUE_NAME).unwrap(); - - for v in [false_val, true_val] { - let simpl = v.typed_value().get_type(); - assert_eq!(simpl, bool_t()); - } - } - /// Generate a logic extension "and" operation over [`crate::prelude::bool_t()`] pub(crate) fn and_op() -> LogicOp { LogicOp::And diff --git a/hugr-py/src/hugr/_serialization/extension.py b/hugr-py/src/hugr/_serialization/extension.py index 429bdd785..95e59754e 100644 --- a/hugr-py/src/hugr/_serialization/extension.py +++ b/hugr-py/src/hugr/_serialization/extension.py @@ -8,7 +8,6 @@ from hugr.hugr.base import Hugr from hugr.utils import deser_it -from .ops import Value from .serial_hugr import SerialHugr, serialization_version from .tys import ( ConfiguredBaseModel, @@ -20,7 +19,6 @@ ) if TYPE_CHECKING: - from .ops import Value from .serial_hugr import SerialHugr @@ -62,20 +60,6 @@ def deserialize(self, extension: ext.Extension) -> ext.TypeDef: ) -class ExtensionValue(ConfiguredBaseModel): - extension: ExtensionId - name: str - typed_value: Value - - def deserialize(self, extension: ext.Extension) -> ext.ExtensionValue: - return extension.add_extension_value( - ext.ExtensionValue( - name=self.name, - val=self.typed_value.deserialize(), - ) - ) - - # -------------------------------------- # --------------- OpDef ---------------- # -------------------------------------- @@ -124,7 +108,6 @@ class Extension(ConfiguredBaseModel): name: ExtensionId runtime_reqs: set[ExtensionId] types: dict[str, TypeDef] - values: dict[str, ExtensionValue] operations: dict[str, OpDef] @classmethod @@ -146,10 +129,6 @@ def deserialize(self) -> ext.Extension: assert k == o.name, "Operation name must match key" e.add_op_def(o.deserialize(e)) - for k, v in self.values.items(): - assert k == v.name, "Value name must match key" - e.add_extension_value(v.deserialize(e)) - return e diff --git a/hugr-py/src/hugr/ext.py b/hugr-py/src/hugr/ext.py index 494ea3c69..7bd02f982 100644 --- a/hugr-py/src/hugr/ext.py +++ b/hugr-py/src/hugr/ext.py @@ -8,7 +8,7 @@ from semver import Version import hugr._serialization.extension as ext_s -from hugr import ops, tys, val +from hugr import ops, tys from hugr.utils import ser_it __all__ = [ @@ -18,7 +18,6 @@ "FixedHugr", "OpDefSig", "OpDef", - "ExtensionValue", "Extension", "Version", ] @@ -246,23 +245,6 @@ def instantiate( return ops.ExtOp(self, concrete_signature, list(args or [])) -@dataclass -class ExtensionValue(ExtensionObject): - """A value defined in an :class:`Extension`.""" - - #: The name of the value. - name: str - #: Value payload. - val: val.Value - - def _to_serial(self) -> ext_s.ExtensionValue: - return ext_s.ExtensionValue( - extension=self.get_extension().name, - name=self.name, - typed_value=self.val._to_serial_root(), - ) - - T = TypeVar("T", bound=ops.RegisteredOp) @@ -278,8 +260,6 @@ class Extension: runtime_reqs: set[ExtensionId] = field(default_factory=set) #: Type definitions in the extension. types: dict[str, TypeDef] = field(default_factory=dict) - #: Values defined in the extension. - values: dict[str, ExtensionValue] = field(default_factory=dict) #: Operation definitions in the extension. operations: dict[str, OpDef] = field(default_factory=dict) @@ -295,7 +275,6 @@ def _to_serial(self) -> ext_s.Extension: version=self.version, # type: ignore[arg-type] runtime_reqs=self.runtime_reqs, types={k: v._to_serial() for k, v in self.types.items()}, - values={k: v._to_serial() for k, v in self.values.items()}, operations={k: v._to_serial() for k, v in self.operations.items()}, ) @@ -347,19 +326,6 @@ def add_type_def(self, type_def: TypeDef) -> TypeDef: self.types[type_def.name] = type_def return self.types[type_def.name] - def add_extension_value(self, extension_value: ExtensionValue) -> ExtensionValue: - """Add a value to the extension. - - Args: - extension_value: The value to add. - - Returns: - The added value, now associated with the extension. - """ - extension_value._extension = self - self.values[extension_value.name] = extension_value - return self.values[extension_value.name] - @dataclass class OperationNotFound(NotFound): """Operation not found in extension.""" @@ -406,12 +372,6 @@ def get_type(self, name: str) -> TypeDef: class ValueNotFound(NotFound): """Value not found in extension.""" - def get_value(self, name: str) -> ExtensionValue: - try: - return self.values[name] - except KeyError as e: - raise self.ValueNotFound(name) from e - T = TypeVar("T", bound=ops.RegisteredOp) def register_op( diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json +++ b/hugr-py/src/hugr/std/_json_defs/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/hugr-py/src/hugr/std/_json_defs/collections/array.json b/hugr-py/src/hugr/std/_json_defs/collections/array.json index 21e405151..375e13c72 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/list.json b/hugr-py/src/hugr/std/_json_defs/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/list.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/hugr-py/src/hugr/std/_json_defs/collections/static_array.json +++ b/hugr-py/src/hugr/std/_json_defs/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/hugr-py/src/hugr/std/_json_defs/logic.json b/hugr-py/src/hugr/std/_json_defs/logic.json index ad9f02019..ff29d2c21 100644 --- a/hugr-py/src/hugr/std/_json_defs/logic.json +++ b/hugr-py/src/hugr/std/_json_defs/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/hugr-py/src/hugr/std/_json_defs/prelude.json b/hugr-py/src/hugr/std/_json_defs/prelude.json index e11ba2388..ec392b155 100644 --- a/hugr-py/src/hugr/std/_json_defs/prelude.json +++ b/hugr-py/src/hugr/std/_json_defs/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/hugr-py/src/hugr/std/_json_defs/ptr.json b/hugr-py/src/hugr/std/_json_defs/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/hugr-py/src/hugr/std/_json_defs/ptr.json +++ b/hugr-py/src/hugr/std/_json_defs/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", diff --git a/specification/schema/hugr_schema_live.json b/specification/schema/hugr_schema_live.json index 9e7d8c40c..ea08dff5b 100644 --- a/specification/schema/hugr_schema_live.json +++ b/specification/schema/hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/hugr_schema_strict_live.json b/specification/schema/hugr_schema_strict_live.json index 6f436f969..8b65bae94 100644 --- a/specification/schema/hugr_schema_strict_live.json +++ b/specification/schema/hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/schema/testing_hugr_schema_live.json b/specification/schema/testing_hugr_schema_live.json index bc067d40e..91b121da6 100644 --- a/specification/schema/testing_hugr_schema_live.json +++ b/specification/schema/testing_hugr_schema_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": true, "properties": { diff --git a/specification/schema/testing_hugr_schema_strict_live.json b/specification/schema/testing_hugr_schema_strict_live.json index 47c9778d3..eae6a13a7 100644 --- a/specification/schema/testing_hugr_schema_strict_live.json +++ b/specification/schema/testing_hugr_schema_strict_live.json @@ -517,13 +517,6 @@ "title": "Types", "type": "object" }, - "values": { - "additionalProperties": { - "$ref": "#/$defs/ExtensionValue" - }, - "title": "Values", - "type": "object" - }, "operations": { "additionalProperties": { "$ref": "#/$defs/OpDef" @@ -537,7 +530,6 @@ "name", "runtime_reqs", "types", - "values", "operations" ], "title": "Extension", @@ -589,28 +581,6 @@ "title": "ExtensionOp", "type": "object" }, - "ExtensionValue": { - "properties": { - "extension": { - "title": "Extension", - "type": "string" - }, - "name": { - "title": "Name", - "type": "string" - }, - "typed_value": { - "$ref": "#/$defs/Value" - } - }, - "required": [ - "extension", - "name", - "typed_value" - ], - "title": "ExtensionValue", - "type": "object" - }, "ExtensionsArg": { "additionalProperties": false, "properties": { diff --git a/specification/std_extensions/arithmetic/conversions.json b/specification/std_extensions/arithmetic/conversions.json index 9c0054354..1d310df25 100644 --- a/specification/std_extensions/arithmetic/conversions.json +++ b/specification/std_extensions/arithmetic/conversions.json @@ -6,7 +6,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "bytecast_float64_to_int64": { "extension": "arithmetic.conversions", diff --git a/specification/std_extensions/arithmetic/float.json b/specification/std_extensions/arithmetic/float.json index 31ccaaa59..8da056772 100644 --- a/specification/std_extensions/arithmetic/float.json +++ b/specification/std_extensions/arithmetic/float.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "fabs": { "extension": "arithmetic.float", diff --git a/specification/std_extensions/arithmetic/float/types.json b/specification/std_extensions/arithmetic/float/types.json index 56e35c50b..0c563c474 100644 --- a/specification/std_extensions/arithmetic/float/types.json +++ b/specification/std_extensions/arithmetic/float/types.json @@ -14,6 +14,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/arithmetic/int.json b/specification/std_extensions/arithmetic/int.json index 62d0a6663..5b1a81250 100644 --- a/specification/std_extensions/arithmetic/int.json +++ b/specification/std_extensions/arithmetic/int.json @@ -5,7 +5,6 @@ "arithmetic.int.types" ], "types": {}, - "values": {}, "operations": { "iabs": { "extension": "arithmetic.int", diff --git a/specification/std_extensions/arithmetic/int/types.json b/specification/std_extensions/arithmetic/int/types.json index 60cf69f63..36df125a6 100644 --- a/specification/std_extensions/arithmetic/int/types.json +++ b/specification/std_extensions/arithmetic/int/types.json @@ -19,6 +19,5 @@ } } }, - "values": {}, "operations": {} } diff --git a/specification/std_extensions/collections/array.json b/specification/std_extensions/collections/array.json index 21e405151..375e13c72 100644 --- a/specification/std_extensions/collections/array.json +++ b/specification/std_extensions/collections/array.json @@ -25,7 +25,6 @@ } } }, - "values": {}, "operations": { "discard_empty": { "extension": "collections.array", diff --git a/specification/std_extensions/collections/list.json b/specification/std_extensions/collections/list.json index 0fbafc638..8a60d3544 100644 --- a/specification/std_extensions/collections/list.json +++ b/specification/std_extensions/collections/list.json @@ -21,7 +21,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.list", diff --git a/specification/std_extensions/collections/static_array.json b/specification/std_extensions/collections/static_array.json index e4669f671..53b8e61c7 100644 --- a/specification/std_extensions/collections/static_array.json +++ b/specification/std_extensions/collections/static_array.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "get": { "extension": "collections.static_array", diff --git a/specification/std_extensions/logic.json b/specification/std_extensions/logic.json index ad9f02019..ff29d2c21 100644 --- a/specification/std_extensions/logic.json +++ b/specification/std_extensions/logic.json @@ -3,34 +3,6 @@ "name": "logic", "runtime_reqs": [], "types": {}, - "values": { - "FALSE": { - "extension": "logic", - "name": "FALSE", - "typed_value": { - "v": "Sum", - "tag": 0, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - }, - "TRUE": { - "extension": "logic", - "name": "TRUE", - "typed_value": { - "v": "Sum", - "tag": 1, - "vs": [], - "typ": { - "s": "Unit", - "size": 2 - } - } - } - }, "operations": { "And": { "extension": "logic", diff --git a/specification/std_extensions/prelude.json b/specification/std_extensions/prelude.json index e11ba2388..ec392b155 100644 --- a/specification/std_extensions/prelude.json +++ b/specification/std_extensions/prelude.json @@ -44,7 +44,6 @@ } } }, - "values": {}, "operations": { "Barrier": { "extension": "prelude", diff --git a/specification/std_extensions/ptr.json b/specification/std_extensions/ptr.json index 18b1f26b6..614b6aecf 100644 --- a/specification/std_extensions/ptr.json +++ b/specification/std_extensions/ptr.json @@ -19,7 +19,6 @@ } } }, - "values": {}, "operations": { "New": { "extension": "ptr", From 89c2680912b47950ffd73a7c29a21386fdd0aee7 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Tue, 22 Apr 2025 17:16:31 +0100 Subject: [PATCH 08/18] feat!: ComposablePass trait allowing sequencing and validation (#1895) Currently We have several "passes": monomorphization, dead function removal, constant folding. Each has its own code to allow setting a validation level (before and after that pass). This PR adds the ability chain (sequence) passes;, and to add validation before+after any pass or sequence; and commons up validation code. The top-level `constant_fold_pass` (etc.) functions are left as wrappers that do a single pass with validation only in test. I've left ConstFoldPass as always including DCE, but an alternative could be to return a sequence of the two - ATM that means a tuple `(ConstFoldPass, DeadCodeElimPass)`. I also wondered about including a method `add_entry_point` in ComposablePass (e.g. for ConstFoldPass, that means `with_inputs` but no inputs, i.e. all Top). I feel this is not applicable to *all* passes, but near enough. This could be done in a later PR but `add_entry_point` would need a no-op default for that to be a non-breaking change. So if we wouldn't be happy with the no-op default then I could just add it here... Finally...docs are extremely minimal ATM (this is hugr-passes), I am hoping that most of this is reasonably obvious (it doesn't really do a lot!), but please flag anything you think is particularly in need of a doc comment! BREAKING CHANGE: quite a lot of calls to current pass routines will break, specific cases include (a) `with_validation_level` should be done by wrapping a ValidatingPass around the receiver; (b) XXXPass::run() requires `use ...ComposablePass` (however, such calls will cease to do any validation). closes #1832 --- hugr-passes/src/composable.rs | 361 +++++++++++++++++++++ hugr-passes/src/const_fold.rs | 45 +-- hugr-passes/src/const_fold/test.rs | 1 + hugr-passes/src/dead_code.rs | 50 ++- hugr-passes/src/dead_funcs.rs | 77 ++--- hugr-passes/src/lib.rs | 12 +- hugr-passes/src/monomorphize.rs | 92 ++---- hugr-passes/src/replace_types.rs | 105 +++--- hugr-passes/src/replace_types/linearize.rs | 2 +- hugr-passes/src/untuple.rs | 70 ++-- 10 files changed, 550 insertions(+), 265 deletions(-) create mode 100644 hugr-passes/src/composable.rs diff --git a/hugr-passes/src/composable.rs b/hugr-passes/src/composable.rs new file mode 100644 index 000000000..fb3319155 --- /dev/null +++ b/hugr-passes/src/composable.rs @@ -0,0 +1,361 @@ +//! Compiler passes and utilities for composing them + +use std::{error::Error, marker::PhantomData}; + +use hugr_core::hugr::{hugrmut::HugrMut, ValidationError}; +use hugr_core::HugrView; +use itertools::Either; + +/// An optimization pass that can be sequenced with another and/or wrapped +/// e.g. by [ValidatingPass] +pub trait ComposablePass: Sized { + type Error: Error; + type Result; // Would like to default to () but currently unstable + + fn run(&self, hugr: &mut impl HugrMut) -> Result; + + fn map_err( + self, + f: impl Fn(Self::Error) -> E2, + ) -> impl ComposablePass { + ErrMapper::new(self, f) + } + + /// Returns a [ComposablePass] that does "`self` then `other`", so long as + /// `other::Err` can be combined with ours. + fn then>( + self, + other: P, + ) -> impl ComposablePass { + struct Sequence(P1, P2, PhantomData); + impl ComposablePass for Sequence + where + P1: ComposablePass, + P2: ComposablePass, + E: ErrorCombiner, + { + type Error = E; + + type Result = (P1::Result, P2::Result); + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res1 = self.0.run(hugr).map_err(E::from_first)?; + let res2 = self.1.run(hugr).map_err(E::from_second)?; + Ok((res1, res2)) + } + } + + Sequence(self, other, PhantomData) + } +} + +/// Trait for combining the error types from two different passes +/// into a single error. +pub trait ErrorCombiner: Error { + fn from_first(a: A) -> Self; + fn from_second(b: B) -> Self; +} + +impl> ErrorCombiner for A { + fn from_first(a: A) -> Self { + a + } + + fn from_second(b: B) -> Self { + b.into() + } +} + +impl ErrorCombiner for Either { + fn from_first(a: A) -> Self { + Either::Left(a) + } + + fn from_second(b: B) -> Self { + Either::Right(b) + } +} + +// Note: in the short term we could wish for two more impls: +// impl ErrorCombiner for E +// impl ErrorCombiner for E +// however, these aren't possible as they conflict with +// impl> ErrorCombiner for A +// when A=E=Infallible, boo :-(. +// However this will become possible, indeed automatic, when Infallible is replaced +// by ! (never_type) as (unlike Infallible) ! converts Into anything + +// ErrMapper ------------------------------ +struct ErrMapper(P, F, PhantomData); + +impl E> ErrMapper { + fn new(pass: P, err_fn: F) -> Self { + Self(pass, err_fn, PhantomData) + } +} + +impl E> ComposablePass for ErrMapper { + type Error = E; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.0.run(hugr).map_err(&self.1) + } +} + +// ValidatingPass ------------------------------ + +/// Error from a [ValidatingPass] +#[derive(thiserror::Error, Debug)] +pub enum ValidatePassError { + #[error("Failed to validate input HUGR: {err}\n{pretty_hugr}")] + Input { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error("Failed to validate output HUGR: {err}\n{pretty_hugr}")] + Output { + #[source] + err: ValidationError, + pretty_hugr: String, + }, + #[error(transparent)] + Underlying(#[from] E), +} + +/// Runs an underlying pass, but with validation of the Hugr +/// both before and afterwards. +pub struct ValidatingPass

(P, bool); + +impl ValidatingPass

{ + pub fn new_default(underlying: P) -> Self { + // Self(underlying, cfg!(feature = "extension_inference")) + // Sadly, many tests fail with extension inference, hence: + Self(underlying, false) + } + + pub fn new_validating_extensions(underlying: P) -> Self { + Self(underlying, true) + } + + pub fn new(underlying: P, validate_extensions: bool) -> Self { + Self(underlying, validate_extensions) + } + + fn validation_impl( + &self, + hugr: &impl HugrView, + mk_err: impl FnOnce(ValidationError, String) -> ValidatePassError, + ) -> Result<(), ValidatePassError> { + match self.1 { + false => hugr.validate_no_extensions(), + true => hugr.validate(), + } + .map_err(|err| mk_err(err, hugr.mermaid_string())) + } +} + +impl ComposablePass for ValidatingPass

{ + type Error = ValidatePassError; + type Result = P::Result; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Input { + err, + pretty_hugr, + })?; + let res = self.0.run(hugr).map_err(ValidatePassError::Underlying)?; + self.validation_impl(hugr, |err, pretty_hugr| ValidatePassError::Output { + err, + pretty_hugr, + })?; + Ok(res) + } +} + +// IfThen ------------------------------ +/// [ComposablePass] that executes a first pass that returns a `bool` +/// result; and then, if-and-only-if that first result was true, +/// executes a second pass +pub struct IfThen(A, B, PhantomData); + +impl, B: ComposablePass, E: ErrorCombiner> + IfThen +{ + /// Make a new instance given the [ComposablePass] to run first + /// and (maybe) second + pub fn new(fst: A, opt_snd: B) -> Self { + Self(fst, opt_snd, PhantomData) + } +} + +impl, B: ComposablePass, E: ErrorCombiner> + ComposablePass for IfThen +{ + type Error = E; + + type Result = Option; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let res: bool = self.0.run(hugr).map_err(ErrorCombiner::from_first)?; + res.then(|| self.1.run(hugr).map_err(ErrorCombiner::from_second)) + .transpose() + } +} + +pub(crate) fn validate_if_test( + pass: P, + hugr: &mut impl HugrMut, +) -> Result> { + if cfg!(test) { + ValidatingPass::new_default(pass).run(hugr) + } else { + pass.run(hugr).map_err(ValidatePassError::Underlying) + } +} + +#[cfg(test)] +mod test { + use itertools::{Either, Itertools}; + use std::convert::Infallible; + + use hugr_core::builder::{ + Container, Dataflow, DataflowHugr, DataflowSubContainer, FunctionBuilder, HugrBuilder, + ModuleBuilder, + }; + use hugr_core::extension::prelude::{ + bool_t, usize_t, ConstUsize, MakeTuple, UnpackTuple, PRELUDE_ID, + }; + use hugr_core::hugr::hugrmut::HugrMut; + use hugr_core::ops::{handle::NodeHandle, Input, OpType, Output, DEFAULT_OPTYPE, DFG}; + use hugr_core::std_extensions::arithmetic::int_types::INT_TYPES; + use hugr_core::types::{Signature, TypeRow}; + use hugr_core::{Hugr, HugrView, IncomingPort}; + + use crate::const_fold::{ConstFoldError, ConstantFoldPass}; + use crate::untuple::{UntupleRecursive, UntupleResult}; + use crate::{DeadCodeElimPass, ReplaceTypes, UntuplePass}; + + use super::{validate_if_test, ComposablePass, IfThen, ValidatePassError, ValidatingPass}; + + #[test] + fn test_then() { + let mut mb = ModuleBuilder::new(); + let id1 = mb + .define_function("id1", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id1.input_wires(); + let id1 = id1.finish_with_outputs(inps).unwrap(); + let id2 = mb + .define_function("id2", Signature::new_endo(usize_t())) + .unwrap(); + let inps = id2.input_wires(); + let id2 = id2.finish_with_outputs(inps).unwrap(); + let hugr = mb.finish_hugr().unwrap(); + + let dce = DeadCodeElimPass::default().with_entry_points([id1.node()]); + let cfold = + ConstantFoldPass::default().with_inputs(id2.node(), [(0, ConstUsize::new(2).into())]); + + cfold.run(&mut hugr.clone()).unwrap(); + + let exp_err = ConstFoldError::InvalidEntryPoint(id2.node(), DEFAULT_OPTYPE); + let r: Result<_, Either> = + dce.clone().then(cfold.clone()).run(&mut hugr.clone()); + assert_eq!(r, Err(Either::Right(exp_err.clone()))); + + let r = dce + .clone() + .map_err(|inf| match inf {}) + .then(cfold.clone()) + .run(&mut hugr.clone()); + assert_eq!(r, Err(exp_err)); + + let r2: Result<_, Either<_, _>> = cfold.then(dce).run(&mut hugr.clone()); + r2.unwrap(); + } + + #[test] + fn test_validation() { + let mut h = Hugr::new(DFG { + signature: Signature::new(usize_t(), bool_t()), + }); + let inp = h.add_node_with_parent( + h.root(), + Input { + types: usize_t().into(), + }, + ); + let outp = h.add_node_with_parent( + h.root(), + Output { + types: bool_t().into(), + }, + ); + h.connect(inp, 0, outp, 0); + let backup = h.clone(); + let err = backup.validate().unwrap_err(); + + let no_inputs: [(IncomingPort, _); 0] = []; + let cfold = ConstantFoldPass::default().with_inputs(backup.root(), no_inputs); + cfold.run(&mut h).unwrap(); + assert_eq!(h, backup); // Did nothing + + let r = ValidatingPass(cfold, false).run(&mut h); + assert!(matches!(r, Err(ValidatePassError::Input { err: e, .. }) if e == err)); + } + + #[test] + fn test_if_then() { + let tr = TypeRow::from(vec![usize_t(); 2]); + + let h = { + let sig = Signature::new_endo(tr.clone()).with_extension_delta(PRELUDE_ID); + let mut fb = FunctionBuilder::new("tupuntup", sig).unwrap(); + let [a, b] = fb.input_wires_arr(); + let tup = fb + .add_dataflow_op(MakeTuple::new(tr.clone()), [a, b]) + .unwrap(); + let untup = fb + .add_dataflow_op(UnpackTuple::new(tr.clone()), tup.outputs()) + .unwrap(); + fb.finish_hugr_with_outputs(untup.outputs()).unwrap() + }; + + let untup = UntuplePass::new(UntupleRecursive::Recursive); + { + // Change usize_t to INT_TYPES[6], and if that did anything (it will!), then Untuple + let mut repl = ReplaceTypes::default(); + let usize_custom_t = usize_t().as_extension().unwrap().clone(); + repl.replace_type(usize_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup.clone()); + + let mut h = h.clone(); + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!( + r, + Some(UntupleResult { + rewrites_applied: 1 + }) + ); + let [tuple_in, tuple_out] = h.children(h.root()).collect_array().unwrap(); + assert_eq!(h.output_neighbours(tuple_in).collect_vec(), [tuple_out; 2]); + } + + // Change INT_TYPES[5] to INT_TYPES[6]; that won't do anything, so don't Untuple + let mut repl = ReplaceTypes::default(); + let i32_custom_t = INT_TYPES[5].as_extension().unwrap().clone(); + repl.replace_type(i32_custom_t, INT_TYPES[6].clone()); + let ifthen = IfThen::, _, _>::new(repl, untup); + let mut h = h; + let r = validate_if_test(ifthen, &mut h).unwrap(); + assert_eq!(r, None); + assert_eq!(h.children(h.root()).count(), 4); + let mktup = h + .output_neighbours(h.first_child(h.root()).unwrap()) + .next() + .unwrap(); + assert_eq!(h.get_optype(mktup), &OpType::from(MakeTuple::new(tr))); + } +} diff --git a/hugr-passes/src/const_fold.rs b/hugr-passes/src/const_fold.rs index e73e3cd0e..99ccc180c 100644 --- a/hugr-passes/src/const_fold.rs +++ b/hugr-passes/src/const_fold.rs @@ -21,12 +21,11 @@ use crate::dataflow::{ TailLoopTermination, }; use crate::dead_code::{DeadCodeElimPass, PreserveNode}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{composable::validate_if_test, ComposablePass}; #[derive(Debug, Clone, Default)] /// A configuration for the Constant Folding pass. pub struct ConstantFoldPass { - validation: ValidationLevel, allow_increase_termination: bool, /// Each outer key Node must be either: /// - a FuncDefn child of the root, if the root is a module; or @@ -34,13 +33,10 @@ pub struct ConstantFoldPass { inputs: HashMap>, } -#[derive(Debug, Error)] +#[derive(Clone, Debug, Error, PartialEq)] #[non_exhaustive] /// Errors produced by [ConstantFoldPass]. pub enum ConstFoldError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), /// Error raised when a Node is specified as an entry-point but /// is neither a dataflow parent, nor a [CFG](OpType::CFG), nor /// a [Conditional](OpType::Conditional). @@ -49,12 +45,6 @@ pub enum ConstFoldError { } impl ConstantFoldPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows the pass to remove potentially-non-terminating [TailLoop]s and [CFG] if their /// result (if/when they do terminate) is either known or not needed. /// @@ -86,9 +76,19 @@ impl ConstantFoldPass { .extend(inputs.into_iter().map(|(p, v)| (p.into(), v))); self } +} + +impl ComposablePass for ConstantFoldPass { + type Error = ConstFoldError; + type Result = (); /// Run the Constant Folding pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { + /// + /// # Errors + /// + /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] + /// was of an invalid [OpType] + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ConstFoldError> { let fresh_node = Node::from(portgraph::NodeIndex::new( hugr.nodes().max().map_or(0, |n| n.index() + 1), )); @@ -164,23 +164,10 @@ impl ConstantFoldPass { } }) }) - .run(hugr)?; + .run(hugr) + .map_err(|inf| match inf {})?; // TODO use into_ok when available Ok(()) } - - /// Run the pass using this configuration. - /// - /// # Errors - /// - /// [ConstFoldError::ValidationError] if the Hugr does not validate before/afnerwards - /// (if [Self::validation_level] is set, or in tests) - /// - /// [ConstFoldError::InvalidEntryPoint] if an entry-point added by [Self::with_inputs] - /// was of an invalid OpType - pub fn run(&self, hugr: &mut H) -> Result<(), ConstFoldError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } /// Exhaustively apply constant folding to a HUGR. @@ -198,7 +185,7 @@ pub fn constant_fold_pass(h: &mut H) { } else { c }; - c.run(h).unwrap() + validate_if_test(c, h).unwrap() } struct ConstFoldContext; diff --git a/hugr-passes/src/const_fold/test.rs b/hugr-passes/src/const_fold/test.rs index 58e69c568..ff5cd93a5 100644 --- a/hugr-passes/src/const_fold/test.rs +++ b/hugr-passes/src/const_fold/test.rs @@ -32,6 +32,7 @@ use hugr_core::types::{Signature, SumType, Type, TypeBound, TypeRow, TypeRowRV}; use hugr_core::{type_row, Hugr, HugrView, IncomingPort, Node}; use crate::dataflow::{partial_from_const, DFContext, PartialValue}; +use crate::ComposablePass as _; use super::{constant_fold_pass, ConstFoldContext, ConstantFoldPass, ValueHandle}; diff --git a/hugr-passes/src/dead_code.rs b/hugr-passes/src/dead_code.rs index b714dd6fd..899e30243 100644 --- a/hugr-passes/src/dead_code.rs +++ b/hugr-passes/src/dead_code.rs @@ -1,13 +1,14 @@ //! Pass for removing dead code, i.e. that computes values that are then discarded use hugr_core::{hugr::hugrmut::HugrMut, ops::OpType, Hugr, HugrView, Node}; +use std::convert::Infallible; use std::fmt::{Debug, Formatter}; use std::{ collections::{HashMap, HashSet, VecDeque}, sync::Arc, }; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration for Dead Code Elimination pass #[derive(Clone)] @@ -18,7 +19,6 @@ pub struct DeadCodeElimPass { /// Callback identifying nodes that must be preserved even if their /// results are not used. Defaults to [PreserveNode::default_for]. preserve_callback: Arc, - validation: ValidationLevel, } impl Default for DeadCodeElimPass { @@ -26,7 +26,6 @@ impl Default for DeadCodeElimPass { Self { entry_points: Default::default(), preserve_callback: Arc::new(PreserveNode::default_for), - validation: ValidationLevel::default(), } } } @@ -39,13 +38,11 @@ impl Debug for DeadCodeElimPass { #[derive(Debug)] struct DCEDebug<'a> { entry_points: &'a Vec, - validation: ValidationLevel, } Debug::fmt( &DCEDebug { entry_points: &self.entry_points, - validation: self.validation, }, f, ) @@ -86,13 +83,6 @@ impl PreserveNode { } impl DeadCodeElimPass { - /// Sets the validation level used before and after the pass is run - #[allow(unused)] - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Allows setting a callback that determines whether a node must be preserved /// (even when its result is not used) pub fn set_preserve_callback(mut self, cb: Arc) -> Self { @@ -146,24 +136,6 @@ impl DeadCodeElimPass { needed } - pub fn run(&self, hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { - self.validation.run_validated_pass(hugr, |h, _| { - self.run_no_validate(h); - Ok(()) - }) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) { - let needed = self.find_needed_nodes(&*hugr); - let remove = hugr - .nodes() - .filter(|n| !needed.contains(n)) - .collect::>(); - for n in remove { - hugr.remove_node(n); - } - } - fn must_preserve( &self, h: &impl HugrView, @@ -185,6 +157,22 @@ impl DeadCodeElimPass { } } +impl ComposablePass for DeadCodeElimPass { + type Error = Infallible; + type Result = (); + + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), Infallible> { + let needed = self.find_needed_nodes(&*hugr); + let remove = hugr + .nodes() + .filter(|n| !needed.contains(n)) + .collect::>(); + for n in remove { + hugr.remove_node(n); + } + Ok(()) + } +} #[cfg(test)] mod test { use std::sync::Arc; @@ -196,6 +184,8 @@ mod test { use hugr_core::{ops::Value, type_row, HugrView}; use itertools::Itertools; + use crate::ComposablePass; + use super::{DeadCodeElimPass, PreserveNode}; #[test] diff --git a/hugr-passes/src/dead_funcs.rs b/hugr-passes/src/dead_funcs.rs index b114a9e42..7071d5335 100644 --- a/hugr-passes/src/dead_funcs.rs +++ b/hugr-passes/src/dead_funcs.rs @@ -10,7 +10,10 @@ use hugr_core::{ }; use petgraph::visit::{Dfs, Walker}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::{ + composable::{validate_if_test, ValidatePassError}, + ComposablePass, +}; use super::call_graph::{CallGraph, CallGraphNode}; @@ -26,9 +29,6 @@ pub enum RemoveDeadFuncsError { /// The invalid node. node: N, }, - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), } fn reachable_funcs<'a, H: HugrView>( @@ -64,17 +64,10 @@ fn reachable_funcs<'a, H: HugrView>( #[derive(Debug, Clone, Default)] /// A configuration for the Dead Function Removal pass. pub struct RemoveDeadFuncsPass { - validation: ValidationLevel, entry_points: Vec, } impl RemoveDeadFuncsPass { - /// Sets the validation level used before and after the pass is run - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Adds new entry points - these must be [FuncDefn] nodes /// that are children of the [Module] at the root of the Hugr. /// @@ -87,16 +80,32 @@ impl RemoveDeadFuncsPass { self.entry_points.extend(entry_points); self } +} - /// Runs the pass (see [remove_dead_funcs]) with this configuration - pub fn run(&self, hugr: &mut H) -> Result<(), RemoveDeadFuncsError> { - self.validation.run_validated_pass(hugr, |hugr: &mut H, _| { - remove_dead_funcs(hugr, self.entry_points.iter().cloned()) - }) +impl ComposablePass for RemoveDeadFuncsPass { + type Error = RemoveDeadFuncsError; + type Result = (); + fn run(&self, hugr: &mut impl HugrMut) -> Result<(), RemoveDeadFuncsError> { + let reachable = reachable_funcs( + &CallGraph::new(hugr), + hugr, + self.entry_points.iter().cloned(), + )? + .collect::>(); + let unreachable = hugr + .nodes() + .filter(|n| { + OpTag::Function.is_superset(hugr.get_optype(*n).tag()) && !reachable.contains(n) + }) + .collect::>(); + for n in unreachable { + hugr.remove_subtree(n); + } + Ok(()) } } -/// Delete from the Hugr any functions that are not used by either [Call] or +/// Deletes from the Hugr any functions that are not used by either [Call] or /// [LoadFunction] nodes in reachable parts. /// /// For [Module]-rooted Hugrs, `entry_points` may provide a list of entry points, @@ -118,16 +127,11 @@ impl RemoveDeadFuncsPass { pub fn remove_dead_funcs( h: &mut impl HugrMut, entry_points: impl IntoIterator, -) -> Result<(), RemoveDeadFuncsError> { - let reachable = reachable_funcs(&CallGraph::new(h), h, entry_points)?.collect::>(); - let unreachable = h - .nodes() - .filter(|n| OpTag::Function.is_superset(h.get_optype(*n).tag()) && !reachable.contains(n)) - .collect::>(); - for n in unreachable { - h.remove_subtree(n); - } - Ok(()) +) -> Result<(), ValidatePassError> { + validate_if_test( + RemoveDeadFuncsPass::default().with_module_entry_points(entry_points), + h, + ) } #[cfg(test)] @@ -142,7 +146,7 @@ mod test { }; use hugr_core::{extension::prelude::usize_t, types::Signature, HugrView}; - use super::RemoveDeadFuncsPass; + use super::remove_dead_funcs; #[rstest] #[case([], vec![])] // No entry_points removes everything! @@ -182,15 +186,14 @@ mod test { }) .collect::>(); - RemoveDeadFuncsPass::default() - .with_module_entry_points( - entry_points - .into_iter() - .map(|name| *avail_funcs.get(name).unwrap()) - .collect::>(), - ) - .run(&mut hugr) - .unwrap(); + remove_dead_funcs( + &mut hugr, + entry_points + .into_iter() + .map(|name| *avail_funcs.get(name).unwrap()) + .collect::>(), + ) + .unwrap(); let remaining_funcs = hugr .nodes() diff --git a/hugr-passes/src/lib.rs b/hugr-passes/src/lib.rs index 961c4da47..83ff71b67 100644 --- a/hugr-passes/src/lib.rs +++ b/hugr-passes/src/lib.rs @@ -1,6 +1,8 @@ //! Compilation passes acting on the HUGR program representation. pub mod call_graph; +pub mod composable; +pub use composable::ComposablePass; pub mod const_fold; pub mod dataflow; pub mod dead_code; @@ -21,19 +23,11 @@ pub mod untuple; )] #[allow(deprecated)] pub use monomorphize::remove_polyfuncs; -// TODO: Deprecated re-export. Remove on a breaking release. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -#[allow(deprecated)] -pub use monomorphize::monomorphize; -pub use monomorphize::{MonomorphizeError, MonomorphizePass}; +pub use monomorphize::{monomorphize, MonomorphizePass}; pub mod replace_types; pub use replace_types::ReplaceTypes; pub mod nest_cfgs; pub mod non_local; -pub mod validation; pub use force_order::{force_order, force_order_by_key}; pub use lower::{lower_ops, replace_many_ops}; pub use non_local::{ensure_no_nonlocal_edges, nonlocal_edges}; diff --git a/hugr-passes/src/monomorphize.rs b/hugr-passes/src/monomorphize.rs index 4f4e9bda2..875ee9355 100644 --- a/hugr-passes/src/monomorphize.rs +++ b/hugr-passes/src/monomorphize.rs @@ -1,5 +1,6 @@ use std::{ collections::{hash_map::Entry, HashMap}, + convert::Infallible, fmt::Write, ops::Deref, }; @@ -12,7 +13,9 @@ use hugr_core::{ use hugr_core::hugr::{hugrmut::HugrMut, Hugr, HugrView, OpType}; use itertools::Itertools as _; -use thiserror::Error; + +use crate::composable::{validate_if_test, ValidatePassError}; +use crate::ComposablePass; /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. @@ -30,26 +33,8 @@ use thiserror::Error; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[deprecated( - since = "0.14.1", - note = "Use `hugr_passes::MonomorphizePass` instead." -)] -// TODO: Deprecated. Remove on a breaking release and rename private `monomorphize_ref` to `monomorphize`. -pub fn monomorphize(mut h: Hugr) -> Hugr { - monomorphize_ref(&mut h); - h -} - -fn monomorphize_ref(h: &mut impl HugrMut) { - let root = h.root(); - // If the root is a polymorphic function, then there are no external calls, so nothing to do - if !is_polymorphic_funcdefn(h.get_optype(root)) { - mono_scan(h, root, None, &mut HashMap::new()); - if !h.get_optype(root).is_module() { - #[allow(deprecated)] // TODO remove in next breaking release and update docs - remove_polyfuncs_ref(h); - } - } +pub fn monomorphize(hugr: &mut impl HugrMut) -> Result<(), ValidatePassError> { + validate_if_test(MonomorphizePass, hugr) } /// Removes any polymorphic [FuncDefn]s from the Hugr. Note that if these have @@ -254,8 +239,6 @@ fn instantiate( mono_tgt } -use crate::validation::{ValidatePassError, ValidationLevel}; - /// Replaces calls to polymorphic functions with calls to new monomorphic /// instantiations of the polymorphic ones. /// @@ -271,38 +254,25 @@ use crate::validation::{ValidatePassError, ValidationLevel}; /// children of the root node. We make best effort to ensure that names (derived /// from parent function names and concrete type args) of new functions are unique /// whenever the names of their parents are unique, but this is not guaranteed. -#[derive(Debug, Clone, Default)] -pub struct MonomorphizePass { - validation: ValidationLevel, -} - -#[derive(Debug, Error)] -#[non_exhaustive] -/// Errors produced by [MonomorphizePass]. -pub enum MonomorphizeError { - #[error(transparent)] - #[allow(missing_docs)] - ValidationError(#[from] ValidatePassError), -} - -impl MonomorphizePass { - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - - /// Run the Monomorphization pass. - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result<(), MonomorphizeError> { - monomorphize_ref(hugr); +#[derive(Debug, Clone)] +pub struct MonomorphizePass; + +impl ComposablePass for MonomorphizePass { + type Error = Infallible; + type Result = (); + + fn run(&self, h: &mut impl HugrMut) -> Result<(), Self::Error> { + let root = h.root(); + // If the root is a polymorphic function, then there are no external calls, so nothing to do + if !is_polymorphic_funcdefn(h.get_optype(root)) { + mono_scan(h, root, None, &mut HashMap::new()); + if !h.get_optype(root).is_module() { + #[allow(deprecated)] // TODO remove in next breaking release and update docs + remove_polyfuncs_ref(h); + } + } Ok(()) } - - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result<(), MonomorphizeError> { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } } struct TypeArgsList<'a>(&'a [TypeArg]); @@ -387,9 +357,9 @@ mod test { use hugr_core::{Hugr, HugrView, Node}; use rstest::rstest; - use crate::remove_dead_funcs; + use crate::{monomorphize, remove_dead_funcs}; - use super::{is_polymorphic, mangle_inner_func, mangle_name, MonomorphizePass}; + use super::{is_polymorphic, mangle_inner_func, mangle_name}; fn pair_type(ty: Type) -> Type { Type::new_tuple(vec![ty.clone(), ty]) @@ -410,7 +380,7 @@ mod test { let [i1] = dfg_builder.input_wires_arr(); let hugr = dfg_builder.finish_hugr_with_outputs([i1]).unwrap(); let mut hugr2 = hugr.clone(); - MonomorphizePass::default().run(&mut hugr2).unwrap(); + monomorphize(&mut hugr2).unwrap(); assert_eq!(hugr, hugr2); } @@ -472,7 +442,7 @@ mod test { .count(), 3 ); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono = hugr; mono.validate()?; @@ -493,7 +463,7 @@ mod test { ["double", "main", "triple"] ); let mut mono2 = mono.clone(); - MonomorphizePass::default().run(&mut mono2)?; + monomorphize(&mut mono2)?; assert_eq!(mono2, mono); // Idempotent @@ -601,7 +571,7 @@ mod test { .outputs_arr(); let mut hugr = outer.finish_hugr_with_outputs([e1, e2]).unwrap(); - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); let mono_hugr = hugr; mono_hugr.validate().unwrap(); let funcs = list_funcs(&mono_hugr); @@ -662,7 +632,7 @@ mod test { let mono = mono.finish_with_outputs([a, b]).unwrap(); let c = dfg.call(mono.handle(), &[], dfg.input_wires()).unwrap(); let mut hugr = dfg.finish_hugr_with_outputs(c.outputs()).unwrap(); - MonomorphizePass::default().run(&mut hugr)?; + monomorphize(&mut hugr)?; let mono_hugr = hugr; let mut funcs = list_funcs(&mono_hugr); @@ -719,7 +689,7 @@ mod test { module_builder.finish_hugr().unwrap() }; - MonomorphizePass::default().run(&mut hugr).unwrap(); + monomorphize(&mut hugr).unwrap(); remove_dead_funcs(&mut hugr, []).unwrap(); let funcs = list_funcs(&hugr); diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..e81a640e3 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -26,7 +26,7 @@ use hugr_core::types::{ }; use hugr_core::{Hugr, HugrView, Node, Wire}; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; mod linearize; pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; @@ -143,7 +143,6 @@ pub struct ReplaceTypes { ParametricType, Arc Result, ReplaceTypesError>>, >, - validation: ValidationLevel, } impl Default for ReplaceTypes { @@ -184,8 +183,6 @@ pub enum ReplaceTypesError { #[error(transparent)] SignatureError(#[from] SignatureError), #[error(transparent)] - ValidationError(#[from] ValidatePassError), - #[error(transparent)] ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), @@ -203,16 +200,9 @@ impl ReplaceTypes { param_ops: Default::default(), consts: Default::default(), param_consts: Default::default(), - validation: Default::default(), } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; - self - } - /// Configures this instance to replace occurrences of type `src` with `dest`. /// Note that if `src` is an instance of a *parametrized* [TypeDef], this takes /// precedence over [Self::replace_parametrized_type] where the `src`s overlap. Thus, this @@ -323,36 +313,6 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - /// Run the pass using specified configuration. - pub fn run(&self, hugr: &mut H) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr)) - } - - fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { - let mut changed = false; - for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; - let new_dfsig = hugr.get_optype(n).dataflow_signature(); - if let Some(new_sig) = new_dfsig - .filter(|_| changed && n != hugr.root()) - .map(Cow::into_owned) - { - for outp in new_sig.output_ports() { - if !new_sig.out_port_type(outp).unwrap().copyable() { - let targets = hugr.linked_inputs(n, outp).collect::>(); - if targets.len() != 1 { - hugr.disconnect(n, outp); - let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; - } - } - } - } - } - Ok(changed) - } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) @@ -472,11 +432,40 @@ impl ReplaceTypes { false } }), - Value::Function { hugr } => self.run_no_validate(&mut **hugr), + Value::Function { hugr } => self.run(&mut **hugr), } } } +impl ComposablePass for ReplaceTypes { + type Error = ReplaceTypesError; + type Result = bool; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let mut changed = false; + for n in hugr.nodes().collect::>() { + changed |= self.change_node(hugr, n)?; + let new_dfsig = hugr.get_optype(n).dataflow_signature(); + if let Some(new_sig) = new_dfsig + .filter(|_| changed && n != hugr.root()) + .map(Cow::into_owned) + { + for outp in new_sig.output_ports() { + if !new_sig.out_port_type(outp).unwrap().copyable() { + let targets = hugr.linked_inputs(n, outp).collect::>(); + if targets.len() != 1 { + hugr.disconnect(n, outp); + let src = Wire::new(n, outp); + self.linearize.insert_copy_discard(hugr, src, &targets)?; + } + } + } + } + } + Ok(changed) + } +} + pub mod handlers; #[derive(Clone, Hash, PartialEq, Eq)] @@ -532,29 +521,26 @@ mod test { use hugr_core::extension::prelude::{ bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, }; - use hugr_core::extension::simple_op::MakeExtensionOp; - use hugr_core::extension::{TypeDefBound, Version}; - - use hugr_core::ops::constant::OpaqueValue; - use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; - use hugr_core::std_extensions::arithmetic::int_types::ConstInt; - use hugr_core::std_extensions::arithmetic::{conversions::ConvertOpDef, int_types::INT_TYPES}; + use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::hugr::{IdentList, ValidationError}; + use hugr_core::ops::{ + constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, + }; + use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; use hugr_core::std_extensions::collections::array::{ array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, }; use hugr_core::std_extensions::collections::list::{ list_type, list_type_def, ListOp, ListValue, }; - - use hugr_core::hugr::ValidationError; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{hugr::IdentList, type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, HugrView}; use itertools::Itertools; use rstest::rstest; - use crate::validation::ValidatePassError; + use crate::ComposablePass; - use super::ReplaceTypesError; use super::{handlers::list_const, NodeTemplate, ReplaceTypes}; const PACKED_VEC: &str = "PackedVec"; @@ -979,13 +965,16 @@ mod test { let cu = cst.value().downcast_ref::().unwrap(); Ok(ConstInt::new_u(6, cu.value())?.into()) }); + + let mut h = backup.clone(); + repl.run(&mut h).unwrap(); // No validation here assert!( - matches!(repl.run(&mut backup.clone()), Err(ReplaceTypesError::ValidationError(ValidatePassError::OutputError { - err: ValidationError::IncompatiblePorts {from, to, ..}, .. - })) if backup.get_optype(from).is_const() && to == c.node()) + matches!(h.validate(), Err(ValidationError::IncompatiblePorts {from, to, ..}) + if backup.get_optype(from).is_const() && to == c.node()) ); repl.replace_consts_parametrized(array_type_def(), array_const); let mut h = backup; - repl.run(&mut h).unwrap(); // Includes validation + repl.run(&mut h).unwrap(); + h.validate_no_extensions().unwrap(); } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 5b4da7184..bc508bd53 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -377,7 +377,7 @@ mod test { use crate::replace_types::handlers::linearize_array; use crate::replace_types::{LinearizeError, NodeTemplate, ReplaceTypesError}; - use crate::ReplaceTypes; + use crate::{ComposablePass, ReplaceTypes}; const LIN_T: &str = "Lin"; diff --git a/hugr-passes/src/untuple.rs b/hugr-passes/src/untuple.rs index dbe04edd1..874fd9ec3 100644 --- a/hugr-passes/src/untuple.rs +++ b/hugr-passes/src/untuple.rs @@ -10,19 +10,19 @@ use hugr_core::hugr::views::SiblingSubgraph; use hugr_core::hugr::SimpleReplacementError; use hugr_core::ops::{NamedOp, OpTrait, OpType}; use hugr_core::types::Type; -use hugr_core::{HugrView, SimpleReplacement}; +use hugr_core::{HugrView, Node, SimpleReplacement}; use itertools::Itertools; -use crate::validation::{ValidatePassError, ValidationLevel}; +use crate::ComposablePass; /// Configuration enum for the untuple rewrite pass. /// /// Indicates whether the pattern match should traverse the HUGR recursively. #[derive(Debug, Clone, Copy, PartialEq, Eq, Default)] pub enum UntupleRecursive { - /// Traverse the HUGR recursively. + /// Traverse the HUGR recursively, i.e. consider the entire subtree Recursive, - /// Do not traverse the HUGR recursively. + /// Do not traverse the HUGR recursively, i.e. consider only the sibling subgraph #[default] NonRecursive, } @@ -48,22 +48,20 @@ pub enum UntupleRecursive { pub struct UntuplePass { /// Whether to traverse the HUGR recursively. recursive: UntupleRecursive, - /// The level of validation to perform on the rewrite. - validation: ValidationLevel, + /// Parent node under which to operate; None indicates the Hugr root + parent: Option, } #[derive(Debug, derive_more::Display, derive_more::Error, derive_more::From)] #[non_exhaustive] /// Errors produced by [UntuplePass]. pub enum UntupleError { - /// An error occurred while validating the rewrite. - ValidationError(ValidatePassError), /// Rewriting the circuit failed. RewriteError(SimpleReplacementError), } /// Result type for the untuple pass. -#[derive(Debug, Clone, Copy, Default)] +#[derive(Debug, Clone, Copy, Default, PartialEq)] pub struct UntupleResult { /// Number of `MakeTuple` rewrites applied. pub rewrites_applied: usize, @@ -71,16 +69,16 @@ pub struct UntupleResult { impl UntuplePass { /// Create a new untuple pass with the given configuration. - pub fn new(recursive: UntupleRecursive, validation: ValidationLevel) -> Self { + pub fn new(recursive: UntupleRecursive) -> Self { Self { recursive, - validation, + parent: None, } } - /// Sets the validation level used before and after the pass is run. - pub fn validation_level(mut self, level: ValidationLevel) -> Self { - self.validation = level; + /// Sets the parent node to optimize (overwrites any previous setting) + pub fn set_parent(mut self, parent: impl Into>) -> Self { + self.parent = parent.into(); self } @@ -90,31 +88,6 @@ impl UntuplePass { self } - /// Run the pass using specified configuration. - pub fn run( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - self.validation - .run_validated_pass(hugr, |hugr: &mut H, _| self.run_no_validate(hugr, parent)) - } - - /// Run the Monomorphization pass. - fn run_no_validate( - &self, - hugr: &mut H, - parent: H::Node, - ) -> Result { - let rewrites = self.find_rewrites(hugr, parent); - let rewrites_applied = rewrites.len(); - // The rewrites are independent, so we can always apply them all. - for rewrite in rewrites { - hugr.apply_rewrite(rewrite)?; - } - Ok(UntupleResult { rewrites_applied }) - } - /// Find tuple pack operations followed by tuple unpack operations /// and generate rewrites to remove them. /// @@ -148,6 +121,22 @@ impl UntuplePass { } } +impl ComposablePass for UntuplePass { + type Error = UntupleError; + + type Result = UntupleResult; + + fn run(&self, hugr: &mut impl HugrMut) -> Result { + let rewrites = self.find_rewrites(hugr, self.parent.unwrap_or(hugr.root())); + let rewrites_applied = rewrites.len(); + // The rewrites are independent, so we can always apply them all. + for rewrite in rewrites { + hugr.apply_rewrite(rewrite)?; + } + Ok(UntupleResult { rewrites_applied }) + } +} + /// Returns true if the given optype is a MakeTuple operation. /// /// Boilerplate required due to https://github.com/CQCL/hugr/issues/1496 @@ -421,7 +410,8 @@ mod test { let parent = hugr.root(); let res = pass - .run(&mut hugr, parent) + .set_parent(parent) + .run(&mut hugr) .unwrap_or_else(|e| panic!("{e}")); assert_eq!(res.rewrites_applied, expected_rewrites); assert_eq!(hugr.children(parent).count(), remaining_nodes); From d8a5d6794526f22bc99d7a5489cbcc2d39e3c59a Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 10:55:43 +0100 Subject: [PATCH 09/18] feat!: ReplaceTypes: allow lowering ops into a Call to a function already in the Hugr (#2094) There are two issues: * Errors. The previous NodeTemplates still always work, but the Call one can fail if the Hugr doesn't contain the target function node. ATM there is no channel for reporting that error so I've had to panic. Otherwise it's an even-more-breaking change to add an error type to `NodeTemplate::add()` and `NodeTemplate::add_hugr()`. Should we? (I note `HugrMut::connect` panics if the node isn't there, but could make the `NodeTemplate::add` builder method return a BuildError...and propagate that everywhere of course) * There's a big limitation in `linearize_array` that it'll break if the *element* says it should be copied/discarded via a NodeTemplate::Call, as `linearize_array` puts the elementwise copy/discard function into a *nested Hugr* (`Value::Function`) that won't contain the function. This could be fixed via lifting those to toplevel FuncDefns with name-mangling, but I'd rather leave that for #2086 .... BREAKING CHANGE: Add new variant NodeTemplate::Call; LinearizeError no longer derives Eq. --- hugr-passes/src/replace_types.rs | 234 ++++++++++++++++----- hugr-passes/src/replace_types/handlers.rs | 4 +- hugr-passes/src/replace_types/linearize.rs | 104 +++++++-- 3 files changed, 268 insertions(+), 74 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index e81a640e3..df4c14075 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -15,16 +15,17 @@ use hugr_core::builder::{BuildError, BuildHandle, Dataflow}; use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef}; use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::ops::constant::{OpaqueValue, Sum}; -use hugr_core::ops::handle::DataflowOpID; +use hugr_core::ops::handle::{DataflowOpID, FuncID}; use hugr_core::ops::{ AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp, FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop, Value, CFG, DFG, }; use hugr_core::types::{ - ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer, + ConstTypeError, CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeRow, + TypeTransformer, }; -use hugr_core::{Hugr, HugrView, Node, Wire}; +use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; use crate::ComposablePass; @@ -45,21 +46,37 @@ pub enum NodeTemplate { /// Note this will be of limited use before [monomorphization](super::monomorphize()) /// because the new subtree will not be able to use type variables present in the /// parent Hugr or previous op. - // TODO: store also a vec, and update Hugr::validate to take &[TypeParam]s - // (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709 CompoundOp(Box), - // TODO allow also Call to a Node in the existing Hugr - // (can't see any other way to achieve multiple calls to the same decl. - // So client should add the functions before replacement, then remove unused ones afterwards.) + /// A Call to an existing function. + Call(Node, Vec), } impl NodeTemplate { /// Adds this instance to the specified [HugrMut] as a new node or subtree under a /// given parent, returning the unique new child (of that parent) thus created - pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node { + /// + /// # Panics + /// + /// * If `parent` is not in the `hugr` + /// + /// # Errors + /// + /// * If `self` is a [Self::Call] and the target Node either + /// * is neither a [FuncDefn] nor a [FuncDecl] + /// * has a [`signature`] which the type-args of the [Self::Call] do not match + /// + /// [`signature`]: hugr_core::types::PolyFuncType + pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Result { match self { - NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type), - NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root, + NodeTemplate::SingleOp(op_type) => Ok(hugr.add_node_with_parent(parent, op_type)), + NodeTemplate::CompoundOp(new_h) => Ok(hugr.insert_hugr(parent, *new_h).new_root), + NodeTemplate::Call(target, type_args) => { + let c = call(hugr, target, type_args)?; + let tgt_port = c.called_function_port(); + let n = hugr.add_node_with_parent(parent, c); + hugr.connect(target, 0, n, tgt_port); + Ok(n) + } } } @@ -72,10 +89,15 @@ impl NodeTemplate { match self { NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs), NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs), + // Really we should check whether func points at a FuncDecl or FuncDefn and create + // the appropriate variety of FuncID but it doesn't matter for the purpose of making a Call. + NodeTemplate::Call(func, type_args) => { + dfb.call(&FuncID::::from(func), &type_args, inputs) + } } } - fn replace(&self, hugr: &mut impl HugrMut, n: Node) { + fn replace(&self, hugr: &mut impl HugrMut, n: Node) -> Result<(), BuildError> { assert_eq!(hugr.children(n).count(), 0); let new_optype = match self.clone() { NodeTemplate::SingleOp(op_type) => op_type, @@ -88,19 +110,57 @@ impl NodeTemplate { } root_opty } + NodeTemplate::Call(func, type_args) => { + let c = call(hugr, func, type_args)?; + let static_inport = c.called_function_port(); + // insert an input for the Call static input + hugr.insert_ports(n, Direction::Incoming, static_inport.index(), 1); + // connect the function to (what will be) the call + hugr.connect(func, 0, n, static_inport); + c.into() + } }; *hugr.optype_mut(n) = new_optype; + Ok(()) } - fn signature(&self) -> Option> { - match self { + fn check_signature( + &self, + inputs: &TypeRow, + outputs: &TypeRow, + ) -> Result<(), Option> { + let sig = match self { NodeTemplate::SingleOp(op_type) => op_type, NodeTemplate::CompoundOp(hugr) => hugr.root_type(), + NodeTemplate::Call(_, _) => return Ok(()), // no way to tell + } + .dataflow_signature(); + if sig.as_deref().map(Signature::io) == Some((inputs, outputs)) { + Ok(()) + } else { + Err(sig.map(Cow::into_owned)) } - .dataflow_signature() } } +fn call>( + h: &H, + func: Node, + type_args: Vec, +) -> Result { + let func_sig = match h.get_optype(func) { + OpType::FuncDecl(fd) => fd.signature.clone(), + OpType::FuncDefn(fd) => fd.signature.clone(), + _ => { + return Err(BuildError::UnexpectedType { + node: func, + op_desc: "func defn/decl", + }) + } + }; + Ok(Call::try_new(func_sig, type_args)?) +} + /// A configuration of what types, ops, and constants should be replaced with what. /// May be applied to a Hugr via [Self::run]. /// @@ -186,6 +246,8 @@ pub enum ReplaceTypesError { ConstError(#[from] ConstTypeError), #[error(transparent)] LinearizeError(#[from] LinearizeError), + #[error("Replacement op for {0} could not be added because {1}")] + AddTemplateError(Node, BuildError), } impl ReplaceTypes { @@ -370,8 +432,11 @@ impl ReplaceTypes { OpType::Const(Const { value, .. }) => self.change_value(value), OpType::ExtensionOp(ext_op) => Ok( + // Copy/discard insertion done by caller if let Some(replacement) = self.op_map.get(&OpHashWrapper::from(&*ext_op)) { - replacement.replace(hugr, n); // Copy/discard insertion done by caller + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { let def = ext_op.def_arc(); @@ -382,7 +447,9 @@ impl ReplaceTypes { .get(&def.as_ref().into()) .and_then(|rep_fn| rep_fn(&args)) { - replacement.replace(hugr, n); + replacement + .replace(hugr, n) + .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { if ch { @@ -515,24 +582,22 @@ mod test { use std::sync::Arc; use hugr_core::builder::{ - inout_sig, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + FunctionBuilder, HugrBuilder, ModuleBuilder, SubContainer, TailLoopBuilder, }; use hugr_core::extension::prelude::{ - bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, + bool_t, option_type, qb_t, usize_t, ConstUsize, UnwrapBuilder, PRELUDE_ID, }; - use hugr_core::extension::{simple_op::MakeExtensionOp, TypeDefBound, Version}; + use hugr_core::extension::{simple_op::MakeExtensionOp, ExtensionSet, TypeDefBound, Version}; + use hugr_core::hugr::hugrmut::HugrMut; use hugr_core::hugr::{IdentList, ValidationError}; - use hugr_core::ops::{ - constant::OpaqueValue, ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value, - }; - use hugr_core::std_extensions::arithmetic::conversions::ConvertOpDef; + use hugr_core::ops::constant::OpaqueValue; + use hugr_core::ops::{ExtensionOp, NamedOp, OpTrait, OpType, Tag, Value}; + use hugr_core::std_extensions::arithmetic::conversions::{self, ConvertOpDef}; use hugr_core::std_extensions::arithmetic::int_types::{ConstInt, INT_TYPES}; - use hugr_core::std_extensions::collections::array::{ - array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue, - }; - use hugr_core::std_extensions::collections::list::{ - list_type, list_type_def, ListOp, ListValue, + use hugr_core::std_extensions::collections::{ + array::{self, array_type, array_type_def, ArrayOp, ArrayOpDef, ArrayValue}, + list::{list_type, list_type_def, ListOp, ListValue}, }; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; use hugr_core::{type_row, Extension, HugrView}; @@ -601,30 +666,37 @@ mod test { ) } - fn lowerer(ext: &Arc) -> ReplaceTypes { - fn lowered_read(args: &[TypeArg]) -> Option { - let ty = just_elem_type(args); - let mut dfb = DFGBuilder::new(inout_sig( - vec![array_type(64, ty.clone()), i64_t()], - ty.clone(), - )) + fn lowered_read( + elem_ty: Type, + new: impl Fn(Signature) -> Result, + ) -> T { + let mut dfb = new(Signature::new( + vec![array_type(64, elem_ty.clone()), i64_t()], + elem_ty.clone(), + ) + .with_extension_delta(ExtensionSet::from_iter([ + PRELUDE_ID, + array::EXTENSION_ID, + conversions::EXTENSION_ID, + ]))) + .unwrap(); + let [val, idx] = dfb.input_wires_arr(); + let [idx] = dfb + .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) + .unwrap() + .outputs_arr(); + let [opt] = dfb + .add_dataflow_op(ArrayOpDef::get.to_concrete(elem_ty.clone(), 64), [val, idx]) + .unwrap() + .outputs_arr(); + let [res] = dfb + .build_unwrap_sum(1, option_type(Type::from(elem_ty)), opt) .unwrap(); - let [val, idx] = dfb.input_wires_arr(); - let [idx] = dfb - .add_dataflow_op(ConvertOpDef::itousize.without_log_width(), [idx]) - .unwrap() - .outputs_arr(); - let [opt] = dfb - .add_dataflow_op(ArrayOpDef::get.to_concrete(ty.clone(), 64), [val, idx]) - .unwrap() - .outputs_arr(); - let [res] = dfb - .build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt) - .unwrap(); - Some(NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs([res]).unwrap(), - ))) - } + dfb.set_outputs([res]).unwrap(); + dfb + } + + fn lowerer(ext: &Arc) -> ReplaceTypes { let pv = ext.get_type(PACKED_VEC).unwrap(); let mut lw = ReplaceTypes::default(); lw.replace_type(pv.instantiate([bool_t().into()]).unwrap(), i64_t()); @@ -640,7 +712,13 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), Box::new(lowered_read)); + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + Some(NodeTemplate::CompoundOp(Box::new( + lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) + .finish_hugr() + .unwrap(), + ))) + }); lw } @@ -977,4 +1055,52 @@ mod test { repl.run(&mut h).unwrap(); h.validate_no_extensions().unwrap(); } + + #[test] + fn op_to_call() { + let e = ext(); + let pv = e.get_type(PACKED_VEC).unwrap(); + let inner = pv.instantiate([usize_t().into()]).unwrap(); + let outer = pv + .instantiate([Type::new_extension(inner.clone()).into()]) + .unwrap(); + let mut dfb = DFGBuilder::new(inout_sig(vec![outer.into(), i64_t()], usize_t())).unwrap(); + let [outer, idx] = dfb.input_wires_arr(); + let [inner] = dfb + .add_dataflow_op(read_op(&e, inner.clone().into()), [outer, idx]) + .unwrap() + .outputs_arr(); + let res = dfb + .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) + .unwrap(); + let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); + let read_func = h + .insert_hugr( + h.root(), + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap(), + ) + .new_root; + + let mut lw = lowerer(&e); + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + lw.run(&mut h).unwrap(); + + assert_eq!(h.output_neighbours(read_func).count(), 2); + let ext_op_names = h + .nodes() + .filter_map(|n| h.get_optype(n).as_extension_op()) + .map(|e| e.def().name()) + .sorted() + .collect_vec(); + assert_eq!(ext_op_names, ["get", "itousize", "panic",]); + } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..b6e6e6780 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -92,7 +92,7 @@ pub fn linearize_array( let [to_discard] = dfb.input_wires_arr(); lin.copy_discard_op(ty, 0)? .add(&mut dfb, [to_discard]) - .unwrap(); + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))?; let ret = dfb.add_load_value(Value::unary_unit_sum()); dfb.finish_hugr_with_outputs([ret]).unwrap() }; @@ -162,7 +162,7 @@ pub fn linearize_array( let mut copies = lin .copy_discard_op(ty, num_outports)? .add(&mut dfb, [elem]) - .unwrap() + .map_err(|e| LinearizeError::NestedTemplateError(ty.clone(), e))? .outputs(); let copy0 = copies.next().unwrap(); // We'll return this directly diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index bc508bd53..5c4a4a707 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -1,10 +1,9 @@ -use std::borrow::Cow; use std::iter::repeat; use std::{collections::HashMap, sync::Arc}; use hugr_core::builder::{ - inout_sig, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, - HugrBuilder, + inout_sig, BuildError, ConditionalBuilder, DFGBuilder, Dataflow, DataflowHugr, + DataflowSubContainer, HugrBuilder, }; use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; @@ -76,9 +75,11 @@ pub trait Linearizer { tgt_parent, }); } + let typ = typ.clone(); // Stop borrowing hugr in order to add_hugr to it let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); + .copy_discard_op(&typ, targets.len())? + .add_hugr(hugr, src_parent) + .map_err(|e| LinearizeError::NestedTemplateError(typ, e))?; for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); } @@ -133,7 +134,7 @@ impl Default for DelegatingLinearizer { // rather than passing a &DelegatingLinearizer directly) pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); -#[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] +#[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] #[non_exhaustive] pub enum LinearizeError { @@ -163,6 +164,10 @@ pub enum LinearizeError { /// Neither does linearization make sense for copyable types #[error("Type {_0} is copyable")] CopyableType(Type), + /// Error may be returned by a callback for e.g. a container because it could + /// not generate a [NodeTemplate] because of a problem with an element + #[error("Could not generate NodeTemplate for contained type {0} because {1}")] + NestedTemplateError(Type, BuildError), } impl DelegatingLinearizer { @@ -185,8 +190,10 @@ impl DelegatingLinearizer { /// /// * [LinearizeError::CopyableType] If `typ` is /// [Copyable](hugr_core::types::TypeBound::Copyable) - /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the - /// expected inputs or outputs + /// * [LinearizeError::WrongSignature] if `copy` or `discard` do not have the expected + /// inputs or outputs (for [NodeTemplate::SingleOp] and [NodeTemplate::CompoundOp] + /// only: the signature for a [NodeTemplate::Call] cannot be checked until it is used + /// in a Hugr). pub fn register_simple( &mut self, cty: CustomType, @@ -230,18 +237,12 @@ impl DelegatingLinearizer { } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { - let sig = tmpl.signature(); - if sig.as_ref().is_some_and(|sig| { - sig.io() == (&typ.clone().into(), &vec![typ.clone(); num_outports].into()) - }) { - Ok(()) - } else { - Err(LinearizeError::WrongSignature { + tmpl.check_signature(&typ.clone().into(), &vec![typ.clone(); num_outports].into()) + .map_err(|sig| LinearizeError::WrongSignature { typ: typ.clone(), num_outports, - sig: sig.map(Cow::into_owned), + sig, }) - } } impl Linearizer for DelegatingLinearizer { @@ -353,7 +354,10 @@ mod test { use std::iter::successors; use std::sync::Arc; - use hugr_core::builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer}; + use hugr_core::builder::{ + inout_sig, BuildError, Container, DFGBuilder, Dataflow, DataflowHugr, DataflowSubContainer, + HugrBuilder, + }; use hugr_core::extension::prelude::{option_type, usize_t}; use hugr_core::extension::simple_op::MakeExtensionOp; @@ -768,4 +772,68 @@ mod test { )); assert_eq!(copy_sig.input[2..], copy_sig.output[1..]); } + + #[test] + fn call_ok_except_in_array() { + let (e, _) = ext_lowerer(); + let lin_ct = e.get_type(LIN_T).unwrap().instantiate([]).unwrap(); + let lin_t: Type = lin_ct.clone().into(); + + // A simple Hugr that discards a usize_t, with a "drop" function + let mut dfb = DFGBuilder::new(inout_sig(usize_t(), type_row![])).unwrap(); + let discard_fn = { + let mut fb = dfb + .define_function( + "drop", + Signature::new(lin_t.clone(), type_row![]) + .with_extension_delta(e.name().clone()), + ) + .unwrap(); + let ins = fb.input_wires(); + fb.add_dataflow_op( + ExtensionOp::new(e.get_op("discard").unwrap().clone(), []).unwrap(), + ins, + ) + .unwrap(); + fb.finish_with_outputs([]).unwrap() + } + .node(); + let backup = dfb.finish_hugr().unwrap(); + + let mut lower_discard_to_call = ReplaceTypes::default(); + // The `copy_fn` here will break completely, but we don't use it + lower_discard_to_call + .linearizer() + .register_simple( + lin_ct.clone(), + NodeTemplate::Call(backup.root(), vec![]), + NodeTemplate::Call(discard_fn, vec![]), + ) + .unwrap(); + + // Ok to lower usize_t to lin_t and call that function + { + let mut lowerer = lower_discard_to_call.clone(); + lowerer.replace_type(usize_t().as_extension().unwrap().clone(), lin_t.clone()); + let mut h = backup.clone(); + lowerer.run(&mut h).unwrap(); + assert_eq!(h.output_neighbours(discard_fn).count(), 1); + } + + // But if we lower usize_t to array, the call will fail + lower_discard_to_call.replace_type( + usize_t().as_extension().unwrap().clone(), + array_type(4, lin_ct.into()), + ); + let r = lower_discard_to_call.run(&mut backup.clone()); + assert!(matches!( + r, + Err(ReplaceTypesError::LinearizeError( + LinearizeError::NestedTemplateError( + nested_t, + BuildError::UnexpectedType { node, .. } + ) + )) if nested_t == lin_t && node == discard_fn + )); + } } From 2430f5620d81672ef8f574b159834fe2d184be38 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 9 Apr 2025 11:49:16 +0100 Subject: [PATCH 10/18] (breaking) callback takes &mut CallbackHandler --- hugr-passes/src/replace_types/handlers.rs | 2 +- hugr-passes/src/replace_types/linearize.rs | 9 ++++++--- 2 files changed, 7 insertions(+), 4 deletions(-) diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index e835a2d9b..f0b07fec1 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -78,7 +78,7 @@ fn runtime_reqs(h: &Hugr) -> ExtensionSet { pub fn linearize_array( args: &[TypeArg], num_outports: usize, - lin: &CallbackHandler, + lin: &mut CallbackHandler, ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 7b83717d0..6977ef7e4 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -115,7 +115,9 @@ pub struct DelegatingLinearizer { // including lowering of the copy/discard operations to...whatever. copy_discard_parametric: HashMap< ParametricType, - Arc Result>, + Arc< + dyn Fn(&[TypeArg], usize, &mut CallbackHandler) -> Result, + >, >, } @@ -217,7 +219,7 @@ impl DelegatingLinearizer { pub fn register_callback( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &CallbackHandler) -> Result + copy_discard_fn: impl Fn(&[TypeArg], usize, &mut CallbackHandler) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise @@ -325,7 +327,8 @@ impl Linearizer for DelegatingLinearizer { .copy_discard_parametric .get(&cty.into()) .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; - let tmpl = copy_discard_fn(cty.args(), num_outports, &CallbackHandler(self))?; + let tmpl = + copy_discard_fn(cty.args(), num_outports, &mut CallbackHandler(self))?; check_sig(&tmpl, typ, num_outports)?; Ok(tmpl) } From 15f446710ff4de8937a4ee9222293d9d555c9321 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 13:37:21 +0100 Subject: [PATCH 11/18] rm trait Linearizer, CallbackHandler stores cache and (alone) provides methods --- hugr-passes/src/replace_types.rs | 5 +- hugr-passes/src/replace_types/handlers.rs | 2 +- hugr-passes/src/replace_types/linearize.rs | 213 +++++++++++---------- 3 files changed, 121 insertions(+), 99 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 3ed7337a9..995a99ca4 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -29,7 +29,7 @@ use hugr_core::{Hugr, HugrView, Node, Wire}; use crate::validation::{ValidatePassError, ValidationLevel}; mod linearize; -pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer}; +pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError}; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. @@ -331,6 +331,7 @@ impl ReplaceTypes { fn run_no_validate(&self, hugr: &mut impl HugrMut) -> Result { let mut changed = false; + let mut cache = HashMap::new(); for n in hugr.nodes().collect::>() { changed |= self.change_node(hugr, n)?; let new_dfsig = hugr.get_optype(n).dataflow_signature(); @@ -344,7 +345,7 @@ impl ReplaceTypes { if targets.len() != 1 { hugr.disconnect(n, outp); let src = Wire::new(n, outp); - self.linearize.insert_copy_discard(hugr, src, &targets)?; + self.linearize.handler(hugr, &mut cache).insert_copy_discard(src, &targets)?; } } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index f0b07fec1..065132f00 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -18,7 +18,7 @@ use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; use super::{ - CallbackHandler, LinearizeError, Linearizer, NodeTemplate, ReplaceTypes, ReplaceTypesError, + CallbackHandler, LinearizeError, NodeTemplate, ReplaceTypes, ReplaceTypesError, }; /// Handler for [ListValue] constants that updates the element type and diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 6977ef7e4..23f25291f 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -9,98 +9,11 @@ use hugr_core::builder::{ use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; -use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, HugrView, IncomingPort, Node, Wire}; +use hugr_core::{hugr::hugrmut::HugrMut, Hugr, ops::Tag, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; use super::{handlers::linearize_array, NodeTemplate, ParametricType}; -/// Trait for things that know how to wire up linear outports to other than one -/// target. Used to restore Hugr validity when a [ReplaceTypes](super::ReplaceTypes) -/// results in types of such outports changing from [Copyable] to linear (i.e. -/// [hugr_core::types::TypeBound::Any]). -/// -/// Note that this is not really effective before [monomorphization]: if a -/// function polymorphic over a [Copyable] becomes called with a -/// non-Copyable type argument, [Linearizer] cannot insert copy/discard -/// operations for such a case. However, following [monomorphization], there -/// would be a specific instantiation of the function for the -/// type-that-becomes-linear, into which copy/discard can be inserted. -/// -/// [monomorphization]: crate::monomorphize() -/// [Copyable]: hugr_core::types::TypeBound::Copyable -pub trait Linearizer { - /// Insert copy or discard operations (as appropriate) enough to wire `src` - /// up to all `targets`. - /// - /// The default implementation - /// * if `targets.len() == 1`, wires `src` to the unique target - /// * otherwise, makes a single call to [Self::copy_discard_op], inserts that op, - /// and wires its outputs 1:1 to each target - /// - /// # Errors - /// - /// Most variants of [LinearizeError] can be raised, specifically including - /// [LinearizeError::CopyableType] if the type is [Copyable], in which case the Hugr - /// will be unchanged. - /// - /// [Copyable]: hugr_core::types::TypeBound::Copyable - /// - /// # Panics - /// - /// if `src` is not a valid Wire (does not identify a dataflow out-port) - fn insert_copy_discard( - &self, - hugr: &mut impl HugrMut, - src: Wire, - targets: &[(Node, IncomingPort)], - ) -> Result<(), LinearizeError> { - let sig = hugr.signature(src.node()).unwrap(); - let typ = sig.port_type(src.source()).unwrap(); - let (tgt_node, tgt_inport) = if targets.len() == 1 { - *targets.first().unwrap() - } else { - // Fail fast if the edges are nonlocal. (TODO transform to local edges!) - let src_parent = hugr - .get_parent(src.node()) - .expect("Root node cannot have out edges"); - if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { - let tgt_parent = hugr - .get_parent(*tgt) - .expect("Root node cannot have incoming edges"); - (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) - }) { - return Err(LinearizeError::NoLinearNonLocalEdges { - src: src.node(), - src_parent, - tgt, - tgt_parent, - }); - } - let copy_discard_op = self - .copy_discard_op(typ, targets.len())? - .add_hugr(hugr, src_parent); - for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { - hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); - } - (copy_discard_op, 0.into()) - }; - hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); - Ok(()) - } - - /// Gets an [NodeTemplate] for copying or discarding a value of type `typ`, i.e. - /// a recipe for a node with one input of that type and the specified number of - /// outports. - /// - /// Implementations are free to panic if `num_outports == 1`, such calls should never - /// occur as source/target can be directly wired without any node/op being required. - fn copy_discard_op( - &self, - typ: &Type, - num_outports: usize, - ) -> Result; -} - /// A configuration for implementing [Linearizer] by delegating to /// type-specific callbacks, and by composing them in order to handle compound types /// such as [TypeEnum::Sum]s. @@ -129,11 +42,17 @@ impl Default for DelegatingLinearizer { } } +type FuncId = String; + /// Implementation of [Linearizer] passed to callbacks, (e.g.) so that callbacks for /// handling collection types can use it to generate copy/discards of elements. // (Note, this is its own type just to give a bit of room for future expansion, // rather than passing a &DelegatingLinearizer directly) -pub struct CallbackHandler<'a>(#[allow(dead_code)] &'a DelegatingLinearizer); +pub struct CallbackHandler<'a> { + hugr: &'a mut Hugr, + cache: &'a mut HashMap, + lin: &'a DelegatingLinearizer +} #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] #[allow(missing_docs)] @@ -245,11 +164,19 @@ fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), } } -impl Linearizer for DelegatingLinearizer { +impl DelegatingLinearizer { + pub fn handler<'a>(&'a self, hugr: &'a mut impl HugrMut, cache: &'a mut HashMap) -> CallbackHandler<'a> { + // ALAN ugh, can we avoid hugr_mut() here? Maybe by *not* storing the hugr-mut in the + // CallbackHandler (==> NodeTemplate::Call contains FuncID *or* Node) ? + CallbackHandler { hugr: hugr.hugr_mut(), cache, lin: self} + } + fn copy_discard_op( &self, typ: &Type, num_outports: usize, + hugr: &mut impl HugrMut, + cache: &mut HashMap, ) -> Result { if typ.copyable() { return Err(LinearizeError::CopyableType(typ.clone())); @@ -275,7 +202,7 @@ impl Linearizer for DelegatingLinearizer { let inp_copies = if ty.copyable() { repeat(inp).take(num_outports).collect::>() } else { - self.copy_discard_op(ty, num_outports)? + self.copy_discard_op(ty, num_outports, hugr, cache)? .add(&mut case_b, [inp]) .unwrap() .outputs() @@ -328,7 +255,7 @@ impl Linearizer for DelegatingLinearizer { .get(&cty.into()) .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; let tmpl = - copy_discard_fn(cty.args(), num_outports, &mut CallbackHandler(self))?; + copy_discard_fn(cty.args(), num_outports, &mut self.handler(hugr, cache))?; check_sig(&tmpl, typ, num_outports)?; Ok(tmpl) } @@ -339,13 +266,107 @@ impl Linearizer for DelegatingLinearizer { } } -impl Linearizer for CallbackHandler<'_> { - fn copy_discard_op( - &self, +/// Trait for things that know how to wire up linear outports to other than one +/// target. Used to restore Hugr validity when a [ReplaceTypes](super::ReplaceTypes) +/// results in types of such outports changing from [Copyable] to linear (i.e. +/// [hugr_core::types::TypeBound::Any]). +/// +/// Note that this is not really effective before [monomorphization]: if a +/// function polymorphic over a [Copyable] becomes called with a +/// non-Copyable type argument, [Linearizer] cannot insert copy/discard +/// operations for such a case. However, following [monomorphization], there +/// would be a specific instantiation of the function for the +/// type-that-becomes-linear, into which copy/discard can be inserted. +/// +/// [monomorphization]: crate::monomorphize() +/// [Copyable]: hugr_core::types::TypeBound::Copyable +impl CallbackHandler<'_> { + /// Callbacks can use this to make a function in the Hugr. + /// The first call for a given `id` will call `body`, which must return + /// a [FuncDefn]-rooted Hugr, and insert that into the underlying Hugr; + /// the node containing the newly-inserted FuncDefn is returned. + /// + /// A second call with the same `id` will return the node from the first + /// call, without executing `body`. + pub fn make_function(&mut self, id: FuncId, body: impl Fn() -> Hugr) -> Node { + if let Some(n) = self.cache.get(&id) {return *n;} + let h = body(); + let n = self.hugr.insert_hugr(self.hugr.root(), h).new_root; + self.cache.insert(id, n); + n + } + + /// Insert copy or discard operations (as appropriate) enough to wire `src` + /// up to all `targets`. + /// + /// The default implementation + /// * if `targets.len() == 1`, wires `src` to the unique target + /// * otherwise, makes a single call to [Self::copy_discard_op], inserts that op, + /// and wires its outputs 1:1 to each target + /// + /// # Errors + /// + /// Most variants of [LinearizeError] can be raised, specifically including + /// [LinearizeError::CopyableType] if the type is [Copyable], in which case the Hugr + /// will be unchanged. + /// + /// [Copyable]: hugr_core::types::TypeBound::Copyable + /// + /// # Panics + /// + /// if `src` is not a valid Wire (does not identify a dataflow out-port) + pub fn insert_copy_discard( + &mut self, + src: Wire, + targets: &[(Node, IncomingPort)], + ) -> Result<(), LinearizeError> { + let sig = self.hugr.signature(src.node()).unwrap(); + // Must clone here to avoid borrowing part of `self` + let typ = sig.port_type(src.source()).unwrap().clone(); + let (tgt_node, tgt_inport) = if targets.len() == 1 { + *targets.first().unwrap() + } else { + // Fail fast if the edges are nonlocal. (TODO transform to local edges!) + let src_parent = self.hugr + .get_parent(src.node()) + .expect("Root node cannot have out edges"); + if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { + let tgt_parent = self.hugr + .get_parent(*tgt) + .expect("Root node cannot have incoming edges"); + (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) + }) { + return Err(LinearizeError::NoLinearNonLocalEdges { + src: src.node(), + src_parent, + tgt, + tgt_parent, + }); + } + let copy_discard_op = self + .copy_discard_op(&typ, targets.len())? + .add_hugr(self.hugr, src_parent); + for (n, (tgt_node, tgt_port)) in targets.iter().enumerate() { + self.hugr.connect(copy_discard_op, n, *tgt_node, *tgt_port); + } + (copy_discard_op, 0.into()) + }; + self.hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); + Ok(()) + } + + /// Gets an [NodeTemplate] for copying or discarding a value of type `typ`, i.e. + /// a recipe for a node with one input of that type and the specified number of + /// outports. + /// + /// Implementations are free to panic if `num_outports == 1`, such calls should never + /// occur as source/target can be directly wired without any node/op being required. + pub fn copy_discard_op( + &mut self, typ: &Type, num_outports: usize, ) -> Result { - self.0.copy_discard_op(typ, num_outports) + self.lin.copy_discard_op(typ, num_outports, self.hugr, self.cache) } } From 28ac1b14e8e1be5da9b52964e920a3062d3d5f41 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 13:38:07 +0100 Subject: [PATCH 12/18] Tidy up, moving copy_discard_op --- hugr-passes/src/replace_types.rs | 4 +- hugr-passes/src/replace_types/handlers.rs | 4 +- hugr-passes/src/replace_types/linearize.rs | 221 +++++++++++---------- 3 files changed, 115 insertions(+), 114 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 995a99ca4..c257c2f86 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -345,7 +345,9 @@ impl ReplaceTypes { if targets.len() != 1 { hugr.disconnect(n, outp); let src = Wire::new(n, outp); - self.linearize.handler(hugr, &mut cache).insert_copy_discard(src, &targets)?; + self.linearize + .handler(hugr, &mut cache) + .insert_copy_discard(src, &targets)?; } } } diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 065132f00..17cddc0a8 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -17,9 +17,7 @@ use hugr_core::types::{SumType, Transformable, Type, TypeArg}; use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; -use super::{ - CallbackHandler, LinearizeError, NodeTemplate, ReplaceTypes, ReplaceTypesError, -}; +use super::{CallbackHandler, LinearizeError, NodeTemplate, ReplaceTypes, ReplaceTypesError}; /// Handler for [ListValue] constants that updates the element type and /// recursively [ReplaceTypes::change_value]s the elements of the list. diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 23f25291f..cf3fefe82 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -9,7 +9,7 @@ use hugr_core::builder::{ use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; -use hugr_core::{hugr::hugrmut::HugrMut, Hugr, ops::Tag, HugrView, IncomingPort, Node, Wire}; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, Hugr, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; use super::{handlers::linearize_array, NodeTemplate, ParametricType}; @@ -51,7 +51,7 @@ type FuncId = String; pub struct CallbackHandler<'a> { hugr: &'a mut Hugr, cache: &'a mut HashMap, - lin: &'a DelegatingLinearizer + lin: &'a DelegatingLinearizer, } #[derive(Clone, Debug, thiserror::Error, PartialEq, Eq)] @@ -147,6 +147,20 @@ impl DelegatingLinearizer { self.copy_discard_parametric .insert(src.into(), Arc::new(copy_discard_fn)); } + + pub fn handler<'a>( + &'a self, + hugr: &'a mut impl HugrMut, + cache: &'a mut HashMap, + ) -> CallbackHandler<'a> { + // ALAN ugh, can we avoid hugr_mut() here? Maybe by *not* storing the hugr-mut in the + // CallbackHandler (==> NodeTemplate::Call contains FuncID *or* Node) ? + CallbackHandler { + hugr: hugr.hugr_mut(), + cache, + lin: self, + } + } } fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), LinearizeError> { @@ -164,108 +178,6 @@ fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), } } -impl DelegatingLinearizer { - pub fn handler<'a>(&'a self, hugr: &'a mut impl HugrMut, cache: &'a mut HashMap) -> CallbackHandler<'a> { - // ALAN ugh, can we avoid hugr_mut() here? Maybe by *not* storing the hugr-mut in the - // CallbackHandler (==> NodeTemplate::Call contains FuncID *or* Node) ? - CallbackHandler { hugr: hugr.hugr_mut(), cache, lin: self} - } - - fn copy_discard_op( - &self, - typ: &Type, - num_outports: usize, - hugr: &mut impl HugrMut, - cache: &mut HashMap, - ) -> Result { - if typ.copyable() { - return Err(LinearizeError::CopyableType(typ.clone())); - }; - assert!(num_outports != 1); - - match typ.as_type_enum() { - TypeEnum::Sum(sum_type) => { - let variants = sum_type - .variants() - .map(|trv| trv.clone().try_into()) - .collect::, _>>()?; - let mut cb = ConditionalBuilder::new( - variants.clone(), - vec![], - vec![sum_type.clone().into(); num_outports], - ) - .unwrap(); - for (tag, variant) in variants.iter().enumerate() { - let mut case_b = cb.case_builder(tag).unwrap(); - let mut elems_for_copy = vec![vec![]; num_outports]; - for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { - let inp_copies = if ty.copyable() { - repeat(inp).take(num_outports).collect::>() - } else { - self.copy_discard_op(ty, num_outports, hugr, cache)? - .add(&mut case_b, [inp]) - .unwrap() - .outputs() - .collect() - }; - for (src, elems) in inp_copies.into_iter().zip_eq(elems_for_copy.iter_mut()) - { - elems.push(src) - } - } - let t = Tag::new(tag, variants.clone()); - let outputs = elems_for_copy - .into_iter() - .map(|elems| { - let [copy] = case_b - .add_dataflow_op(t.clone(), elems) - .unwrap() - .outputs_arr(); - copy - }) - .collect::>(); // must collect to end borrow of `case_b` by closure - case_b.finish_with_outputs(outputs).unwrap(); - } - Ok(NodeTemplate::CompoundOp(Box::new( - cb.finish_hugr().unwrap(), - ))) - } - TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { - Some((copy, discard)) => Ok(if num_outports == 0 { - discard.clone() - } else { - let mut dfb = - DFGBuilder::new(inout_sig(typ.clone(), vec![typ.clone(); num_outports])) - .unwrap(); - let [mut src] = dfb.input_wires_arr(); - let mut outputs = vec![]; - for _ in 0..num_outports - 1 { - let [out0, out1] = copy.clone().add(&mut dfb, [src]).unwrap().outputs_arr(); - outputs.push(out0); - src = out1; - } - outputs.push(src); - NodeTemplate::CompoundOp(Box::new( - dfb.finish_hugr_with_outputs(outputs).unwrap(), - )) - }), - None => { - let copy_discard_fn = self - .copy_discard_parametric - .get(&cty.into()) - .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; - let tmpl = - copy_discard_fn(cty.args(), num_outports, &mut self.handler(hugr, cache))?; - check_sig(&tmpl, typ, num_outports)?; - Ok(tmpl) - } - }, - TypeEnum::Function(_) => panic!("Ruled out above as copyable"), - _ => Err(LinearizeError::UnsupportedType(typ.clone())), - } - } -} - /// Trait for things that know how to wire up linear outports to other than one /// target. Used to restore Hugr validity when a [ReplaceTypes](super::ReplaceTypes) /// results in types of such outports changing from [Copyable] to linear (i.e. @@ -285,11 +197,13 @@ impl CallbackHandler<'_> { /// The first call for a given `id` will call `body`, which must return /// a [FuncDefn]-rooted Hugr, and insert that into the underlying Hugr; /// the node containing the newly-inserted FuncDefn is returned. - /// + /// /// A second call with the same `id` will return the node from the first /// call, without executing `body`. pub fn make_function(&mut self, id: FuncId, body: impl Fn() -> Hugr) -> Node { - if let Some(n) = self.cache.get(&id) {return *n;} + if let Some(n) = self.cache.get(&id) { + return *n; + } let h = body(); let n = self.hugr.insert_hugr(self.hugr.root(), h).new_root; self.cache.insert(id, n); @@ -327,11 +241,13 @@ impl CallbackHandler<'_> { *targets.first().unwrap() } else { // Fail fast if the edges are nonlocal. (TODO transform to local edges!) - let src_parent = self.hugr + let src_parent = self + .hugr .get_parent(src.node()) .expect("Root node cannot have out edges"); if let Some((tgt, tgt_parent)) = targets.iter().find_map(|(tgt, _)| { - let tgt_parent = self.hugr + let tgt_parent = self + .hugr .get_parent(*tgt) .expect("Root node cannot have incoming edges"); (tgt_parent != src_parent).then_some((*tgt, tgt_parent)) @@ -351,7 +267,8 @@ impl CallbackHandler<'_> { } (copy_discard_op, 0.into()) }; - self.hugr.connect(src.node(), src.source(), tgt_node, tgt_inport); + self.hugr + .connect(src.node(), src.source(), tgt_node, tgt_inport); Ok(()) } @@ -366,7 +283,91 @@ impl CallbackHandler<'_> { typ: &Type, num_outports: usize, ) -> Result { - self.lin.copy_discard_op(typ, num_outports, self.hugr, self.cache) + if typ.copyable() { + return Err(LinearizeError::CopyableType(typ.clone())); + }; + assert!(num_outports != 1); + + match typ.as_type_enum() { + TypeEnum::Sum(sum_type) => { + let variants = sum_type + .variants() + .map(|trv| trv.clone().try_into()) + .collect::, _>>()?; + let mut cb = ConditionalBuilder::new( + variants.clone(), + vec![], + vec![sum_type.clone().into(); num_outports], + ) + .unwrap(); + for (tag, variant) in variants.iter().enumerate() { + let mut case_b = cb.case_builder(tag).unwrap(); + let mut elems_for_copy = vec![vec![]; num_outports]; + for (inp, ty) in case_b.input_wires().zip_eq(variant.iter()) { + let inp_copies = if ty.copyable() { + repeat(inp).take(num_outports).collect::>() + } else { + self.copy_discard_op(ty, num_outports)? + .add(&mut case_b, [inp]) + .unwrap() + .outputs() + .collect() + }; + for (src, elems) in inp_copies.into_iter().zip_eq(elems_for_copy.iter_mut()) + { + elems.push(src) + } + } + let t = Tag::new(tag, variants.clone()); + let outputs = elems_for_copy + .into_iter() + .map(|elems| { + let [copy] = case_b + .add_dataflow_op(t.clone(), elems) + .unwrap() + .outputs_arr(); + copy + }) + .collect::>(); // must collect to end borrow of `case_b` by closure + case_b.finish_with_outputs(outputs).unwrap(); + } + Ok(NodeTemplate::CompoundOp(Box::new( + cb.finish_hugr().unwrap(), + ))) + } + TypeEnum::Extension(cty) => match self.lin.copy_discard.get(cty) { + Some((copy, discard)) => Ok(if num_outports == 0 { + discard.clone() + } else { + let mut dfb = + DFGBuilder::new(inout_sig(typ.clone(), vec![typ.clone(); num_outports])) + .unwrap(); + let [mut src] = dfb.input_wires_arr(); + let mut outputs = vec![]; + for _ in 0..num_outports - 1 { + let [out0, out1] = copy.clone().add(&mut dfb, [src]).unwrap().outputs_arr(); + outputs.push(out0); + src = out1; + } + outputs.push(src); + NodeTemplate::CompoundOp(Box::new( + dfb.finish_hugr_with_outputs(outputs).unwrap(), + )) + }), + None => { + let copy_discard_fn = self + .lin + .copy_discard_parametric + .get(&cty.into()) + .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; + let tmpl = copy_discard_fn(cty.args(), num_outports, self)?; + check_sig(&tmpl, typ, num_outports)?; + Ok(tmpl) + } + }, + TypeEnum::Function(_) => panic!("Ruled out above as copyable"), + _ => Err(LinearizeError::UnsupportedType(typ.clone())), + } } } From 7a1d2b0b7c362f498abb71b78b9bd63dd4c404da Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 23:06:18 +0100 Subject: [PATCH 13/18] Define replace_types::CallbackHandler similarly, duplicate create_function --- hugr-passes/src/replace_types.rs | 105 ++++++++++++++++----- hugr-passes/src/replace_types/handlers.rs | 4 +- hugr-passes/src/replace_types/linearize.rs | 4 +- 3 files changed, 84 insertions(+), 29 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 9f8886ca5..ed8559db5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -30,7 +30,11 @@ use hugr_core::{Direction, Hugr, HugrView, Node, PortIndex, Wire}; use crate::ComposablePass; mod linearize; -pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError}; +pub use linearize::{DelegatingLinearizer, LinearizeError}; + +/// Key passed to [CallbackHandler::makee_function] to de-duplicate +/// attempts to add the same function +pub type FuncId = String; /// A recipe for creating a dataflow Node - as a new child of a [DataflowParent] /// or in order to replace an existing node. @@ -194,7 +198,10 @@ pub struct ReplaceTypes { param_types: HashMap Option>>, linearize: DelegatingLinearizer, op_map: HashMap, - param_ops: HashMap Option>>, + param_ops: HashMap< + ParametricOp, + Arc Option>, + >, consts: HashMap< CustomType, Arc Result>, @@ -343,7 +350,7 @@ impl ReplaceTypes { pub fn replace_parametrized_op( &mut self, src: &OpDef, - dest_fn: impl Fn(&[TypeArg]) -> Option + 'static, + dest_fn: impl Fn(&[TypeArg], &mut CallbackHandler) -> Option + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -375,7 +382,12 @@ impl ReplaceTypes { self.param_consts.insert(src_ty.into(), Arc::new(const_fn)); } - fn change_node(&self, hugr: &mut impl HugrMut, n: Node) -> Result { + fn change_node( + &self, + hugr: &mut impl HugrMut, + cache: &mut HashMap, + n: Node, + ) -> Result { match hugr.optype_mut(n) { OpType::FuncDefn(FuncDefn { signature, .. }) | OpType::FuncDecl(FuncDecl { signature, .. }) => signature.body_mut().transform(self), @@ -439,13 +451,13 @@ impl ReplaceTypes { .map_err(|e| ReplaceTypesError::AddTemplateError(n, e))?; true } else { - let def = ext_op.def_arc(); + let def = ext_op.def_arc().clone(); let mut args = ext_op.args().to_vec(); let ch = args.transform(self)?; if let Some(replacement) = self .param_ops .get(&def.as_ref().into()) - .and_then(|rep_fn| rep_fn(&args)) + .and_then(|rep_fn| rep_fn(&args, &mut self.handler(hugr, cache))) { replacement .replace(hugr, n) @@ -453,7 +465,8 @@ impl ReplaceTypes { true } else { if ch { - *ext_op = ExtensionOp::new(def.clone(), args)?; + // can't use ext_op here, as it can't be borrowed while passing self to `rep_fn` + hugr.replace_op(n, ExtensionOp::new(def, args)?).unwrap(); } ch } @@ -467,6 +480,20 @@ impl ReplaceTypes { } } + fn handler<'a>( + &'a self, + hugr: &'a mut impl HugrMut, + cache: &'a mut HashMap, + ) -> CallbackHandler<'a> { + // ALAN ugh, can we avoid hugr_mut() here? Maybe by *not* storing the hugr-mut in the + // CallbackHandler (==> NodeTemplate::Call contains FuncID *or* Node) ? + CallbackHandler { + hugr: hugr.hugr_mut(), + cache, + repl: self, + } + } + /// Modifies the specified Value in-place according to current configuration. /// Returns whether the value has changed (conservative over-approximation). pub fn change_value(&self, value: &mut Value) -> Result { @@ -504,6 +531,39 @@ impl ReplaceTypes { } } +/// struct passed to callbacks registered via [ReplaceTypes::replace_parametrized_op]. +/// The callbacks may use this to create functions to be called via [NodeTemplate::Call]. +pub struct CallbackHandler<'a> { + hugr: &'a mut Hugr, + cache: &'a mut HashMap, + repl: &'a ReplaceTypes, +} + +impl CallbackHandler<'_> { + /// Callbacks can use this to make a function in the Hugr. + /// The first call for a given `id` will call `body`, which must return + /// a [FuncDefn]-rooted Hugr, and insert that into the underlying Hugr; + /// the node containing the newly-inserted FuncDefn is returned. + /// + /// A second call with the same `id` will return the node from the first + /// call, without executing `body`. + pub fn make_function(&mut self, id: FuncId, body: impl Fn() -> Hugr) -> Node { + if let Some(n) = self.cache.get(&id) { + return *n; + } + let h = body(); + let n = self.hugr.insert_hugr(self.hugr.root(), h).new_root; + self.cache.insert(id, n); + n + } + + /// Allows access to the [ReplaceTypes] i.e. which implements [TypeTransformer] + /// to pass to [Type::transform] + pub fn replace_types(&self) -> &ReplaceTypes { + &self.repl + } +} + impl ComposablePass for ReplaceTypes { type Error = ReplaceTypesError; type Result = bool; @@ -512,7 +572,7 @@ impl ComposablePass for ReplaceTypes { let mut changed = false; let mut cache = HashMap::new(); for n in hugr.nodes().collect::>() { - changed |= self.change_node(hugr, n)?; + changed |= self.change_node(hugr, &mut cache, n)?; let new_dfsig = hugr.get_optype(n).dataflow_signature(); if let Some(new_sig) = new_dfsig .filter(|_| changed && n != hugr.root()) @@ -715,7 +775,7 @@ mod test { .into(), ), ); - lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args| { + lw.replace_parametrized_op(ext.get_op(READ).unwrap().as_ref(), |type_args, _| { Some(NodeTemplate::CompoundOp(Box::new( lowered_read(just_elem_type(type_args).clone(), DFGBuilder::new) .finish_hugr() @@ -988,20 +1048,17 @@ mod test { option_contents(just_elem_type(args)).map(list_type) }); // and read> to get - the latter has the expected option return type - lowerer.replace_parametrized_op( - e.get_op(READ).unwrap().as_ref(), - Box::new(|args: &[TypeArg]| { - option_contents(just_elem_type(args)).map(|elem| { - NodeTemplate::SingleOp( - ListOp::get - .with_type(elem) - .to_extension_op() - .unwrap() - .into(), - ) - }) - }), - ); + lowerer.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), |args, _| { + option_contents(just_elem_type(args)).map(|elem| { + NodeTemplate::SingleOp( + ListOp::get + .with_type(elem) + .to_extension_op() + .unwrap() + .into(), + ) + }) + }); assert!(lowerer.run(&mut h).unwrap()); // list -> read -> usz just becomes list -> read -> qb // list> -> read> -> opt becomes list -> get -> opt @@ -1092,7 +1149,7 @@ mod test { .new_root; let mut lw = lowerer(&e); - lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args| { + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| { Some(NodeTemplate::Call(read_func, args.to_owned())) }); lw.run(&mut h).unwrap(); diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 9c3f17896..3ac7e9d39 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -17,7 +17,7 @@ use hugr_core::types::{SumType, Transformable, Type, TypeArg}; use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; -use super::{CallbackHandler, LinearizeError, NodeTemplate, ReplaceTypes, ReplaceTypesError}; +use super::{linearize, LinearizeError, NodeTemplate, ReplaceTypes, ReplaceTypesError}; /// Handler for [ListValue] constants that updates the element type and /// recursively [ReplaceTypes::change_value]s the elements of the list. @@ -76,7 +76,7 @@ fn runtime_reqs(h: &Hugr) -> ExtensionSet { pub fn linearize_array( args: &[TypeArg], num_outports: usize, - lin: &mut CallbackHandler, + lin: &mut linearize::CallbackHandler, ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 28b7e0cd9..1db5ebf4b 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -11,7 +11,7 @@ use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, Hugr, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; -use super::{handlers::linearize_array, NodeTemplate, ParametricType}; +use super::{handlers::linearize_array, FuncId, NodeTemplate, ParametricType}; /// A configuration for implementing [Linearizer] by delegating to /// type-specific callbacks, and by composing them in order to handle compound types @@ -41,8 +41,6 @@ impl Default for DelegatingLinearizer { } } -type FuncId = String; - /// Implementation of [Linearizer] passed to callbacks, (e.g.) so that callbacks for /// handling collection types can use it to generate copy/discards of elements. // (Note, this is its own type just to give a bit of room for future expansion, From abc037f96ac7a04a4c3018e5a81c4172d85b777b Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Wed, 23 Apr 2025 23:24:24 +0100 Subject: [PATCH 14/18] Missing doc...needs more updates --- hugr-passes/src/replace_types.rs | 2 +- hugr-passes/src/replace_types/linearize.rs | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index ed8559db5..a6579a2f8 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -560,7 +560,7 @@ impl CallbackHandler<'_> { /// Allows access to the [ReplaceTypes] i.e. which implements [TypeTransformer] /// to pass to [Type::transform] pub fn replace_types(&self) -> &ReplaceTypes { - &self.repl + self.repl } } diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index 1db5ebf4b..ae155c96f 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -152,6 +152,9 @@ impl DelegatingLinearizer { .insert(src.into(), Arc::new(copy_discard_fn)); } + /// Obtains a [CallbackHandler] as would be passed to a callback registered with + /// [Self::register_callback] allowing to insert copy/discard ops and call + /// [CallbackHandler::make_function] pub fn handler<'a>( &'a self, hugr: &'a mut impl HugrMut, From 42daba50676d73845d8e42d7f7affa81dd316df0 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Apr 2025 15:48:23 +0100 Subject: [PATCH 15/18] Combine the two CallbackHandler structs by (erm...) deref --- hugr-passes/src/replace_types.rs | 30 +++++++------ hugr-passes/src/replace_types/handlers.rs | 7 ++- hugr-passes/src/replace_types/linearize.rs | 52 +++++++--------------- 3 files changed, 39 insertions(+), 50 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index a6579a2f8..8a38bee17 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -4,6 +4,7 @@ //! use std::borrow::Cow; use std::collections::HashMap; +use std::ops::Deref; use std::sync::Arc; use handlers::list_const; @@ -200,7 +201,7 @@ pub struct ReplaceTypes { op_map: HashMap, param_ops: HashMap< ParametricOp, - Arc Option>, + Arc) -> Option>, >, consts: HashMap< CustomType, @@ -350,7 +351,8 @@ impl ReplaceTypes { pub fn replace_parametrized_op( &mut self, src: &OpDef, - dest_fn: impl Fn(&[TypeArg], &mut CallbackHandler) -> Option + 'static, + dest_fn: impl Fn(&[TypeArg], &mut CallbackHandler) -> Option + + 'static, ) { self.param_ops.insert(src.into(), Arc::new(dest_fn)); } @@ -484,13 +486,13 @@ impl ReplaceTypes { &'a self, hugr: &'a mut impl HugrMut, cache: &'a mut HashMap, - ) -> CallbackHandler<'a> { + ) -> CallbackHandler<'a, ReplaceTypes> { // ALAN ugh, can we avoid hugr_mut() here? Maybe by *not* storing the hugr-mut in the // CallbackHandler (==> NodeTemplate::Call contains FuncID *or* Node) ? CallbackHandler { hugr: hugr.hugr_mut(), cache, - repl: self, + deref: self, } } @@ -533,13 +535,21 @@ impl ReplaceTypes { /// struct passed to callbacks registered via [ReplaceTypes::replace_parametrized_op]. /// The callbacks may use this to create functions to be called via [NodeTemplate::Call]. -pub struct CallbackHandler<'a> { +pub struct CallbackHandler<'a, T> { hugr: &'a mut Hugr, cache: &'a mut HashMap, - repl: &'a ReplaceTypes, + deref: &'a T, } -impl CallbackHandler<'_> { +impl<'a, T> Deref for CallbackHandler<'a, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + &self.deref + } +} + +impl CallbackHandler<'_, T> { /// Callbacks can use this to make a function in the Hugr. /// The first call for a given `id` will call `body`, which must return /// a [FuncDefn]-rooted Hugr, and insert that into the underlying Hugr; @@ -556,12 +566,6 @@ impl CallbackHandler<'_> { self.cache.insert(id, n); n } - - /// Allows access to the [ReplaceTypes] i.e. which implements [TypeTransformer] - /// to pass to [Type::transform] - pub fn replace_types(&self) -> &ReplaceTypes { - self.repl - } } impl ComposablePass for ReplaceTypes { diff --git a/hugr-passes/src/replace_types/handlers.rs b/hugr-passes/src/replace_types/handlers.rs index 3ac7e9d39..95ac70e52 100644 --- a/hugr-passes/src/replace_types/handlers.rs +++ b/hugr-passes/src/replace_types/handlers.rs @@ -17,7 +17,10 @@ use hugr_core::types::{SumType, Transformable, Type, TypeArg}; use hugr_core::{type_row, Hugr, HugrView}; use itertools::Itertools; -use super::{linearize, LinearizeError, NodeTemplate, ReplaceTypes, ReplaceTypesError}; +use super::{ + CallbackHandler, DelegatingLinearizer, LinearizeError, NodeTemplate, ReplaceTypes, + ReplaceTypesError, +}; /// Handler for [ListValue] constants that updates the element type and /// recursively [ReplaceTypes::change_value]s the elements of the list. @@ -76,7 +79,7 @@ fn runtime_reqs(h: &Hugr) -> ExtensionSet { pub fn linearize_array( args: &[TypeArg], num_outports: usize, - lin: &mut linearize::CallbackHandler, + lin: &mut CallbackHandler, ) -> Result { // Require known length i.e. usable only after monomorphization, due to no-variables limitation // restriction on NodeTemplate::CompoundOp diff --git a/hugr-passes/src/replace_types/linearize.rs b/hugr-passes/src/replace_types/linearize.rs index ae155c96f..342a75546 100644 --- a/hugr-passes/src/replace_types/linearize.rs +++ b/hugr-passes/src/replace_types/linearize.rs @@ -8,9 +8,10 @@ use hugr_core::builder::{ use hugr_core::extension::{SignatureError, TypeDef}; use hugr_core::std_extensions::collections::array::array_type_def; use hugr_core::types::{CustomType, Signature, Type, TypeArg, TypeEnum, TypeRow}; -use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, Hugr, HugrView, IncomingPort, Node, Wire}; +use hugr_core::{hugr::hugrmut::HugrMut, ops::Tag, HugrView, IncomingPort, Node, Wire}; use itertools::Itertools; +use super::CallbackHandler; use super::{handlers::linearize_array, FuncId, NodeTemplate, ParametricType}; /// A configuration for implementing [Linearizer] by delegating to @@ -28,7 +29,11 @@ pub struct DelegatingLinearizer { copy_discard_parametric: HashMap< ParametricType, Arc< - dyn Fn(&[TypeArg], usize, &mut CallbackHandler) -> Result, + dyn Fn( + &[TypeArg], + usize, + &mut CallbackHandler, + ) -> Result, >, >, } @@ -41,16 +46,6 @@ impl Default for DelegatingLinearizer { } } -/// Implementation of [Linearizer] passed to callbacks, (e.g.) so that callbacks for -/// handling collection types can use it to generate copy/discards of elements. -// (Note, this is its own type just to give a bit of room for future expansion, -// rather than passing a &DelegatingLinearizer directly) -pub struct CallbackHandler<'a> { - hugr: &'a mut Hugr, - cache: &'a mut HashMap, - lin: &'a DelegatingLinearizer, -} - #[derive(Clone, Debug, thiserror::Error, PartialEq)] #[allow(missing_docs)] #[non_exhaustive] @@ -142,7 +137,11 @@ impl DelegatingLinearizer { pub fn register_callback( &mut self, src: &TypeDef, - copy_discard_fn: impl Fn(&[TypeArg], usize, &mut CallbackHandler) -> Result + copy_discard_fn: impl Fn( + &[TypeArg], + usize, + &mut CallbackHandler, + ) -> Result + 'static, ) { // We could look for `src`s TypeDefBound being explicit Copyable, otherwise @@ -159,13 +158,13 @@ impl DelegatingLinearizer { &'a self, hugr: &'a mut impl HugrMut, cache: &'a mut HashMap, - ) -> CallbackHandler<'a> { + ) -> CallbackHandler<'a, DelegatingLinearizer> { // ALAN ugh, can we avoid hugr_mut() here? Maybe by *not* storing the hugr-mut in the // CallbackHandler (==> NodeTemplate::Call contains FuncID *or* Node) ? CallbackHandler { hugr: hugr.hugr_mut(), cache, - lin: self, + deref: self, } } } @@ -193,24 +192,7 @@ fn check_sig(tmpl: &NodeTemplate, typ: &Type, num_outports: usize) -> Result<(), /// /// [monomorphization]: crate::monomorphize() /// [Copyable]: hugr_core::types::TypeBound::Copyable -impl CallbackHandler<'_> { - /// Callbacks can use this to make a function in the Hugr. - /// The first call for a given `id` will call `body`, which must return - /// a [FuncDefn]-rooted Hugr, and insert that into the underlying Hugr; - /// the node containing the newly-inserted FuncDefn is returned. - /// - /// A second call with the same `id` will return the node from the first - /// call, without executing `body`. - pub fn make_function(&mut self, id: FuncId, body: impl Fn() -> Hugr) -> Node { - if let Some(n) = self.cache.get(&id) { - return *n; - } - let h = body(); - let n = self.hugr.insert_hugr(self.hugr.root(), h).new_root; - self.cache.insert(id, n); - n - } - +impl CallbackHandler<'_, DelegatingLinearizer> { /// Insert copy or discard operations (as appropriate) enough to wire `src` /// up to all `targets`. /// @@ -337,7 +319,7 @@ impl CallbackHandler<'_> { cb.finish_hugr().unwrap(), ))) } - TypeEnum::Extension(cty) => match self.lin.copy_discard.get(cty) { + TypeEnum::Extension(cty) => match self.copy_discard.get(cty) { Some((copy, discard)) => Ok(if num_outports == 0 { discard.clone() } else { @@ -358,7 +340,7 @@ impl CallbackHandler<'_> { }), None => { let copy_discard_fn = self - .lin + .deref .copy_discard_parametric .get(&cty.into()) .ok_or_else(|| LinearizeError::NeedCopyDiscard(typ.clone()))?; From cb2e138c9b918baa6bfb43def2d20ea03bd1ec44 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Apr 2025 15:52:17 +0100 Subject: [PATCH 16/18] separate get_function from make_function, removing callback --- hugr-passes/src/replace_types.rs | 28 +++++++++++++++------------- 1 file changed, 15 insertions(+), 13 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 8a38bee17..a43c3a37c 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -3,6 +3,7 @@ //! Replace types with other types across the Hugr. See [ReplaceTypes] and [Linearizer]. //! use std::borrow::Cow; +use std::collections::hash_map::Entry; use std::collections::HashMap; use std::ops::Deref; use std::sync::Arc; @@ -550,21 +551,22 @@ impl<'a, T> Deref for CallbackHandler<'a, T> { } impl CallbackHandler<'_, T> { - /// Callbacks can use this to make a function in the Hugr. - /// The first call for a given `id` will call `body`, which must return - /// a [FuncDefn]-rooted Hugr, and insert that into the underlying Hugr; - /// the node containing the newly-inserted FuncDefn is returned. + /// Returns any Node previously created by [Self::make_function] with the same `id` + pub fn get_function(&self, id: FuncId) -> Option { + self.cache.get(&id).copied() + } + + /// Callbacks can use this to make a function in the Hugr, if none already + /// exists for the same `id` - check using `get_function` first. + /// + /// # Panics /// - /// A second call with the same `id` will return the node from the first - /// call, without executing `body`. - pub fn make_function(&mut self, id: FuncId, body: impl Fn() -> Hugr) -> Node { - if let Some(n) = self.cache.get(&id) { - return *n; + /// if `make_function` has already been called with the same `id` + pub fn make_function(&mut self, id: FuncId, body: Hugr) -> Node { + match self.cache.entry(id.clone()) { + Entry::Occupied(_) => panic!("Key {id} already present"), + Entry::Vacant(ve) => *ve.insert(self.hugr.insert_hugr(self.hugr.root(), body).new_root), } - let h = body(); - let n = self.hugr.insert_hugr(self.hugr.root(), h).new_root; - self.cache.insert(id, n); - n } } From 1ac57a59d6b9a4f38ac7ae4ede0e882becbd610c Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Apr 2025 16:02:14 +0100 Subject: [PATCH 17/18] Generalize test to check lazy too --- hugr-passes/src/replace_types.rs | 56 ++++++++++++++++++++------------ 1 file changed, 36 insertions(+), 20 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index a43c3a37c..7467d25e5 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -669,7 +669,7 @@ mod test { list::{list_type, list_type_def, ListOp, ListValue}, }; use hugr_core::types::{PolyFuncType, Signature, SumType, Type, TypeArg, TypeBound, TypeRow}; - use hugr_core::{type_row, Extension, HugrView}; + use hugr_core::{type_row, Extension, Hugr, HugrView}; use itertools::Itertools; use rstest::rstest; @@ -1122,8 +1122,19 @@ mod test { h.validate_no_extensions().unwrap(); } - #[test] - fn op_to_call() { + fn make_read_func() -> Hugr { + lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { + FunctionBuilder::new( + "lowered_read", + PolyFuncType::new([TypeBound::Copyable.into()], sig), + ) + }) + .finish_hugr() + .unwrap() + } + + #[rstest] + fn op_to_call(#[values(false, true)] create_lazy: bool) { let e = ext(); let pv = e.get_type(PACKED_VEC).unwrap(); let inner = pv.instantiate([usize_t().into()]).unwrap(); @@ -1140,27 +1151,32 @@ mod test { .add_dataflow_op(read_op(&e, usize_t()), [inner, idx]) .unwrap(); let mut h = dfb.finish_hugr_with_outputs(res.outputs()).unwrap(); - let read_func = h - .insert_hugr( - h.root(), - lowered_read(Type::new_var_use(0, TypeBound::Copyable), |sig| { - FunctionBuilder::new( - "lowered_read", - PolyFuncType::new([TypeBound::Copyable.into()], sig), - ) - }) - .finish_hugr() - .unwrap(), - ) - .new_root; + let read_func = (!create_lazy).then(|| h.insert_hugr(h.root(), make_read_func()).new_root); let mut lw = lowerer(&e); - lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| { - Some(NodeTemplate::Call(read_func, args.to_owned())) - }); + if let Some(read_func) = read_func { + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, _| { + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + } else { + lw.replace_parametrized_op(e.get_op(READ).unwrap().as_ref(), move |args, rt| { + let name = "test_read_func".to_string(); + let read_func = rt + .get_function(name.clone()) + .unwrap_or_else(|| rt.make_function(name, make_read_func())); + Some(NodeTemplate::Call(read_func, args.to_owned())) + }); + } lw.run(&mut h).unwrap(); - assert_eq!(h.output_neighbours(read_func).count(), 2); + let [func_node] = h + .nodes() + .filter(|n| h.get_optype(*n).is_func_defn()) + .collect_array() + .unwrap(); + assert!(read_func.is_none_or(|rf| rf == func_node)); + + assert_eq!(h.output_neighbours(func_node).count(), 2); let ext_op_names = h .nodes() .filter_map(|n| h.get_optype(n).as_extension_op()) From d073f35ec1e46021c7d871c002c783bc6a8b5f51 Mon Sep 17 00:00:00 2001 From: Alan Lawrence Date: Fri, 25 Apr 2025 17:50:06 +0100 Subject: [PATCH 18/18] clippy --- hugr-passes/src/replace_types.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hugr-passes/src/replace_types.rs b/hugr-passes/src/replace_types.rs index 7467d25e5..7e313f4cb 100644 --- a/hugr-passes/src/replace_types.rs +++ b/hugr-passes/src/replace_types.rs @@ -542,11 +542,11 @@ pub struct CallbackHandler<'a, T> { deref: &'a T, } -impl<'a, T> Deref for CallbackHandler<'a, T> { +impl Deref for CallbackHandler<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { - &self.deref + self.deref } }