Skip to content

Commit 116d003

Browse files
authored
feat!: Share Extensions under Arcs (#1647)
Extensions are defined once and shared throughout `ExtensionRegistr`ies, `Package`s, and soon `Hugr`s (#1613). This _write-once then share around_ is a good usecase for `Arc`s, specially since the definitions are mostly read and rarely cloned. This is a requisite for #1613, to avoid cloning all extensions for each new hugr. BREAKING CHANGE: `ExtensionRegistry` and `Package` now wrap `Extension`s in `Arc`s.
1 parent e621054 commit 116d003

File tree

23 files changed

+138
-98
lines changed

23 files changed

+138
-98
lines changed

hugr-cli/tests/validate.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44
//! calling the CLI binary, which Miri doesn't support.
55
#![cfg(all(test, not(miri)))]
66

7+
use std::sync::Arc;
8+
79
use assert_cmd::Command;
810
use assert_fs::{fixture::FileWriteStr, NamedTempFile};
911
use hugr::builder::{DFGBuilder, DataflowSubContainer, ModuleBuilder};
@@ -49,7 +51,7 @@ fn test_package(#[default(BOOL_T)] id_type: Type) -> Package {
4951
let hugr = module.hugr().clone(); // unvalidated
5052

5153
let rdr = std::fs::File::open(FLOAT_EXT_FILE).unwrap();
52-
let float_ext: hugr::Extension = serde_json::from_reader(rdr).unwrap();
54+
let float_ext: Arc<hugr::Extension> = serde_json::from_reader(rdr).unwrap();
5355
Package::new(vec![hugr], vec![float_ext]).unwrap()
5456
}
5557

hugr-core/src/extension.rs

Lines changed: 37 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,11 @@ pub mod declarative;
3838

3939
/// Extension Registries store extensions to be looked up e.g. during validation.
4040
#[derive(Clone, Debug, PartialEq)]
41-
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Extension>);
41+
pub struct ExtensionRegistry(BTreeMap<ExtensionId, Arc<Extension>>);
4242

4343
impl ExtensionRegistry {
4444
/// Gets the Extension with the given name
45-
pub fn get(&self, name: &str) -> Option<&Extension> {
45+
pub fn get(&self, name: &str) -> Option<&Arc<Extension>> {
4646
self.0.get(name)
4747
}
4848

@@ -51,9 +51,9 @@ impl ExtensionRegistry {
5151
self.0.contains_key(name)
5252
}
5353

54-
/// Makes a new ExtensionRegistry, validating all the extensions in it
54+
/// Makes a new [ExtensionRegistry], validating all the extensions in it.
5555
pub fn try_new(
56-
value: impl IntoIterator<Item = Extension>,
56+
value: impl IntoIterator<Item = Arc<Extension>>,
5757
) -> Result<Self, ExtensionRegistryError> {
5858
let mut res = ExtensionRegistry(BTreeMap::new());
5959

@@ -70,20 +70,28 @@ impl ExtensionRegistry {
7070
ext.validate(&res)
7171
.map_err(|e| ExtensionRegistryError::InvalidSignature(ext.name().clone(), e))?;
7272
}
73+
7374
Ok(res)
7475
}
7576

7677
/// Registers a new extension to the registry.
7778
///
7879
/// Returns a reference to the registered extension if successful.
79-
pub fn register(&mut self, extension: Extension) -> Result<&Extension, ExtensionRegistryError> {
80+
pub fn register(
81+
&mut self,
82+
extension: impl Into<Arc<Extension>>,
83+
) -> Result<(), ExtensionRegistryError> {
84+
let extension = extension.into();
8085
match self.0.entry(extension.name().clone()) {
8186
btree_map::Entry::Occupied(prev) => Err(ExtensionRegistryError::AlreadyRegistered(
8287
extension.name().clone(),
8388
prev.get().version().clone(),
8489
extension.version().clone(),
8590
)),
86-
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
91+
btree_map::Entry::Vacant(ve) => {
92+
ve.insert(extension);
93+
Ok(())
94+
}
8795
}
8896
}
8997

@@ -93,21 +101,24 @@ impl ExtensionRegistry {
93101
/// If versions match, the original extension is kept.
94102
/// Returns a reference to the registered extension if successful.
95103
///
96-
/// Avoids cloning the extension unless required. For a reference version see
104+
/// Takes an Arc to the extension. To avoid cloning Arcs unless necessary, see
97105
/// [`ExtensionRegistry::register_updated_ref`].
98106
pub fn register_updated(
99107
&mut self,
100-
extension: Extension,
101-
) -> Result<&Extension, ExtensionRegistryError> {
108+
extension: impl Into<Arc<Extension>>,
109+
) -> Result<(), ExtensionRegistryError> {
110+
let extension = extension.into();
102111
match self.0.entry(extension.name().clone()) {
103112
btree_map::Entry::Occupied(mut prev) => {
104113
if prev.get().version() < extension.version() {
105114
*prev.get_mut() = extension;
106115
}
107-
Ok(prev.into_mut())
108116
}
109-
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension)),
117+
btree_map::Entry::Vacant(ve) => {
118+
ve.insert(extension);
119+
}
110120
}
121+
Ok(())
111122
}
112123

113124
/// Registers a new extension to the registry, keeping most up to date if
@@ -117,21 +128,23 @@ impl ExtensionRegistry {
117128
/// If versions match, the original extension is kept. Returns a reference
118129
/// to the registered extension if successful.
119130
///
120-
/// Clones the extension if required. For no-cloning version see
131+
/// Clones the Arc only when required. For no-cloning version see
121132
/// [`ExtensionRegistry::register_updated`].
122133
pub fn register_updated_ref(
123134
&mut self,
124-
extension: &Extension,
125-
) -> Result<&Extension, ExtensionRegistryError> {
135+
extension: &Arc<Extension>,
136+
) -> Result<(), ExtensionRegistryError> {
126137
match self.0.entry(extension.name().clone()) {
127138
btree_map::Entry::Occupied(mut prev) => {
128139
if prev.get().version() < extension.version() {
129140
*prev.get_mut() = extension.clone();
130141
}
131-
Ok(prev.into_mut())
132142
}
133-
btree_map::Entry::Vacant(ve) => Ok(ve.insert(extension.clone())),
143+
btree_map::Entry::Vacant(ve) => {
144+
ve.insert(extension.clone());
145+
}
134146
}
147+
Ok(())
135148
}
136149

137150
/// Returns the number of extensions in the registry.
@@ -145,20 +158,20 @@ impl ExtensionRegistry {
145158
}
146159

147160
/// Returns an iterator over the extensions in the registry.
148-
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Extension)> {
161+
pub fn iter(&self) -> impl Iterator<Item = (&ExtensionId, &Arc<Extension>)> {
149162
self.0.iter()
150163
}
151164

152165
/// Delete an extension from the registry and return it if it was present.
153-
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Extension> {
166+
pub fn remove_extension(&mut self, name: &ExtensionId) -> Option<Arc<Extension>> {
154167
self.0.remove(name)
155168
}
156169
}
157170

158171
impl IntoIterator for ExtensionRegistry {
159-
type Item = (ExtensionId, Extension);
172+
type Item = (ExtensionId, Arc<Extension>);
160173

161-
type IntoIter = <BTreeMap<ExtensionId, Extension> as IntoIterator>::IntoIter;
174+
type IntoIter = <BTreeMap<ExtensionId, Arc<Extension>> as IntoIterator>::IntoIter;
162175

163176
fn into_iter(self) -> Self::IntoIter {
164177
self.0.into_iter()
@@ -646,10 +659,10 @@ pub mod test {
646659

647660
let ext_1_id = ExtensionId::new("ext1").unwrap();
648661
let ext_2_id = ExtensionId::new("ext2").unwrap();
649-
let ext1 = Extension::new(ext_1_id.clone(), Version::new(1, 0, 0));
650-
let ext1_1 = Extension::new(ext_1_id.clone(), Version::new(1, 1, 0));
651-
let ext1_2 = Extension::new(ext_1_id.clone(), Version::new(0, 2, 0));
652-
let ext2 = Extension::new(ext_2_id, Version::new(1, 0, 0));
662+
let ext1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 0, 0)));
663+
let ext1_1 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(1, 1, 0)));
664+
let ext1_2 = Arc::new(Extension::new(ext_1_id.clone(), Version::new(0, 2, 0)));
665+
let ext2 = Arc::new(Extension::new(ext_2_id, Version::new(1, 0, 0)));
653666

654667
reg.register(ext1.clone()).unwrap();
655668
reg_ref.register(ext1.clone()).unwrap();

hugr-core/src/extension/declarative.rs

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -136,8 +136,8 @@ impl ExtensionSetDeclaration {
136136
registry,
137137
};
138138
let ext = decl.make_extension(&self.imports, ctx)?;
139-
let ext = registry.register(ext)?;
140-
scope.insert(ext.name())
139+
scope.insert(ext.name());
140+
registry.register(ext)?;
141141
}
142142

143143
Ok(())
@@ -272,6 +272,7 @@ mod test {
272272
use itertools::Itertools;
273273
use rstest::rstest;
274274
use std::path::PathBuf;
275+
use std::sync::Arc;
275276

276277
use crate::extension::PRELUDE_REGISTRY;
277278
use crate::std_extensions;
@@ -406,7 +407,7 @@ extensions:
406407
fn new_extensions<'a>(
407408
reg: &'a ExtensionRegistry,
408409
dependencies: &'a ExtensionRegistry,
409-
) -> impl Iterator<Item = (&'a ExtensionId, &'a Extension)> {
410+
) -> impl Iterator<Item = (&'a ExtensionId, &'a Arc<Extension>)> {
410411
reg.iter()
411412
.filter(move |(id, _)| !dependencies.contains(id) && *id != &PRELUDE_ID)
412413
}

hugr-core/src/extension/op_def.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -617,7 +617,7 @@ pub(super) mod test {
617617
assert_eq!(def.misc.len(), 1);
618618

619619
let reg =
620-
ExtensionRegistry::try_new([PRELUDE.to_owned(), EXTENSION.to_owned(), e]).unwrap();
620+
ExtensionRegistry::try_new([PRELUDE.clone(), EXTENSION.clone(), e.into()]).unwrap();
621621
let e = reg.get(&EXT_ID).unwrap();
622622

623623
let list_usize =

hugr-core/src/extension/prelude.rs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
//! Prelude extension - available in all contexts, defining common types,
22
//! operations and constants.
3+
use std::sync::Arc;
4+
35
use itertools::Itertools;
46
use lazy_static::lazy_static;
57

@@ -38,7 +40,7 @@ pub const PRELUDE_ID: ExtensionId = ExtensionId::new_unchecked("prelude");
3840
/// Extension version.
3941
pub const VERSION: semver::Version = semver::Version::new(0, 1, 0);
4042
lazy_static! {
41-
static ref PRELUDE_DEF: Extension = {
43+
static ref PRELUDE_DEF: Arc<Extension> = {
4244
let mut prelude = Extension::new(PRELUDE_ID, VERSION);
4345
prelude
4446
.add_type(
@@ -106,14 +108,15 @@ lazy_static! {
106108
LiftDef.add_to_extension(&mut prelude).unwrap();
107109
array::ArrayOpDef::load_all_ops(&mut prelude).unwrap();
108110
array::ArrayScanDef.add_to_extension(&mut prelude).unwrap();
109-
prelude
111+
112+
Arc::new(prelude)
110113
};
111114
/// An extension registry containing only the prelude
112115
pub static ref PRELUDE_REGISTRY: ExtensionRegistry =
113-
ExtensionRegistry::try_new([PRELUDE_DEF.to_owned()]).unwrap();
116+
ExtensionRegistry::try_new([PRELUDE_DEF.clone()]).unwrap();
114117

115118
/// Prelude extension
116-
pub static ref PRELUDE: &'static Extension = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap();
119+
pub static ref PRELUDE: Arc<Extension> = PRELUDE_REGISTRY.get(&PRELUDE_ID).unwrap().clone();
117120

118121
}
119122

hugr-core/src/extension/simple_op.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,8 @@ impl<T: MakeRegisteredOp> From<T> for OpType {
272272

273273
#[cfg(test)]
274274
mod test {
275+
use std::sync::Arc;
276+
275277
use crate::{const_extension_ids, type_row, types::Signature};
276278

277279
use super::*;
@@ -313,13 +315,13 @@ mod test {
313315
}
314316

315317
lazy_static! {
316-
static ref EXT: Extension = {
318+
static ref EXT: Arc<Extension> = {
317319
let mut e = Extension::new_test(EXT_ID.clone());
318320
DummyEnum::Dumb.add_to_extension(&mut e).unwrap();
319-
e
321+
Arc::new(e)
320322
};
321323
static ref DUMMY_REG: ExtensionRegistry =
322-
ExtensionRegistry::try_new([EXT.to_owned()]).unwrap();
324+
ExtensionRegistry::try_new([EXT.clone()]).unwrap();
323325
}
324326
impl MakeRegisteredOp for DummyEnum {
325327
fn extension_id(&self) -> ExtensionId {

hugr-core/src/hugr/rewrite/inline_dfg.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -257,9 +257,9 @@ mod test {
257257
let [q, p] = swap.outputs_arr();
258258
let cx = h.add_dataflow_op(test_quantum_extension::cx_gate(), [q, p])?;
259259
let reg = ExtensionRegistry::try_new([
260-
test_quantum_extension::EXTENSION.to_owned(),
261-
PRELUDE.to_owned(),
262-
float_types::EXTENSION.to_owned(),
260+
test_quantum_extension::EXTENSION.clone(),
261+
PRELUDE.clone(),
262+
float_types::EXTENSION.clone(),
263263
])
264264
.unwrap();
265265

hugr-core/src/hugr/validate/test.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,7 @@ fn invalid_types() {
386386
TypeDefBound::any(),
387387
)
388388
.unwrap();
389-
let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()]).unwrap();
389+
let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()]).unwrap();
390390

391391
let validate_to_sig_error = |t: CustomType| {
392392
let (h, def) = identity_hugr_with_type(Type::new_extension(t));
@@ -643,7 +643,7 @@ fn instantiate_row_variables() -> Result<(), Box<dyn std::error::Error>> {
643643
let eval2 = dfb.add_dataflow_op(eval2, [par_func, a, b])?;
644644
dfb.finish_hugr_with_outputs(
645645
eval2.outputs(),
646-
&ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(),
646+
&ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(),
647647
)?;
648648
Ok(())
649649
}
@@ -683,7 +683,7 @@ fn row_variables() -> Result<(), Box<dyn std::error::Error>> {
683683
let par_func = fb.add_dataflow_op(par, [func_arg, id_usz])?;
684684
fb.finish_hugr_with_outputs(
685685
par_func.outputs(),
686-
&ExtensionRegistry::try_new([PRELUDE.to_owned(), e]).unwrap(),
686+
&ExtensionRegistry::try_new([PRELUDE.clone(), e.into()]).unwrap(),
687687
)?;
688688
Ok(())
689689
}
@@ -763,7 +763,7 @@ fn test_polymorphic_call() -> Result<(), Box<dyn std::error::Error>> {
763763
f.finish_with_outputs([tup])?
764764
};
765765

766-
let reg = ExtensionRegistry::try_new([e, PRELUDE.to_owned()])?;
766+
let reg = ExtensionRegistry::try_new([e.into(), PRELUDE.clone()])?;
767767
let [func, tup] = d.input_wires_arr();
768768
let call = d.call(
769769
f.handle(),

hugr-core/src/ops/custom.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ impl ExtensionOp {
4646
args: impl Into<Vec<TypeArg>>,
4747
exts: &ExtensionRegistry,
4848
) -> Result<Self, SignatureError> {
49-
let args = args.into();
49+
let args: Vec<TypeArg> = args.into();
5050
let signature = def.compute_signature(&args, exts)?;
5151
Ok(Self {
5252
def,
@@ -62,7 +62,7 @@ impl ExtensionOp {
6262
opaque: &OpaqueOp,
6363
exts: &ExtensionRegistry,
6464
) -> Result<Self, SignatureError> {
65-
let args = args.into();
65+
let args: Vec<TypeArg> = args.into();
6666
// TODO skip computation depending on config
6767
// see https://github.com/CQCL/hugr/issues/1363
6868
let signature = match def.compute_signature(&args, exts) {
@@ -421,7 +421,7 @@ mod test {
421421
SignatureFunc::MissingComputeFunc,
422422
)
423423
.unwrap();
424-
let registry = ExtensionRegistry::try_new([ext]).unwrap();
424+
let registry = ExtensionRegistry::try_new([ext.into()]).unwrap();
425425
let opaque_val = OpaqueOp::new(
426426
ext_id.clone(),
427427
val_name,

hugr-core/src/package.rs

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
use derive_more::{Display, Error, From};
44
use std::collections::HashMap;
55
use std::path::Path;
6+
use std::sync::Arc;
67
use std::{fs, io, mem};
78

89
use crate::builder::{Container, Dataflow, DataflowSubContainer, ModuleBuilder};
@@ -19,7 +20,7 @@ pub struct Package {
1920
/// Module HUGRs included in the package.
2021
pub modules: Vec<Hugr>,
2122
/// Extensions to validate against.
22-
pub extensions: Vec<Extension>,
23+
pub extensions: Vec<Arc<Extension>>,
2324
}
2425

2526
impl Package {
@@ -32,7 +33,7 @@ impl Package {
3233
/// Returns an error if any of the HUGRs does not have a `Module` root.
3334
pub fn new(
3435
modules: impl IntoIterator<Item = Hugr>,
35-
extensions: impl IntoIterator<Item = Extension>,
36+
extensions: impl IntoIterator<Item = Arc<Extension>>,
3637
) -> Result<Self, PackageError> {
3738
let modules: Vec<Hugr> = modules.into_iter().collect();
3839
for (idx, module) in modules.iter().enumerate() {
@@ -62,7 +63,7 @@ impl Package {
6263
/// Returns an error if any of the HUGRs cannot be wrapped in a module.
6364
pub fn from_hugrs(
6465
modules: impl IntoIterator<Item = Hugr>,
65-
extensions: impl IntoIterator<Item = Extension>,
66+
extensions: impl IntoIterator<Item = Arc<Extension>>,
6667
) -> Result<Self, PackageError> {
6768
let modules: Vec<Hugr> = modules
6869
.into_iter()
@@ -378,7 +379,7 @@ mod test {
378379

379380
Package {
380381
modules: vec![hugr0, hugr1],
381-
extensions: vec![ext1, ext2],
382+
extensions: vec![ext1.into(), ext2.into()],
382383
}
383384
}
384385

0 commit comments

Comments
 (0)