Skip to content

Commit 649589c

Browse files
authored
feat: Add array repeat and scan ops (#1633)
Closes #1627
1 parent 6bd094f commit 649589c

File tree

4 files changed

+659
-1
lines changed

4 files changed

+659
-1
lines changed

hugr-core/src/extension/prelude.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ lazy_static! {
101101
NoopDef.add_to_extension(&mut prelude).unwrap();
102102
LiftDef.add_to_extension(&mut prelude).unwrap();
103103
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
104+
array::ArrayScanDef.add_to_extension(&mut prelude).unwrap();
104105
prelude
105106
};
106107
/// An extension registry containing only the prelude

hugr-core/src/extension/prelude/array.rs

Lines changed: 292 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
use std::str::FromStr;
2+
3+
use itertools::Itertools;
14
use strum_macros::EnumIter;
25
use strum_macros::EnumString;
36
use strum_macros::IntoStaticStr;
@@ -17,8 +20,10 @@ use crate::ops::ExtensionOp;
1720
use crate::ops::NamedOp;
1821
use crate::ops::OpName;
1922
use crate::type_row;
23+
use crate::types::FuncTypeBase;
2024
use crate::types::FuncValueType;
2125

26+
use crate::types::RowVariable;
2227
use crate::types::TypeBound;
2328

2429
use crate::types::Type;
@@ -28,6 +33,7 @@ use crate::extension::SignatureError;
2833
use crate::types::PolyFuncTypeRV;
2934

3035
use crate::types::type_param::TypeArg;
36+
use crate::types::TypeRV;
3137
use crate::Extension;
3238

3339
use super::PRELUDE_ID;
@@ -46,6 +52,7 @@ pub enum ArrayOpDef {
4652
pop_left,
4753
pop_right,
4854
discard_empty,
55+
repeat,
4956
}
5057

5158
/// Static parameters for array operations. Includes array size. Type is part of the type scheme.
@@ -118,6 +125,14 @@ impl ArrayOpDef {
118125
let standard_params = vec![TypeParam::max_nat(), TypeBound::Any.into()];
119126

120127
match self {
128+
repeat => {
129+
let func =
130+
Type::new_function(FuncValueType::new(type_row![], elem_ty_var.clone()));
131+
PolyFuncTypeRV::new(
132+
standard_params,
133+
FuncValueType::new(vec![func], array_ty.clone()),
134+
)
135+
}
121136
get => {
122137
let params = vec![TypeParam::max_nat(), TypeBound::Copyable.into()];
123138
let copy_elem_ty = Type::new_var_use(1, TypeBound::Copyable);
@@ -179,6 +194,10 @@ impl MakeOpDef for ArrayOpDef {
179194
fn description(&self) -> String {
180195
match self {
181196
ArrayOpDef::new_array => "Create a new array from elements",
197+
ArrayOpDef::repeat => {
198+
"Creates a new array whose elements are initialised by calling \
199+
the given function n times"
200+
}
182201
ArrayOpDef::get => "Get an element from an array",
183202
ArrayOpDef::set => "Set an element in an array",
184203
ArrayOpDef::swap => "Swap two elements in an array",
@@ -246,7 +265,7 @@ impl MakeExtensionOp for ArrayOp {
246265
);
247266
vec![ty_arg]
248267
}
249-
new_array | pop_left | pop_right | get | set | swap => {
268+
new_array | repeat | pop_left | pop_right | get | set | swap => {
250269
vec![TypeArg::BoundedNat { n: self.size }, ty_arg]
251270
}
252271
}
@@ -312,6 +331,192 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
312331
op.to_extension_op().unwrap()
313332
}
314333

334+
/// Name of the operation for the combined map/fold operation
335+
pub const ARRAY_SCAN_OP_ID: OpName = OpName::new_inline("scan");
336+
337+
/// Definition of the array scan op.
338+
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq)]
339+
pub struct ArrayScanDef;
340+
341+
impl NamedOp for ArrayScanDef {
342+
fn name(&self) -> OpName {
343+
ARRAY_SCAN_OP_ID
344+
}
345+
}
346+
347+
impl FromStr for ArrayScanDef {
348+
type Err = ();
349+
350+
fn from_str(s: &str) -> Result<Self, Self::Err> {
351+
if s == ArrayScanDef.name() {
352+
Ok(Self)
353+
} else {
354+
Err(())
355+
}
356+
}
357+
}
358+
359+
impl ArrayScanDef {
360+
/// To avoid recursion when defining the extension, take the type definition as an argument.
361+
fn signature_from_def(&self, array_def: &TypeDef) -> SignatureFunc {
362+
// array<N, T1>, (T1, *A -> T2, *A), -> array<N, T2>, *A
363+
let params = vec![
364+
TypeParam::max_nat(),
365+
TypeBound::Any.into(),
366+
TypeBound::Any.into(),
367+
TypeParam::new_list(TypeBound::Any),
368+
];
369+
let n = TypeArg::new_var_use(0, TypeParam::max_nat());
370+
let t1 = Type::new_var_use(1, TypeBound::Any);
371+
let t2 = Type::new_var_use(2, TypeBound::Any);
372+
let s = TypeRV::new_row_var_use(3, TypeBound::Any);
373+
PolyFuncTypeRV::new(
374+
params,
375+
FuncTypeBase::<RowVariable>::new(
376+
vec![
377+
instantiate(array_def, n.clone(), t1.clone()).into(),
378+
Type::new_function(FuncTypeBase::<RowVariable>::new(
379+
vec![t1.into(), s.clone()],
380+
vec![t2.clone().into(), s.clone()],
381+
))
382+
.into(),
383+
s.clone(),
384+
],
385+
vec![instantiate(array_def, n, t2).into(), s],
386+
),
387+
)
388+
.into()
389+
}
390+
}
391+
392+
impl MakeOpDef for ArrayScanDef {
393+
fn from_def(op_def: &OpDef) -> Result<Self, OpLoadError>
394+
where
395+
Self: Sized,
396+
{
397+
crate::extension::simple_op::try_from_name(op_def.name(), op_def.extension())
398+
}
399+
400+
fn signature(&self) -> SignatureFunc {
401+
self.signature_from_def(array_type_def())
402+
}
403+
404+
fn extension(&self) -> ExtensionId {
405+
PRELUDE_ID
406+
}
407+
408+
fn description(&self) -> String {
409+
"A combination of map and foldl. Applies a function to each element \
410+
of the array with an accumulator that is passed through from start to \
411+
finish. Returns the resulting array and the final state of the \
412+
accumulator."
413+
.into()
414+
}
415+
416+
/// Add an operation implemented as a [MakeOpDef], which can provide the data
417+
/// required to define an [OpDef], to an extension.
418+
//
419+
// This method is re-defined here since we need to pass the array type def while
420+
// computing the signature, to avoid recursive loops initializing the extension.
421+
fn add_to_extension(
422+
&self,
423+
extension: &mut Extension,
424+
) -> Result<(), crate::extension::ExtensionBuildError> {
425+
let sig = self.signature_from_def(extension.get_type(ARRAY_TYPE_NAME).unwrap());
426+
let def = extension.add_op(self.name(), self.description(), sig)?;
427+
428+
self.post_opdef(def);
429+
430+
Ok(())
431+
}
432+
}
433+
434+
/// Definition of the array scan op.
435+
#[derive(Clone, Debug, PartialEq)]
436+
pub struct ArrayScan {
437+
/// The element type of the input array.
438+
src_ty: Type,
439+
/// The target element type of the output array.
440+
tgt_ty: Type,
441+
/// The accumulator types.
442+
acc_tys: Vec<Type>,
443+
/// Size of the array.
444+
size: u64,
445+
}
446+
447+
impl ArrayScan {
448+
fn new(src_ty: Type, tgt_ty: Type, acc_tys: Vec<Type>, size: u64) -> Self {
449+
ArrayScan {
450+
src_ty,
451+
tgt_ty,
452+
acc_tys,
453+
size,
454+
}
455+
}
456+
}
457+
458+
impl NamedOp for ArrayScan {
459+
fn name(&self) -> OpName {
460+
ARRAY_SCAN_OP_ID
461+
}
462+
}
463+
464+
impl MakeExtensionOp for ArrayScan {
465+
fn from_extension_op(ext_op: &ExtensionOp) -> Result<Self, OpLoadError>
466+
where
467+
Self: Sized,
468+
{
469+
let def = ArrayScanDef::from_def(ext_op.def())?;
470+
def.instantiate(ext_op.args())
471+
}
472+
473+
fn type_args(&self) -> Vec<TypeArg> {
474+
vec![
475+
TypeArg::BoundedNat { n: self.size },
476+
self.src_ty.clone().into(),
477+
self.tgt_ty.clone().into(),
478+
TypeArg::Sequence {
479+
elems: self.acc_tys.clone().into_iter().map_into().collect(),
480+
},
481+
]
482+
}
483+
}
484+
485+
impl MakeRegisteredOp for ArrayScan {
486+
fn extension_id(&self) -> ExtensionId {
487+
PRELUDE_ID
488+
}
489+
490+
fn registry<'s, 'r: 's>(&'s self) -> &'r crate::extension::ExtensionRegistry {
491+
&PRELUDE_REGISTRY
492+
}
493+
}
494+
495+
impl HasDef for ArrayScan {
496+
type Def = ArrayScanDef;
497+
}
498+
499+
impl HasConcrete for ArrayScanDef {
500+
type Concrete = ArrayScan;
501+
502+
fn instantiate(&self, type_args: &[TypeArg]) -> Result<Self::Concrete, OpLoadError> {
503+
match type_args {
504+
[TypeArg::BoundedNat { n }, TypeArg::Type { ty: src_ty }, TypeArg::Type { ty: tgt_ty }, TypeArg::Sequence { elems: acc_tys }] =>
505+
{
506+
let acc_tys: Result<_, OpLoadError> = acc_tys
507+
.iter()
508+
.map(|acc_ty| match acc_ty {
509+
TypeArg::Type { ty } => Ok(ty.clone()),
510+
_ => Err(SignatureError::InvalidTypeArgs.into()),
511+
})
512+
.collect();
513+
Ok(ArrayScan::new(src_ty.clone(), tgt_ty.clone(), acc_tys?, *n))
514+
}
515+
_ => Err(SignatureError::InvalidTypeArgs.into()),
516+
}
517+
}
518+
}
519+
315520
#[cfg(test)]
316521
mod tests {
317522
use strum::IntoEnumIterator;
@@ -320,6 +525,7 @@ mod tests {
320525
builder::{inout_sig, DFGBuilder, Dataflow, DataflowHugr},
321526
extension::prelude::{BOOL_T, QB_T},
322527
ops::{OpTrait, OpType},
528+
types::Signature,
323529
};
324530

325531
use super::*;
@@ -459,4 +665,89 @@ mod tests {
459665
)
460666
);
461667
}
668+
669+
#[test]
670+
fn test_repeat() {
671+
let size = 2;
672+
let element_ty = QB_T;
673+
let op = ArrayOpDef::repeat.to_concrete(element_ty.clone(), size);
674+
675+
let optype: OpType = op.into();
676+
677+
let sig = optype.dataflow_signature().unwrap();
678+
679+
assert_eq!(
680+
sig.io(),
681+
(
682+
&vec![Type::new_function(Signature::new(vec![], vec![QB_T]))].into(),
683+
&vec![array_type(size, element_ty.clone())].into(),
684+
)
685+
);
686+
}
687+
688+
#[test]
689+
fn test_scan_def() {
690+
let op = ArrayScan::new(BOOL_T, QB_T, vec![USIZE_T], 2);
691+
let optype: OpType = op.clone().into();
692+
let new_op: ArrayScan = optype.cast().unwrap();
693+
assert_eq!(new_op, op);
694+
}
695+
696+
#[test]
697+
fn test_scan_map() {
698+
let size = 2;
699+
let src_ty = QB_T;
700+
let tgt_ty = BOOL_T;
701+
702+
let op = ArrayScan::new(src_ty.clone(), tgt_ty.clone(), vec![], size);
703+
let optype: OpType = op.into();
704+
let sig = optype.dataflow_signature().unwrap();
705+
706+
assert_eq!(
707+
sig.io(),
708+
(
709+
&vec![
710+
array_type(size, src_ty.clone()),
711+
Type::new_function(Signature::new(vec![src_ty], vec![tgt_ty.clone()]))
712+
]
713+
.into(),
714+
&vec![array_type(size, tgt_ty)].into(),
715+
)
716+
);
717+
}
718+
719+
#[test]
720+
fn test_scan_accs() {
721+
let size = 2;
722+
let src_ty = QB_T;
723+
let tgt_ty = BOOL_T;
724+
let acc_ty1 = USIZE_T;
725+
let acc_ty2 = QB_T;
726+
727+
let op = ArrayScan::new(
728+
src_ty.clone(),
729+
tgt_ty.clone(),
730+
vec![acc_ty1.clone(), acc_ty2.clone()],
731+
size,
732+
);
733+
let optype: OpType = op.into();
734+
let sig = optype.dataflow_signature().unwrap();
735+
736+
assert_eq!(
737+
sig.io(),
738+
(
739+
&vec![
740+
array_type(size, src_ty.clone()),
741+
Type::new_function(Signature::new(
742+
vec![src_ty, acc_ty1.clone(), acc_ty2.clone()],
743+
vec![tgt_ty.clone(), acc_ty1.clone(), acc_ty2.clone()]
744+
)),
745+
acc_ty1.clone(),
746+
acc_ty2.clone()
747+
]
748+
.into(),
749+
&vec![array_type(size, tgt_ty), acc_ty1, acc_ty2].into(),
750+
)
751+
);
752+
}
462753
}

0 commit comments

Comments
 (0)