Skip to content

Commit a1a70f1

Browse files
authored
feat(core, llvm): add array unpack operations (#2339)
Dual of new_array Closes #1947
1 parent 2a5413a commit a1a70f1

14 files changed

+1452
-1125
lines changed

hugr-core/src/std_extensions/collections/array.rs

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ pub const ARRAY_VALUENAME: TypeName = TypeName::new_inline("array");
4343
/// Reported unique name of the extension
4444
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_unchecked("collections.array");
4545
/// Extension version.
46-
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
46+
pub const VERSION: semver::Version = semver::Version::new(0, 1, 1);
4747

4848
/// A linear, fixed-length collection of values.
4949
///
@@ -197,7 +197,31 @@ pub trait ArrayOpBuilder: GenericArrayOpBuilder {
197197
) -> Result<Wire, BuildError> {
198198
self.add_new_generic_array::<Array>(elem_ty, values)
199199
}
200-
200+
/// Adds an array unpack operation to the dataflow graph.
201+
///
202+
/// This operation unpacks an array into individual elements.
203+
///
204+
/// # Arguments
205+
///
206+
/// * `elem_ty` - The type of the elements in the array.
207+
/// * `size` - The size of the array.
208+
/// * `input` - The wire representing the array to unpack.
209+
///
210+
/// # Errors
211+
///
212+
/// If building the operation fails.
213+
///
214+
/// # Returns
215+
///
216+
/// A vector of wires representing the individual elements from the array.
217+
fn add_array_unpack(
218+
&mut self,
219+
elem_ty: Type,
220+
size: u64,
221+
input: Wire,
222+
) -> Result<Vec<Wire>, BuildError> {
223+
self.add_generic_array_unpack::<Array>(elem_ty, size, input)
224+
}
201225
/// Adds an array clone operation to the dataflow graph and return the wires
202226
/// representing the originala and cloned array.
203227
///

hugr-core/src/std_extensions/collections/array/array_op.rs

Lines changed: 32 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ use crate::utils::Never;
2121

2222
use super::array_kind::ArrayKind;
2323

24-
/// Array operation definitions. Generic over the conrete array implementation.
24+
/// Array operation definitions. Generic over the concrete array implementation.
2525
#[derive(Clone, Copy, Debug, Hash, PartialEq, Eq, IntoStaticStr, EnumIter, EnumString)]
2626
#[allow(non_camel_case_types)]
2727
#[non_exhaustive]
@@ -58,6 +58,10 @@ pub enum GenericArrayOpDef<AK: ArrayKind> {
5858
/// references `AK` to ensure that the type parameter is used.
5959
#[strum(disabled)]
6060
_phantom(PhantomData<AK>, Never),
61+
/// Unpacks an array into its individual elements:
62+
/// `unpack<SIZE><elemty>: array<SIZE, elemty> -> (elemty)^SIZE`
63+
/// where `SIZE` must be statically known (not a variable)
64+
unpack,
6165
}
6266

6367
/// Static parameters for array operations. Includes array size. Type is part of the type scheme.
@@ -76,6 +80,10 @@ impl<AK: ArrayKind> SignatureFromArgs for GenericArrayOpDef<AK> {
7680
params,
7781
FuncValueType::new(vec![elem_ty_var.clone(); n as usize], array_ty),
7882
),
83+
GenericArrayOpDef::unpack => PolyFuncTypeRV::new(
84+
params,
85+
FuncValueType::new(array_ty, vec![elem_ty_var.clone(); n as usize]),
86+
),
7987
GenericArrayOpDef::pop_left | GenericArrayOpDef::pop_right => {
8088
let popped_array_ty = AK::ty(n - 1, elem_ty_var.clone());
8189
PolyFuncTypeRV::new(
@@ -124,9 +132,9 @@ impl<AK: ArrayKind> GenericArrayOpDef<AK> {
124132
_extension_ref: &Weak<Extension>,
125133
) -> SignatureFunc {
126134
use GenericArrayOpDef::{
127-
_phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap,
135+
_phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap, unpack,
128136
};
129-
if let new_array | pop_left | pop_right = self {
137+
if let new_array | unpack | pop_left | pop_right = self {
130138
// implements SignatureFromArgs
131139
// signature computed dynamically, so can rely on type definition in extension.
132140
(*self).into()
@@ -184,7 +192,7 @@ impl<AK: ArrayKind> GenericArrayOpDef<AK> {
184192
),
185193
),
186194
_phantom(_, never) => match *never {},
187-
new_array | pop_left | pop_right => unreachable!(),
195+
new_array | unpack | pop_left | pop_right => unreachable!(),
188196
}
189197
.into()
190198
}
@@ -218,6 +226,7 @@ impl<AK: ArrayKind> MakeOpDef for GenericArrayOpDef<AK> {
218226
fn description(&self) -> String {
219227
match self {
220228
GenericArrayOpDef::new_array => "Create a new array from elements",
229+
GenericArrayOpDef::unpack => "Unpack an array into its elements",
221230
GenericArrayOpDef::get => "Get an element from an array",
222231
GenericArrayOpDef::set => "Set an element in an array",
223232
GenericArrayOpDef::swap => "Swap two elements in an array",
@@ -250,7 +259,7 @@ impl<AK: ArrayKind> MakeOpDef for GenericArrayOpDef<AK> {
250259
}
251260

252261
#[derive(Clone, Debug, PartialEq)]
253-
/// Concrete array operation. Generic over the actual array implemenation.
262+
/// Concrete array operation. Generic over the actual array implementation.
254263
pub struct GenericArrayOp<AK: ArrayKind> {
255264
/// The operation definition.
256265
pub def: GenericArrayOpDef<AK>,
@@ -275,7 +284,7 @@ impl<AK: ArrayKind> MakeExtensionOp for GenericArrayOp<AK> {
275284

276285
fn type_args(&self) -> Vec<TypeArg> {
277286
use GenericArrayOpDef::{
278-
_phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap,
287+
_phantom, discard_empty, get, new_array, pop_left, pop_right, set, swap, unpack,
279288
};
280289
let ty_arg = TypeArg::Type {
281290
ty: self.elem_ty.clone(),
@@ -288,7 +297,7 @@ impl<AK: ArrayKind> MakeExtensionOp for GenericArrayOp<AK> {
288297
);
289298
vec![ty_arg]
290299
}
291-
new_array | pop_left | pop_right | get | set | swap => {
300+
new_array | unpack | pop_left | pop_right | get | set | swap => {
292301
vec![TypeArg::BoundedNat { n: self.size }, ty_arg]
293302
}
294303
_phantom(_, never) => match never {},
@@ -379,6 +388,22 @@ mod tests {
379388
b.finish_hugr_with_outputs(out.outputs()).unwrap();
380389
}
381390

391+
#[rstest]
392+
#[case(Array)]
393+
#[case(ValueArray)]
394+
/// Test building a HUGR involving an unpack operation.
395+
fn test_unpack<AK: ArrayKind>(#[case] _kind: AK) {
396+
let mut b = DFGBuilder::new(inout_sig(AK::ty(2, qb_t()), vec![qb_t(), qb_t()])).unwrap();
397+
398+
let [array] = b.input_wires_arr();
399+
400+
let op = GenericArrayOpDef::<AK>::unpack.to_concrete(qb_t(), 2);
401+
402+
let out = b.add_dataflow_op(op, [array]).unwrap();
403+
404+
b.finish_hugr_with_outputs(out.outputs()).unwrap();
405+
}
406+
382407
#[rstest]
383408
#[case(Array)]
384409
#[case(ValueArray)]

hugr-core/src/std_extensions/collections/array/op_builder.rs

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,32 @@ pub trait GenericArrayOpBuilder: Dataflow {
4949
Ok(out)
5050
}
5151

52+
/// Adds an array unpack operation to the dataflow graph.
53+
///
54+
/// This operation unpacks an array into individual elements.
55+
///
56+
/// # Arguments
57+
///
58+
/// * `elem_ty` - The type of the elements in the array.
59+
/// * `size` - The size of the array.
60+
/// * `input` - The wire representing the array.
61+
///
62+
/// # Errors
63+
///
64+
/// Returns an error if building the operation fails.
65+
///
66+
/// # Returns
67+
///
68+
/// A vector of wires representing the individual elements of the array.
69+
fn add_generic_array_unpack<AK: ArrayKind>(
70+
&mut self,
71+
elem_ty: Type,
72+
size: u64,
73+
input: Wire,
74+
) -> Result<Vec<Wire>, BuildError> {
75+
let op = GenericArrayOpDef::<AK>::unpack.instantiate(&[size.into(), elem_ty.into()])?;
76+
Ok(self.add_dataflow_op(op, vec![input])?.outputs().collect())
77+
}
5278
/// Adds an array clone operation to the dataflow graph and return the wires
5379
/// representing the originala and cloned array.
5480
///
@@ -283,6 +309,17 @@ pub fn build_all_array_ops_generic<B: Dataflow, AK: ArrayKind>(mut builder: B) -
283309
let us0 = builder.add_load_value(ConstUsize::new(0));
284310
let us1 = builder.add_load_value(ConstUsize::new(1));
285311
let us2 = builder.add_load_value(ConstUsize::new(2));
312+
let arr = builder
313+
.add_new_generic_array::<AK>(usize_t(), [us1, us2])
314+
.unwrap();
315+
316+
// Add array unpack operation
317+
let [_us1, _us2] = builder
318+
.add_generic_array_unpack::<AK>(usize_t(), 2, arr)
319+
.unwrap()
320+
.try_into()
321+
.unwrap();
322+
286323
let arr = builder
287324
.add_new_generic_array::<AK>(usize_t(), [us1, us2])
288325
.unwrap();

hugr-core/src/std_extensions/collections/value_array.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ pub const VALUE_ARRAY_VALUENAME: TypeName = TypeName::new_inline("value_array");
3131
/// Reported unique name of the extension
3232
pub const EXTENSION_ID: ExtensionId = ExtensionId::new_static_unchecked("collections.value_array");
3333
/// Extension version.
34-
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
34+
pub const VERSION: semver::Version = semver::Version::new(0, 1, 1);
3535

3636
/// A fixed-length collection of values.
3737
///

hugr-llvm/src/extension/collections/array.rs

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -455,6 +455,24 @@ pub fn emit_array_op<'c, H: HugrView<Node = Node>>(
455455
}
456456
outputs.finish(ctx.builder(), [array_v.into()])
457457
}
458+
ArrayOpDef::unpack => {
459+
let [array_v] = inputs
460+
.try_into()
461+
.map_err(|_| anyhow!("ArrayOpDef::unpack expects one argument"))?;
462+
let (array_ptr, array_offset) = decompose_array_fat_pointer(builder, array_v)?;
463+
464+
let mut result = Vec::with_capacity(size as usize);
465+
let usize_t = usize_ty(&ctx.typing_session());
466+
467+
for i in 0..size {
468+
let idx = builder.build_int_add(array_offset, usize_t.const_int(i, false), "")?;
469+
let elem_addr = unsafe { builder.build_in_bounds_gep(array_ptr, &[idx], "")? };
470+
let elem_v = builder.build_load(elem_addr, "")?;
471+
result.push(elem_v);
472+
}
473+
474+
outputs.finish(ctx.builder(), result)
475+
}
458476
ArrayOpDef::get => {
459477
let [array_v, index_v] = inputs
460478
.try_into()
@@ -999,6 +1017,41 @@ mod test {
9991017
check_emission!(hugr, llvm_ctx);
10001018
}
10011019

1020+
// #[rstest]
1021+
// #[case(1, 2, 3)]
1022+
// #[case(0, 0, 0)]
1023+
// #[case(10, 20, 30)]
1024+
// fn exec_unpack_and_sum(mut exec_ctx: TestContext, #[case] a: u64, #[case] b: u64, #[case] expected: u64) {
1025+
// let hugr = SimpleHugrConfig::new()
1026+
// .with_extensions(exec_registry())
1027+
// .with_outs(vec![usize_t()])
1028+
// .finish(|mut builder| {
1029+
// // Create an array with the test values
1030+
// let values = vec![ConstUsize::new(a).into(), ConstUsize::new(b).into()];
1031+
// let arr = builder.add_load_value(array::ArrayValue::new(usize_t(), values));
1032+
1033+
// // Unpack the array
1034+
// let [val_a, val_b] = builder.add_array_unpack(usize_t(), 2, arr).unwrap().try_into().unwrap();
1035+
1036+
// // Add the values
1037+
// let sum = {
1038+
// let int_ty = int_type(6);
1039+
// let a_int = builder.cast(val_a, int_ty.clone()).unwrap();
1040+
// let b_int = builder.cast(val_b, int_ty.clone()).unwrap();
1041+
// let sum_int = builder.add_iadd(6, a_int, b_int).unwrap();
1042+
// builder.cast(sum_int, usize_t()).unwrap()
1043+
// };
1044+
1045+
// builder.finish_hugr_with_outputs([sum]).unwrap()
1046+
// });
1047+
// exec_ctx.add_extensions(|cge| {
1048+
// cge.add_default_prelude_extensions()
1049+
// .add_default_array_extensions()
1050+
// .add_default_int_extensions()
1051+
// });
1052+
// assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1053+
// }
1054+
10021055
fn exec_registry() -> ExtensionRegistry {
10031056
ExtensionRegistry::new([
10041057
int_types::EXTENSION.to_owned(),
@@ -1398,6 +1451,51 @@ mod test {
13981451
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
13991452
}
14001453

1454+
#[rstest]
1455+
#[case(&[], 0)]
1456+
#[case(&[1, 2], 3)]
1457+
#[case(&[6, 6, 6], 18)]
1458+
fn exec_unpack(
1459+
mut exec_ctx: TestContext,
1460+
#[case] array_contents: &[u64],
1461+
#[case] expected: u64,
1462+
) {
1463+
// We build a HUGR that:
1464+
// - Loads an array with the given contents
1465+
// - Unpacks all the elements
1466+
// - Returns the sum of the elements
1467+
1468+
let int_ty = int_type(6);
1469+
let hugr = SimpleHugrConfig::new()
1470+
.with_outs(int_ty.clone())
1471+
.with_extensions(exec_registry())
1472+
.finish(|mut builder| {
1473+
let array = array::ArrayValue::new(
1474+
int_ty.clone(),
1475+
array_contents
1476+
.iter()
1477+
.map(|&i| ConstInt::new_u(6, i).unwrap().into())
1478+
.collect_vec(),
1479+
);
1480+
let array = builder.add_load_value(array);
1481+
let unpacked = builder
1482+
.add_array_unpack(int_ty.clone(), array_contents.len() as u64, array)
1483+
.unwrap();
1484+
let mut r = builder.add_load_value(ConstInt::new_u(6, 0).unwrap());
1485+
for elem in unpacked {
1486+
r = builder.add_iadd(6, r, elem).unwrap();
1487+
}
1488+
1489+
builder.finish_hugr_with_outputs([r]).unwrap()
1490+
});
1491+
exec_ctx.add_extensions(|cge| {
1492+
cge.add_default_prelude_extensions()
1493+
.add_default_array_extensions()
1494+
.add_default_int_extensions()
1495+
});
1496+
assert_eq!(expected, exec_ctx.exec_hugr_u64(hugr, "main"));
1497+
}
1498+
14011499
#[rstest]
14021500
#[case(5, 42, 0)]
14031501
#[case(5, 42, 1)]

0 commit comments

Comments
 (0)