Skip to content

Commit 0c68de8

Browse files
authored
fix: Make SumType::Unit(N) equal to SumType::General([(); N]) (#2250)
`SumType` can represent the same sum of empty tuples in two different ways, but its equality implementation didn't consider them equal. This PR fixes this by manually implementing `PartialEq`, as well as `Hash` (to preserve correctness). We already do that on the python side https://github.com/CQCL/hugr/blob/940895cf3eae57bb6cae6d13f96bf4f065cb4d2b/hugr-py/src/hugr/tys.py#L311-L312 Ideally we'd make the general-sum-of-empty-tuples unrepresentable, but since `SumType` is an enum (with public member), we are not able to avoid arbitrary mutation.
1 parent 70bec95 commit 0c68de8

5 files changed

+42
-23
lines changed

hugr-core/src/types.rs

Lines changed: 34 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,7 @@ pub(crate) fn least_upper_bound(mut tags: impl Iterator<Item = TypeBound>) -> Ty
171171
.into_inner()
172172
}
173173

174-
#[derive(Clone, PartialEq, Debug, Eq, Hash, Serialize, Deserialize)]
174+
#[derive(Clone, Debug, Eq, Serialize, Deserialize)]
175175
#[serde(tag = "s")]
176176
#[non_exhaustive]
177177
/// Representation of a Sum type.
@@ -186,6 +186,18 @@ pub enum SumType {
186186
General { rows: Vec<TypeRowRV> },
187187
}
188188

189+
impl std::hash::Hash for SumType {
190+
fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
191+
self.variants().for_each(|v| v.hash(state));
192+
}
193+
}
194+
195+
impl PartialEq for SumType {
196+
fn eq(&self, other: &Self) -> bool {
197+
self.num_variants() == other.num_variants() && self.variants().eq(other.variants())
198+
}
199+
}
200+
189201
impl std::fmt::Display for SumType {
190202
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191203
if self.num_variants() == 0 {
@@ -839,6 +851,7 @@ pub(crate) fn check_typevar_decl(
839851

840852
#[cfg(test)]
841853
pub(crate) mod test {
854+
use std::hash::{Hash, Hasher};
842855
use std::sync::Weak;
843856

844857
use super::*;
@@ -917,20 +930,26 @@ pub(crate) mod test {
917930

918931
#[test]
919932
fn sum_variants() {
920-
{
921-
let variants: Vec<TypeRowRV> = vec![
922-
TypeRV::UNIT.into(),
923-
vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(),
924-
];
925-
let t = SumType::new(variants.clone());
926-
assert_eq!(variants, t.variants().cloned().collect_vec());
927-
}
928-
{
929-
assert_eq!(
930-
vec![&TypeRV::EMPTY_TYPEROW; 3],
931-
SumType::new_unary(3).variants().collect_vec()
932-
);
933-
}
933+
let variants: Vec<TypeRowRV> = vec![
934+
TypeRV::UNIT.into(),
935+
vec![TypeRV::new_row_var_use(0, TypeBound::Any)].into(),
936+
];
937+
let t = SumType::new(variants.clone());
938+
assert_eq!(variants, t.variants().cloned().collect_vec());
939+
940+
let empty_rows = vec![TypeRV::EMPTY_TYPEROW; 3];
941+
let sum_unary = SumType::new_unary(3);
942+
let sum_general = SumType::General {
943+
rows: empty_rows.clone(),
944+
};
945+
assert_eq!(&empty_rows, &sum_unary.variants().cloned().collect_vec());
946+
assert_eq!(sum_general, sum_unary);
947+
948+
let mut hasher_general = std::hash::DefaultHasher::new();
949+
sum_general.hash(&mut hasher_general);
950+
let mut hasher_unary = std::hash::DefaultHasher::new();
951+
sum_unary.hash(&mut hasher_unary);
952+
assert_eq!(hasher_general.finish(), hasher_unary.finish());
934953
}
935954

936955
pub(super) struct FnTransformer<T>(pub(super) T);

hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_2.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ expression: mod_str
55
; ModuleID = 'test_context'
66
source_filename = "test_context"
77

8-
@sa.c.911aa16d.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] }
8+
@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] }
99

1010
define { i64, [0 x i1] }* @_hl.main.1() {
1111
alloca_block:
1212
br label %entry_block
1313

1414
entry_block: ; preds = %alloca_block
15-
ret { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.911aa16d.0 to { i64, [0 x i1] }*)
15+
ret { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d2dddd66.0 to { i64, [0 x i1] }*)
1616
}

hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@llvm14_3.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,12 @@ expression: mod_str
55
; ModuleID = 'test_context'
66
source_filename = "test_context"
77

8-
@sa.d.4c6da27.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] }
8+
@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] }
99

