1
+ use std:: str:: FromStr ;
2
+
3
+ use itertools:: Itertools ;
1
4
use strum_macros:: EnumIter ;
2
5
use strum_macros:: EnumString ;
3
6
use strum_macros:: IntoStaticStr ;
@@ -17,8 +20,10 @@ use crate::ops::ExtensionOp;
17
20
use crate :: ops:: NamedOp ;
18
21
use crate :: ops:: OpName ;
19
22
use crate :: type_row;
23
+ use crate :: types:: FuncTypeBase ;
20
24
use crate :: types:: FuncValueType ;
21
25
26
+ use crate :: types:: RowVariable ;
22
27
use crate :: types:: TypeBound ;
23
28
24
29
use crate :: types:: Type ;
@@ -28,6 +33,7 @@ use crate::extension::SignatureError;
28
33
use crate :: types:: PolyFuncTypeRV ;
29
34
30
35
use crate :: types:: type_param:: TypeArg ;
36
+ use crate :: types:: TypeRV ;
31
37
use crate :: Extension ;
32
38
33
39
use super :: PRELUDE_ID ;
@@ -46,6 +52,7 @@ pub enum ArrayOpDef {
46
52
pop_left,
47
53
pop_right,
48
54
discard_empty,
55
+ repeat,
49
56
}
50
57
51
58
/// Static parameters for array operations. Includes array size. Type is part of the type scheme.
@@ -118,6 +125,14 @@ impl ArrayOpDef {
118
125
let standard_params = vec ! [ TypeParam :: max_nat( ) , TypeBound :: Any . into( ) ] ;
119
126
120
127
match self {
128
+ repeat => {
129
+ let func =
130
+ Type :: new_function ( FuncValueType :: new ( type_row ! [ ] , elem_ty_var. clone ( ) ) ) ;
131
+ PolyFuncTypeRV :: new (
132
+ standard_params,
133
+ FuncValueType :: new ( vec ! [ func] , array_ty. clone ( ) ) ,
134
+ )
135
+ }
121
136
get => {
122
137
let params = vec ! [ TypeParam :: max_nat( ) , TypeBound :: Copyable . into( ) ] ;
123
138
let copy_elem_ty = Type :: new_var_use ( 1 , TypeBound :: Copyable ) ;
@@ -179,6 +194,10 @@ impl MakeOpDef for ArrayOpDef {
179
194
fn description ( & self ) -> String {
180
195
match self {
181
196
ArrayOpDef :: new_array => "Create a new array from elements" ,
197
+ ArrayOpDef :: repeat => {
198
+ "Creates a new array whose elements are initialised by calling \
199
+ the given function n times"
200
+ }
182
201
ArrayOpDef :: get => "Get an element from an array" ,
183
202
ArrayOpDef :: set => "Set an element in an array" ,
184
203
ArrayOpDef :: swap => "Swap two elements in an array" ,
@@ -246,7 +265,7 @@ impl MakeExtensionOp for ArrayOp {
246
265
) ;
247
266
vec ! [ ty_arg]
248
267
}
249
- new_array | pop_left | pop_right | get | set | swap => {
268
+ new_array | repeat | pop_left | pop_right | get | set | swap => {
250
269
vec ! [ TypeArg :: BoundedNat { n: self . size } , ty_arg]
251
270
}
252
271
}
@@ -312,6 +331,192 @@ pub fn new_array_op(element_ty: Type, size: u64) -> ExtensionOp {
312
331
op. to_extension_op ( ) . unwrap ( )
313
332
}
314
333
334
+ /// Name of the operation for the combined map/fold operation
335
+ pub const ARRAY_SCAN_OP_ID : OpName = OpName :: new_inline ( "scan" ) ;
336
+
337
+ /// Definition of the array scan op.
338
+ #[ derive( Clone , Copy , Debug , Hash , PartialEq , Eq ) ]
339
+ pub struct ArrayScanDef ;
340
+
341
+ impl NamedOp for ArrayScanDef {
342
+ fn name ( & self ) -> OpName {
343
+ ARRAY_SCAN_OP_ID
344
+ }
345
+ }
346
+
347
+ impl FromStr for ArrayScanDef {
348
+ type Err = ( ) ;
349
+
350
+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
351
+ if s == ArrayScanDef . name ( ) {
352
+ Ok ( Self )
353
+ } else {
354
+ Err ( ( ) )
355
+ }
356
+ }
357
+ }
358
+
359
+ impl ArrayScanDef {
360
+ /// To avoid recursion when defining the extension, take the type definition as an argument.
361
+ fn signature_from_def ( & self , array_def : & TypeDef ) -> SignatureFunc {
362
+ // array<N, T1>, (T1, *A -> T2, *A), -> array<N, T2>, *A
363
+ let params = vec ! [
364
+ TypeParam :: max_nat( ) ,
365
+ TypeBound :: Any . into( ) ,
366
+ TypeBound :: Any . into( ) ,
367
+ TypeParam :: new_list( TypeBound :: Any ) ,
368
+ ] ;
369
+ let n = TypeArg :: new_var_use ( 0 , TypeParam :: max_nat ( ) ) ;
370
+ let t1 = Type :: new_var_use ( 1 , TypeBound :: Any ) ;
371
+ let t2 = Type :: new_var_use ( 2 , TypeBound :: Any ) ;
372
+ let s = TypeRV :: new_row_var_use ( 3 , TypeBound :: Any ) ;
373
+ PolyFuncTypeRV :: new (
374
+ params,
375
+ FuncTypeBase :: < RowVariable > :: new (
376
+ vec ! [
377
+ instantiate( array_def, n. clone( ) , t1. clone( ) ) . into( ) ,
378
+ Type :: new_function( FuncTypeBase :: <RowVariable >:: new(
379
+ vec![ t1. into( ) , s. clone( ) ] ,
380
+ vec![ t2. clone( ) . into( ) , s. clone( ) ] ,
381
+ ) )
382
+ . into( ) ,
383
+ s. clone( ) ,
384
+ ] ,
385
+ vec ! [ instantiate( array_def, n, t2) . into( ) , s] ,
386
+ ) ,
387
+ )
388
+ . into ( )
389
+ }
390
+ }
391
+
392
+ impl MakeOpDef for ArrayScanDef {
393
+ fn from_def ( op_def : & OpDef ) -> Result < Self , OpLoadError >
394
+ where
395
+ Self : Sized ,
396
+ {
397
+ crate :: extension:: simple_op:: try_from_name ( op_def. name ( ) , op_def. extension ( ) )
398
+ }
399
+
400
+ fn signature ( & self ) -> SignatureFunc {
401
+ self . signature_from_def ( array_type_def ( ) )
402
+ }
403
+
404
+ fn extension ( & self ) -> ExtensionId {
405
+ PRELUDE_ID
406
+ }
407
+
408
+ fn description ( & self ) -> String {
409
+ "A combination of map and foldl. Applies a function to each element \
410
+ of the array with an accumulator that is passed through from start to \
411
+ finish. Returns the resulting array and the final state of the \
412
+ accumulator."
413
+ . into ( )
414
+ }
415
+
416
+ /// Add an operation implemented as a [MakeOpDef], which can provide the data
417
+ /// required to define an [OpDef], to an extension.
418
+ //
419
+ // This method is re-defined here since we need to pass the array type def while
420
+ // computing the signature, to avoid recursive loops initializing the extension.
421
+ fn add_to_extension (
422
+ & self ,
423
+ extension : & mut Extension ,
424
+ ) -> Result < ( ) , crate :: extension:: ExtensionBuildError > {
425
+ let sig = self . signature_from_def ( extension. get_type ( ARRAY_TYPE_NAME ) . unwrap ( ) ) ;
426
+ let def = extension. add_op ( self . name ( ) , self . description ( ) , sig) ?;
427
+
428
+ self . post_opdef ( def) ;
429
+
430
+ Ok ( ( ) )
431
+ }
432
+ }
433
+
434
+ /// Definition of the array scan op.
435
+ #[ derive( Clone , Debug , PartialEq ) ]
436
+ pub struct ArrayScan {
437
+ /// The element type of the input array.
438
+ src_ty : Type ,
439
+ /// The target element type of the output array.
440
+ tgt_ty : Type ,
441
+ /// The accumulator types.
442
+ acc_tys : Vec < Type > ,
443
+ /// Size of the array.
444
+ size : u64 ,
445
+ }
446
+
447
+ impl ArrayScan {
448
+ fn new ( src_ty : Type , tgt_ty : Type , acc_tys : Vec < Type > , size : u64 ) -> Self {
449
+ ArrayScan {
450
+ src_ty,
451
+ tgt_ty,
452
+ acc_tys,
453
+ size,
454
+ }
455
+ }
456
+ }
457
+
458
+ impl NamedOp for ArrayScan {
459
+ fn name ( & self ) -> OpName {
460
+ ARRAY_SCAN_OP_ID
461
+ }
462
+ }
463
+
464
+ impl MakeExtensionOp for ArrayScan {
465
+ fn from_extension_op ( ext_op : & ExtensionOp ) -> Result < Self , OpLoadError >
466
+ where
467
+ Self : Sized ,
468
+ {
469
+ let def = ArrayScanDef :: from_def ( ext_op. def ( ) ) ?;
470
+ def. instantiate ( ext_op. args ( ) )
471
+ }
472
+
473
+ fn type_args ( & self ) -> Vec < TypeArg > {
474
+ vec ! [
475
+ TypeArg :: BoundedNat { n: self . size } ,
476
+ self . src_ty. clone( ) . into( ) ,
477
+ self . tgt_ty. clone( ) . into( ) ,
478
+ TypeArg :: Sequence {
479
+ elems: self . acc_tys. clone( ) . into_iter( ) . map_into( ) . collect( ) ,
480
+ } ,
481
+ ]
482
+ }
483
+ }
484
+
485
+ impl MakeRegisteredOp for ArrayScan {
486
+ fn extension_id ( & self ) -> ExtensionId {
487
+ PRELUDE_ID
488
+ }
489
+
490
+ fn registry < ' s , ' r : ' s > ( & ' s self ) -> & ' r crate :: extension:: ExtensionRegistry {
491
+ & PRELUDE_REGISTRY
492
+ }
493
+ }
494
+
495
+ impl HasDef for ArrayScan {
496
+ type Def = ArrayScanDef ;
497
+ }
498
+
499
+ impl HasConcrete for ArrayScanDef {
500
+ type Concrete = ArrayScan ;
501
+
502
+ fn instantiate ( & self , type_args : & [ TypeArg ] ) -> Result < Self :: Concrete , OpLoadError > {
503
+ match type_args {
504
+ [ TypeArg :: BoundedNat { n } , TypeArg :: Type { ty : src_ty } , TypeArg :: Type { ty : tgt_ty } , TypeArg :: Sequence { elems : acc_tys } ] =>
505
+ {
506
+ let acc_tys: Result < _ , OpLoadError > = acc_tys
507
+ . iter ( )
508
+ . map ( |acc_ty| match acc_ty {
509
+ TypeArg :: Type { ty } => Ok ( ty. clone ( ) ) ,
510
+ _ => Err ( SignatureError :: InvalidTypeArgs . into ( ) ) ,
511
+ } )
512
+ . collect ( ) ;
513
+ Ok ( ArrayScan :: new ( src_ty. clone ( ) , tgt_ty. clone ( ) , acc_tys?, * n) )
514
+ }
515
+ _ => Err ( SignatureError :: InvalidTypeArgs . into ( ) ) ,
516
+ }
517
+ }
518
+ }
519
+
315
520
#[ cfg( test) ]
316
521
mod tests {
317
522
use strum:: IntoEnumIterator ;
@@ -320,6 +525,7 @@ mod tests {
320
525
builder:: { inout_sig, DFGBuilder , Dataflow , DataflowHugr } ,
321
526
extension:: prelude:: { BOOL_T , QB_T } ,
322
527
ops:: { OpTrait , OpType } ,
528
+ types:: Signature ,
323
529
} ;
324
530
325
531
use super :: * ;
@@ -459,4 +665,89 @@ mod tests {
459
665
)
460
666
) ;
461
667
}
668
+
669
+ #[ test]
670
+ fn test_repeat ( ) {
671
+ let size = 2 ;
672
+ let element_ty = QB_T ;
673
+ let op = ArrayOpDef :: repeat. to_concrete ( element_ty. clone ( ) , size) ;
674
+
675
+ let optype: OpType = op. into ( ) ;
676
+
677
+ let sig = optype. dataflow_signature ( ) . unwrap ( ) ;
678
+
679
+ assert_eq ! (
680
+ sig. io( ) ,
681
+ (
682
+ & vec![ Type :: new_function( Signature :: new( vec![ ] , vec![ QB_T ] ) ) ] . into( ) ,
683
+ & vec![ array_type( size, element_ty. clone( ) ) ] . into( ) ,
684
+ )
685
+ ) ;
686
+ }
687
+
688
+ #[ test]
689
+ fn test_scan_def ( ) {
690
+ let op = ArrayScan :: new ( BOOL_T , QB_T , vec ! [ USIZE_T ] , 2 ) ;
691
+ let optype: OpType = op. clone ( ) . into ( ) ;
692
+ let new_op: ArrayScan = optype. cast ( ) . unwrap ( ) ;
693
+ assert_eq ! ( new_op, op) ;
694
+ }
695
+
696
+ #[ test]
697
+ fn test_scan_map ( ) {
698
+ let size = 2 ;
699
+ let src_ty = QB_T ;
700
+ let tgt_ty = BOOL_T ;
701
+
702
+ let op = ArrayScan :: new ( src_ty. clone ( ) , tgt_ty. clone ( ) , vec ! [ ] , size) ;
703
+ let optype: OpType = op. into ( ) ;
704
+ let sig = optype. dataflow_signature ( ) . unwrap ( ) ;
705
+
706
+ assert_eq ! (
707
+ sig. io( ) ,
708
+ (
709
+ & vec![
710
+ array_type( size, src_ty. clone( ) ) ,
711
+ Type :: new_function( Signature :: new( vec![ src_ty] , vec![ tgt_ty. clone( ) ] ) )
712
+ ]
713
+ . into( ) ,
714
+ & vec![ array_type( size, tgt_ty) ] . into( ) ,
715
+ )
716
+ ) ;
717
+ }
718
+
719
+ #[ test]
720
+ fn test_scan_accs ( ) {
721
+ let size = 2 ;
722
+ let src_ty = QB_T ;
723
+ let tgt_ty = BOOL_T ;
724
+ let acc_ty1 = USIZE_T ;
725
+ let acc_ty2 = QB_T ;
726
+
727
+ let op = ArrayScan :: new (
728
+ src_ty. clone ( ) ,
729
+ tgt_ty. clone ( ) ,
730
+ vec ! [ acc_ty1. clone( ) , acc_ty2. clone( ) ] ,
731
+ size,
732
+ ) ;
733
+ let optype: OpType = op. into ( ) ;
734
+ let sig = optype. dataflow_signature ( ) . unwrap ( ) ;
735
+
736
+ assert_eq ! (
737
+ sig. io( ) ,
738
+ (
739
+ & vec![
740
+ array_type( size, src_ty. clone( ) ) ,
741
+ Type :: new_function( Signature :: new(
742
+ vec![ src_ty, acc_ty1. clone( ) , acc_ty2. clone( ) ] ,
743
+ vec![ tgt_ty. clone( ) , acc_ty1. clone( ) , acc_ty2. clone( ) ]
744
+ ) ) ,
745
+ acc_ty1. clone( ) ,
746
+ acc_ty2. clone( )
747
+ ]
748
+ . into( ) ,
749
+ & vec![ array_type( size, tgt_ty) , acc_ty1, acc_ty2] . into( ) ,
750
+ )
751
+ ) ;
752
+ }
462
753
}
0 commit comments