Skip to content

Commit 9ad9e6d

Browse files
acl-cqcdoug-q
andauthored
feat!: Extend LowerTypes pass to linearize by inserting copy/discard (#2018)
* Does not handle nonlocal edges. This is much more involved, not part of this PR. I note #1912 * handlers for arrays follow in #2023 BREAKING CHANGE: `OpReplacement` renamed to `NodeTemplate`. Note this is *not* a breaking change if this PR goes in the same release as #1989 (which introduced `OpReplacement`) --------- Co-authored-by: Douglas Wilson <douglas.wilson@quantinuum.com>
1 parent 079585e commit 9ad9e6d

File tree

2 files changed

+774
-28
lines changed

2 files changed

+774
-28
lines changed

hugr-passes/src/replace_types.rs

Lines changed: 126 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,47 @@
11
#![allow(clippy::type_complexity)]
22
#![warn(missing_docs)]
3-
//! Replace types with other types across the Hugr.
3+
//! Replace types with other types across the Hugr. See [ReplaceTypes] and [Linearizer].
44
//!
5-
//! Parametrized types and ops will be reparametrized taking into account the replacements,
6-
//! but any ops taking/returning the replaced types *not* as a result of parametrization,
7-
//! will also need to be replaced - see [ReplaceTypes::replace_op]. (Similarly [Const]s.)
5+
use std::borrow::Cow;
86
use std::collections::HashMap;
97
use std::sync::Arc;
108

119
use thiserror::Error;
1210

11+
use hugr_core::builder::{BuildError, BuildHandle, Dataflow};
1312
use hugr_core::extension::{ExtensionId, OpDef, SignatureError, TypeDef};
1413
use hugr_core::hugr::hugrmut::HugrMut;
1514
use hugr_core::ops::constant::{OpaqueValue, Sum};
15+
use hugr_core::ops::handle::DataflowOpID;
1616
use hugr_core::ops::{
1717
AliasDefn, Call, CallIndirect, Case, Conditional, Const, DataflowBlock, ExitBlock, ExtensionOp,
18-
FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpType, Output, Tag, TailLoop, Value,
19-
CFG, DFG,
18+
FuncDecl, FuncDefn, Input, LoadConstant, LoadFunction, OpTrait, OpType, Output, Tag, TailLoop,
19+
Value, CFG, DFG,
2020
};
21-
use hugr_core::types::{CustomType, Transformable, Type, TypeArg, TypeEnum, TypeTransformer};
22-
use hugr_core::{Hugr, Node};
21+
use hugr_core::types::{
22+
CustomType, Signature, Transformable, Type, TypeArg, TypeEnum, TypeTransformer,
23+
};
24+
use hugr_core::{Hugr, HugrView, Node, Wire};
2325

2426
use crate::validation::{ValidatePassError, ValidationLevel};
2527

26-
/// A thing with which an Op (i.e. node) can be replaced
28+
mod linearize;
29+
pub use linearize::{CallbackHandler, DelegatingLinearizer, LinearizeError, Linearizer};
30+
31+
/// A recipe for creating a dataflow Node - as a new child of a [DataflowParent]
32+
/// or in order to replace an existing node.
33+
///
34+
/// [DataflowParent]: hugr_core::ops::OpTag::DataflowParent
2735
#[derive(Clone, Debug, PartialEq)]
28-
pub enum OpReplacement {
29-
/// Keep the same node, change only the op (updating types of inputs/outputs)
36+
pub enum NodeTemplate {
37+
/// A single node - so if replacing an existing node, change only the op
3038
SingleOp(OpType),
31-
/// Defines a sub-Hugr to splice in place of the op - a [CFG], [Conditional], [DFG]
32-
/// or [TailLoop], which must have the same inputs and outputs as the original op,
33-
/// modulo replacement.
39+
/// Defines a sub-Hugr to insert, whose root becomes (or replaces) the desired Node.
40+
/// The root must be a [CFG], [Conditional], [DFG] or [TailLoop].
3441
// Not a FuncDefn, nor Case/DataflowBlock
35-
/// Note this will be of limited use before [monomorphization](super::monomorphize()) because
36-
/// the sub-Hugr will not be able to use type variables present in the op.
42+
/// Note this will be of limited use before [monomorphization](super::monomorphize())
43+
/// because the new subtree will not be able to use type variables present in the
44+
/// parent Hugr or previous op.
3745
// TODO: store also a vec<TypeParam>, and update Hugr::validate to take &[TypeParam]s
3846
// (defaulting to empty list) - see https://github.com/CQCL/hugr/issues/709
3947
CompoundOp(Box<Hugr>),
@@ -42,12 +50,33 @@ pub enum OpReplacement {
4250
// So client should add the functions before replacement, then remove unused ones afterwards.)
4351
}
4452

45-
impl OpReplacement {
53+
impl NodeTemplate {
54+
/// Adds this instance to the specified [HugrMut] as a new node or subtree under a
55+
/// given parent, returning the unique new child (of that parent) thus created
56+
pub fn add_hugr(self, hugr: &mut impl HugrMut, parent: Node) -> Node {
57+
match self {
58+
NodeTemplate::SingleOp(op_type) => hugr.add_node_with_parent(parent, op_type),
59+
NodeTemplate::CompoundOp(new_h) => hugr.insert_hugr(parent, *new_h).new_root,
60+
}
61+
}
62+
63+
/// Adds this instance to the specified [Dataflow] builder as a new node or subtree
64+
pub fn add(
65+
self,
66+
dfb: &mut impl Dataflow,
67+
inputs: impl IntoIterator<Item = Wire>,
68+
) -> Result<BuildHandle<DataflowOpID>, BuildError> {
69+
match self {
70+
NodeTemplate::SingleOp(opty) => dfb.add_dataflow_op(opty, inputs),
71+
NodeTemplate::CompoundOp(h) => dfb.add_hugr_with_wires(*h, inputs),
72+
}
73+
}
74+
4675
fn replace(&self, hugr: &mut impl HugrMut, n: Node) {
4776
assert_eq!(hugr.children(n).count(), 0);
4877
let new_optype = match self.clone() {
49-
OpReplacement::SingleOp(op_type) => op_type,
50-
OpReplacement::CompoundOp(new_h) => {
78+
NodeTemplate::SingleOp(op_type) => op_type,
79+
NodeTemplate::CompoundOp(new_h) => {
5180
let new_root = hugr.insert_hugr(n, *new_h).new_root;
5281
let children = hugr.children(new_root).collect::<Vec<_>>();
5382
let root_opty = hugr.remove_node(new_root);
@@ -59,16 +88,50 @@ impl OpReplacement {
5988
};
6089
*hugr.optype_mut(n) = new_optype;
6190
}
91+
92+
fn signature(&self) -> Option<Cow<'_, Signature>> {
93+
match self {
94+
NodeTemplate::SingleOp(op_type) => op_type,
95+
NodeTemplate::CompoundOp(hugr) => hugr.root_type(),
96+
}
97+
.dataflow_signature()
98+
}
6299
}
63100

64101
/// A configuration of what types, ops, and constants should be replaced with what.
65102
/// May be applied to a Hugr via [Self::run].
103+
///
104+
/// Parametrized types and ops will be reparametrized taking into account the
105+
/// replacements, but any ops taking/returning the replaced types *not* as a result of
106+
/// parametrization, will also need to be replaced - see [Self::replace_op].
107+
/// Similarly [Const]s.
108+
///
109+
/// Types that are [Copyable](hugr_core::types::TypeBound::Copyable) may also be replaced
110+
/// with types that are not, see [Linearizer].
111+
///
112+
/// Note that although this pass may be used before [monomorphization], there are some
113+
/// limitations (that do not apply if done after [monomorphization]):
114+
/// * [NodeTemplate::CompoundOp] only works for operations that do not use type variables
115+
/// * "Overrides" of specific instantiations of polymorphic types will not be detected if
116+
/// the instantiations are created inside polymorphic functions. For example, suppose
117+
/// we [Self::replace_type] type `A` with `X`, [Self::replace_parametrized_type]
118+
/// container `MyList` with `List`, and [Self::replace_type] `MyList<A>` with
119+
/// `SpecialListOfXs`. If a function `foo` polymorphic over a type variable `T` dealing
120+
/// with `MyList<T>`s, that is called with type argument `A`, then `foo<T>` will be
121+
/// updated to deal with `List<T>`s and the call `foo<A>` updated to `foo<X>`, but this
122+
/// will still result in using `List<X>` rather than `SpecialListOfXs`. (However this
123+
/// would be fine *after* [monomorphization]: the monomorphic definition of `foo_A`
124+
/// would use `SpecialListOfXs`.)
125+
/// * See also limitations noted for [Linearizer].
126+
///
127+
/// [monomorphization]: super::monomorphize()
66128
#[derive(Clone, Default)]
67129
pub struct ReplaceTypes {
68130
type_map: HashMap<CustomType, Type>,
69131
param_types: HashMap<ParametricType, Arc<dyn Fn(&[TypeArg]) -> Option<Type>>>,
70-
op_map: HashMap<OpHashWrapper, OpReplacement>,
71-
param_ops: HashMap<ParametricOp, Arc<dyn Fn(&[TypeArg]) -> Option<OpReplacement>>>,
132+
linearize: DelegatingLinearizer,
133+
op_map: HashMap<OpHashWrapper, NodeTemplate>,
134+
param_ops: HashMap<ParametricOp, Arc<dyn Fn(&[TypeArg]) -> Option<NodeTemplate>>>,
72135
consts: HashMap<
73136
CustomType,
74137
Arc<dyn Fn(&OpaqueValue, &ReplaceTypes) -> Result<Value, ReplaceTypesError>>,
@@ -109,6 +172,8 @@ pub enum ReplaceTypesError {
109172
SignatureError(#[from] SignatureError),
110173
#[error(transparent)]
111174
ValidationError(#[from] ValidatePassError),
175+
#[error(transparent)]
176+
LinearizeError(#[from] LinearizeError),
112177
}
113178

114179
impl ReplaceTypes {
@@ -157,16 +222,33 @@ impl ReplaceTypes {
157222
// (depending on arguments - i.e. if src's TypeDefBound is anything other than
158223
// `TypeDefBound::Explicit(TypeBound::Copyable)`) but that seems an annoying
159224
// overapproximation. Moreover, these depend upon the *return type* of the Fn.
225+
// It would be too awkward to require:
226+
// dest_fn: impl Fn(&TypeArg) -> (Type,
227+
// Fn(&Linearizer) -> NodeTemplate, // copy
228+
// Fn(&Linearizer) -> NodeTemplate)` // discard
160229
self.param_types.insert(src.into(), Arc::new(dest_fn));
161230
}
162231

232+
/// Allows to configure how to deal with types/wires that were [Copyable]
233+
/// but have become linear as a result of type-changing. Specifically,
234+
/// the [Linearizer] is used whenever lowering produces an outport which both
235+
/// * has a non-[Copyable] type - perhaps a direct substitution, or perhaps e.g.
236+
/// as a result of changing the element type of a collection such as an [`array`]
237+
/// * has other than one connected inport,
238+
///
239+
/// [Copyable]: hugr_core::types::TypeBound::Copyable
240+
/// [`array`]: hugr_core::std_extensions::collections::array::array_type
241+
pub fn linearizer(&mut self) -> &mut DelegatingLinearizer {
242+
&mut self.linearize
243+
}
244+
163245
/// Configures this instance to change occurrences of `src` to `dest`.
164246
/// Note that if `src` is an instance of a *parametrized* [OpDef], this takes
165247
/// precedence over [Self::replace_parametrized_op] where the `src`s overlap. Thus,
166248
/// this should only be used on already-*[monomorphize](super::monomorphize())d*
167249
/// Hugrs, as substitution (parametric polymorphism) happening later will not respect
168250
/// this replacement.
169-
pub fn replace_op(&mut self, src: &ExtensionOp, dest: OpReplacement) {
251+
pub fn replace_op(&mut self, src: &ExtensionOp, dest: NodeTemplate) {
170252
self.op_map.insert(OpHashWrapper::from(src), dest);
171253
}
172254

@@ -179,7 +261,7 @@ impl ReplaceTypes {
179261
pub fn replace_parametrized_op(
180262
&mut self,
181263
src: &OpDef,
182-
dest_fn: impl Fn(&[TypeArg]) -> Option<OpReplacement> + 'static,
264+
dest_fn: impl Fn(&[TypeArg]) -> Option<NodeTemplate> + 'static,
183265
) {
184266
self.param_ops.insert(src.into(), Arc::new(dest_fn));
185267
}
@@ -221,6 +303,22 @@ impl ReplaceTypes {
221303
let mut changed = false;
222304
for n in hugr.nodes().collect::<Vec<_>>() {
223305
changed |= self.change_node(hugr, n)?;
306+
let new_dfsig = hugr.get_optype(n).dataflow_signature();
307+
if let Some(new_sig) = new_dfsig
308+
.filter(|_| changed && n != hugr.root())
309+
.map(Cow::into_owned)
310+
{
311+
for outp in new_sig.output_ports() {
312+
if !new_sig.out_port_type(outp).unwrap().copyable() {
313+
let targets = hugr.linked_inputs(n, outp).collect::<Vec<_>>();
314+
if targets.len() != 1 {
315+
hugr.disconnect(n, outp);
316+
let src = Wire::new(n, outp);
317+
self.linearize.insert_copy_discard(hugr, src, &targets)?;
318+
}
319+
}
320+
}
321+
}
224322
}
225323
Ok(changed)
226324
}
@@ -452,7 +550,7 @@ mod test {
452550
use hugr_core::{hugr::IdentList, type_row, Extension, HugrView};
453551
use itertools::Itertools;
454552

455-
use super::{handlers::list_const, OpReplacement, ReplaceTypes};
553+
use super::{handlers::list_const, NodeTemplate, ReplaceTypes};
456554

457555
const PACKED_VEC: &str = "PackedVec";
458556
const READ: &str = "read";
@@ -513,7 +611,7 @@ mod test {
513611
}
514612

515613
fn lowerer(ext: &Arc<Extension>) -> ReplaceTypes {
516-
fn lowered_read(args: &[TypeArg]) -> Option<OpReplacement> {
614+
fn lowered_read(args: &[TypeArg]) -> Option<NodeTemplate> {
517615
let ty = just_elem_type(args);
518616
let mut dfb = DFGBuilder::new(inout_sig(
519617
vec![array_type(64, ty.clone()), i64_t()],
@@ -532,7 +630,7 @@ mod test {
532630
let [res] = dfb
533631
.build_unwrap_sum(1, option_type(Type::from(ty.clone())), opt)
534632
.unwrap();
535-
Some(OpReplacement::CompoundOp(Box::new(
633+
Some(NodeTemplate::CompoundOp(Box::new(
536634
dfb.finish_hugr_with_outputs([res]).unwrap(),
537635
)))
538636
}
@@ -545,7 +643,7 @@ mod test {
545643
);
546644
lw.replace_op(
547645
&read_op(ext, bool_t()),
548-
OpReplacement::SingleOp(
646+
NodeTemplate::SingleOp(
549647
ExtensionOp::new(ext.get_op("lowered_read_bool").unwrap().clone(), [])
550648
.unwrap()
551649
.into(),
@@ -824,7 +922,7 @@ mod test {
824922
e.get_op(READ).unwrap().as_ref(),
825923
Box::new(|args: &[TypeArg]| {
826924
option_contents(just_elem_type(args)).map(|elem| {
827-
OpReplacement::SingleOp(
925+
NodeTemplate::SingleOp(
828926
ListOp::get
829927
.with_type(elem)
830928
.to_extension_op()

0 commit comments

Comments
 (0)