Skip to content

Commit e7921da

Browse files
authored
expand use of egal for testing type equality (#39604)
fixes #39565
1 parent 91b7c76 commit e7921da

File tree

4 files changed

+42
-38
lines changed

4 files changed

+42
-38
lines changed

src/builtins.c

Lines changed: 30 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -126,13 +126,27 @@ static int NOINLINE compare_fields(jl_value_t *a, jl_value_t *b, jl_datatype_t *
126126
return 1;
127127
}
128128

129-
static int egal_types(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env) JL_NOTSAFEPOINT
129+
static int egal_types(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env, int tvar_names) JL_NOTSAFEPOINT
130130
{
131131
if (a == b)
132132
return 1;
133133
jl_datatype_t *dt = (jl_datatype_t*)jl_typeof(a);
134134
if (dt != (jl_datatype_t*)jl_typeof(b))
135135
return 0;
136+
if (dt == jl_datatype_type) {
137+
jl_datatype_t *dta = (jl_datatype_t*)a;
138+
jl_datatype_t *dtb = (jl_datatype_t*)b;
139+
if (dta->name != dtb->name)
140+
return 0;
141+
size_t i, l = jl_nparams(dta);
142+
if (jl_nparams(dtb) != l)
143+
return 0;
144+
for (i = 0; i < l; i++) {
145+
if (!egal_types(jl_tparam(dta, i), jl_tparam(dtb, i), env, tvar_names))
146+
return 0;
147+
}
148+
return 1;
149+
}
136150
if (dt == jl_tvar_type) {
137151
jl_typeenv_t *pe = env;
138152
while (pe != NULL) {
@@ -142,49 +156,39 @@ static int egal_types(jl_value_t *a, jl_value_t *b, jl_typeenv_t *env) JL_NOTSAF
142156
}
143157
return 0;
144158
}
145-
if (dt == jl_uniontype_type) {
146-
return egal_types(((jl_uniontype_t*)a)->a, ((jl_uniontype_t*)b)->a, env) &&
147-
egal_types(((jl_uniontype_t*)a)->b, ((jl_uniontype_t*)b)->b, env);
148-
}
149159
if (dt == jl_unionall_type) {
150160
jl_unionall_t *ua = (jl_unionall_t*)a;
151161
jl_unionall_t *ub = (jl_unionall_t*)b;
152-
if (ua->var->name != ub->var->name)
162+
if (tvar_names && ua->var->name != ub->var->name)
153163
return 0;
154-
if (!(egal_types(ua->var->lb, ub->var->lb, env) && egal_types(ua->var->ub, ub->var->ub, env)))
164+
if (!(egal_types(ua->var->lb, ub->var->lb, env, tvar_names) && egal_types(ua->var->ub, ub->var->ub, env, tvar_names)))
155165
return 0;
156166
jl_typeenv_t e = { ua->var, (jl_value_t*)ub->var, env };
157-
return egal_types(ua->body, ub->body, &e);
167+
return egal_types(ua->body, ub->body, &e, tvar_names);
158168
}
159-
if (dt == jl_datatype_type) {
160-
jl_datatype_t *dta = (jl_datatype_t*)a;
161-
jl_datatype_t *dtb = (jl_datatype_t*)b;
162-
if (dta->name != dtb->name)
163-
return 0;
164-
size_t i, l = jl_nparams(dta);
165-
if (jl_nparams(dtb) != l)
166-
return 0;
167-
for (i = 0; i < l; i++) {
168-
if (!egal_types(jl_tparam(dta, i), jl_tparam(dtb, i), env))
169-
return 0;
170-
}
171-
return 1;
169+
if (dt == jl_uniontype_type) {
170+
return egal_types(((jl_uniontype_t*)a)->a, ((jl_uniontype_t*)b)->a, env, tvar_names) &&
171+
egal_types(((jl_uniontype_t*)a)->b, ((jl_uniontype_t*)b)->b, env, tvar_names);
172172
}
173-
if (dt == jl_vararg_type)
174-
{
173+
if (dt == jl_vararg_type) {
175174
jl_vararg_t *vma = (jl_vararg_t*)a;
176175
jl_vararg_t *vmb = (jl_vararg_t*)b;
177176
jl_value_t *vmaT = vma->T ? vma->T : (jl_value_t*)jl_any_type;
178177
jl_value_t *vmbT = vmb->T ? vmb->T : (jl_value_t*)jl_any_type;
179-
if (!egal_types(vmaT, vmbT, env))
178+
if (!egal_types(vmaT, vmbT, env, tvar_names))
180179
return 0;
181180
if (vma->N && vmb->N)
182-
return egal_types(vma->N, vmb->N, env);
181+
return egal_types(vma->N, vmb->N, env, tvar_names);
183182
return !vma->N && !vmb->N;
184183
}
185184
return jl_egal(a, b);
186185
}
187186

187+
JL_DLLEXPORT int jl_types_egal(jl_value_t *a, jl_value_t *b)
188+
{
189+
return egal_types(a, b, NULL, 0);
190+
}
191+
188192
JL_DLLEXPORT int jl_egal(jl_value_t *a JL_MAYBE_UNROOTED, jl_value_t *b JL_MAYBE_UNROOTED) JL_NOTSAFEPOINT
189193
{
190194
// warning: a,b may NOT have been gc-rooted by the caller
@@ -219,7 +223,7 @@ JL_DLLEXPORT int jl_egal(jl_value_t *a JL_MAYBE_UNROOTED, jl_value_t *b JL_MAYBE
219223
if (nf == 0 || !dt->layout->haspadding)
220224
return bits_equal(a, b, sz);
221225
if (dt == jl_unionall_type)
222-
return egal_types(a, b, NULL);
226+
return egal_types(a, b, NULL, 1);
223227
return compare_fields(a, b, dt);
224228
}
225229

src/julia_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -473,6 +473,7 @@ jl_svec_t *jl_outer_unionall_vars(jl_value_t *u);
473473
jl_value_t *jl_type_intersection_env_s(jl_value_t *a, jl_value_t *b, jl_svec_t **penv, int *issubty);
474474
jl_value_t *jl_type_intersection_env(jl_value_t *a, jl_value_t *b, jl_svec_t **penv);
475475
int jl_subtype_matching(jl_value_t *a, jl_value_t *b, jl_svec_t **penv);
476+
JL_DLLEXPORT int jl_types_egal(jl_value_t *a, jl_value_t *b);
476477
// specificity comparison assuming !(a <: b) and !(b <: a)
477478
JL_DLLEXPORT int jl_type_morespecific_no_subtype(jl_value_t *a, jl_value_t *b);
478479
jl_value_t *jl_instantiate_type_with(jl_value_t *t, jl_value_t **env, size_t n);

src/subtype.c

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1811,7 +1811,7 @@ JL_DLLEXPORT int jl_subtype_env(jl_value_t *x, jl_value_t *y, jl_value_t **env,
18111811
if (x == y ||
18121812
(jl_typeof(x) == jl_typeof(y) &&
18131813
(jl_is_unionall(y) || jl_is_uniontype(y)) &&
1814-
jl_egal(x, y))) {
1814+
jl_types_egal(x, y))) {
18151815
if (envsz != 0) { // quickly copy env from x
18161816
jl_unionall_t *ua = (jl_unionall_t*)x;
18171817
int i;
@@ -1877,7 +1877,9 @@ JL_DLLEXPORT int jl_subtype(jl_value_t *x, jl_value_t *y)
18771877

18781878
JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
18791879
{
1880-
if (obviously_egal(a, b))
1880+
if (a == b)
1881+
return 1;
1882+
if (jl_typeof(a) == jl_typeof(b) && jl_types_egal(a, b))
18811883
return 1;
18821884
if (obviously_unequal(a, b))
18831885
return 0;
@@ -1896,11 +1898,6 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
18961898
if (b == (jl_value_t*)jl_any_type || a == jl_bottom_type) {
18971899
subtype_ab = 1;
18981900
}
1899-
else if (jl_typeof(a) == jl_typeof(b) &&
1900-
(jl_is_unionall(b) || jl_is_uniontype(b)) &&
1901-
jl_egal(a, b)) {
1902-
subtype_ab = 1;
1903-
}
19041901
else if (jl_obvious_subtype(a, b, &subtype_ab)) {
19051902
#ifdef NDEBUG
19061903
if (subtype_ab == 0)
@@ -1915,11 +1912,6 @@ JL_DLLEXPORT int jl_types_equal(jl_value_t *a, jl_value_t *b)
19151912
if (a == (jl_value_t*)jl_any_type || b == jl_bottom_type) {
19161913
subtype_ba = 1;
19171914
}
1918-
else if (jl_typeof(b) == jl_typeof(a) &&
1919-
(jl_is_unionall(a) || jl_is_uniontype(a)) &&
1920-
jl_egal(b, a)) {
1921-
subtype_ba = 1;
1922-
}
19231915
else if (jl_obvious_subtype(b, a, &subtype_ba)) {
19241916
#ifdef NDEBUG
19251917
if (subtype_ba == 0)

test/subtype.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1872,3 +1872,10 @@ g39218(a, b) = (@nospecialize; if a isa AB39218 && b isa AB39218; f39218(a, b);
18721872
# issue #39521
18731873
@test Tuple{Type{Tuple{A}} where A, DataType, DataType} <: Tuple{Vararg{B}} where B
18741874
@test Tuple{DataType, Type{Tuple{A}} where A, DataType} <: Tuple{Vararg{B}} where B
1875+
1876+
let A = Tuple{Type{<:Union{Number, T}}, Ref{T}} where T,
1877+
B = Tuple{Type{<:Union{Number, T}}, Ref{T}} where T
1878+
# TODO: these are caught by the egal check, but the core algorithm gets them wrong
1879+
@test A == B
1880+
@test A <: B
1881+
end

0 commit comments

Comments
 (0)