1010
define { i64, [0 x { i1, i64 }] }* @_hl.main.1() {
1111
alloca_block:
1212
br label %entry_block
1313

1414
entry_block: ; preds = %alloca_block
15-
ret { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.4c6da27.0 to { i64, [0 x { i1, i64 }] }*)
15+
ret { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.eee08a59.0 to { i64, [0 x { i1, i64 }] }*)
1616
}

hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_2.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ expression: mod_str
55
; ModuleID = 'test_context'
66
source_filename = "test_context"
77

8-
@sa.c.911aa16d.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] }
8+
@sa.c.d2dddd66.0 = constant { i64, [10 x i1] } { i64 10, [10 x i1] [i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false, i1 true, i1 false] }
99

1010
define { i64, [0 x i1] }* @_hl.main.1() {
1111
alloca_block:
@@ -14,7 +14,7 @@ alloca_block:
1414
br label %entry_block
1515

1616
entry_block: ; preds = %alloca_block
17-
store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.911aa16d.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8
17+
store { i64, [0 x i1] }* bitcast ({ i64, [10 x i1] }* @sa.c.d2dddd66.0 to { i64, [0 x i1] }*), { i64, [0 x i1] }** %"5_0", align 8
1818
%"5_01" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"5_0", align 8
1919
store { i64, [0 x i1] }* %"5_01", { i64, [0 x i1] }** %"0", align 8
2020
%"02" = load { i64, [0 x i1] }*, { i64, [0 x i1] }** %"0", align 8

hugr-llvm/src/extension/collections/snapshots/hugr_llvm__extension__collections__static_array__test__static_array_const_codegen@pre-mem2reg@llvm14_3.snap

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ expression: mod_str
55
; ModuleID = 'test_context'
66
source_filename = "test_context"
77

8-
@sa.d.4c6da27.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] }
8+
@sa.d.eee08a59.0 = constant { i64, [10 x { i1, i64 }] } { i64 10, [10 x { i1, i64 }] [{ i1, i64 } { i1 true, i64 0 }, { i1, i64 } { i1 true, i64 1 }, { i1, i64 } { i1 true, i64 2 }, { i1, i64 } { i1 true, i64 3 }, { i1, i64 } { i1 true, i64 4 }, { i1, i64 } { i1 true, i64 5 }, { i1, i64 } { i1 true, i64 6 }, { i1, i64 } { i1 true, i64 7 }, { i1, i64 } { i1 true, i64 8 }, { i1, i64 } { i1 true, i64 9 }] }
99

1010
define { i64, [0 x { i1, i64 }] }* @_hl.main.1() {
1111
alloca_block:
@@ -14,7 +14,7 @@ alloca_block:
1414
br label %entry_block
1515

1616
entry_block: ; preds = %alloca_block
17-
store { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.4c6da27.0 to { i64, [0 x { i1, i64 }] }*), { i64, [0 x { i1, i64 }] }** %"5_0", align 8
17+
store { i64, [0 x { i1, i64 }] }* bitcast ({ i64, [10 x { i1, i64 }] }* @sa.d.eee08a59.0 to { i64, [0 x { i1, i64 }] }*), { i64, [0 x { i1, i64 }] }** %"5_0", align 8
1818
%"5_01" = load { i64, [0 x { i1, i64 }] }*, { i64, [0 x { i1, i64 }] }** %"5_0", align 8
1919
store { i64, [0 x { i1, i64 }] }* %"5_01", { i64, [0 x { i1, i64 }] }** %"0", align 8
2020
%"02" = load { i64, [0 x { i1, i64 }] }*, { i64, [0 x { i1, i64 }] }** %"0", align 8

0 commit comments

Comments
 (0)