Skip to content

Commit d20e34f

Browse files
committed
Normalize (simplify) UnionAlls when used as type parameter
1 parent 36270e9 commit d20e34f

File tree

2 files changed

+88
-1
lines changed

2 files changed

+88
-1
lines changed

src/jltypes.c

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1231,6 +1231,88 @@ static jl_value_t *normalize_vararg(jl_value_t *va)
12311231
return va;
12321232
}
12331233

1234+
int _may_substitute_ub(jl_value_t *v, jl_tvar_t *var, int inside_inv, int *cov_count) JL_NOTSAFEPOINT
1235+
{
1236+
if (v == (jl_value_t*)var) {
1237+
if (inside_inv) {
1238+
return 0;
1239+
} else {
1240+
(*cov_count)++;
1241+
return *cov_count <= 1 || jl_is_concrete_type(var->ub);
1242+
}
1243+
}
1244+
else if (jl_is_uniontype(v)) {
1245+
return _may_substitute_ub(((jl_uniontype_t*)v)->a, var, inside_inv, cov_count) &&
1246+
_may_substitute_ub(((jl_uniontype_t*)v)->b, var, inside_inv, cov_count);
1247+
}
1248+
else if (jl_is_unionall(v)) {
1249+
jl_unionall_t *ua = (jl_unionall_t*)v;
1250+
if (ua->var == var)
1251+
return 1;
1252+
return _may_substitute_ub(ua->var->lb, var, inside_inv, cov_count) &&
1253+
_may_substitute_ub(ua->var->ub, var, inside_inv, cov_count) &&
1254+
_may_substitute_ub(ua->body, var, inside_inv, cov_count);
1255+
}
1256+
else if (jl_is_datatype(v)) {
1257+
int istuple = jl_is_tuple_type(v);
1258+
int isva = jl_is_vararg_type(v);
1259+
for (size_t i = 0; i < jl_nparams(v); i++) {
1260+
int invar = isva ? i == 1 : !istuple;
1261+
int ins_i = inside_inv || invar;
1262+
int old_count = *cov_count;
1263+
if (!_may_substitute_ub(jl_tparam(v,i), var, ins_i, cov_count))
1264+
return 0;
1265+
if (isva && i == 0 && *cov_count > old_count && !jl_is_concrete_type(var->ub))
1266+
return 0;
1267+
}
1268+
return 1;
1269+
}
1270+
return 1;
1271+
}
1272+
1273+
// Check whether `var` may be replaced with its upper bound `ub` in `v where var<:ub`
1274+
// Conditions:
1275+
// * `var` does not appear in invariant position
1276+
// * `var` appears at most once (in covariant position) and not in a `Vararg`
1277+
// unless the upper bound is concrete (diagonal rule)
1278+
int may_substitute_ub(jl_value_t *v, jl_tvar_t *var) JL_NOTSAFEPOINT
1279+
{
1280+
int cov_count = 0;
1281+
return _may_substitute_ub(v, var, 0, &cov_count);
1282+
}
1283+
1284+
jl_value_t *normalize_unionalls(jl_value_t *t)
1285+
{
1286+
JL_GC_PUSH1(&t);
1287+
if (jl_is_uniontype(t)) {
1288+
jl_uniontype_t *u = (jl_uniontype_t *) t;
1289+
jl_value_t *a = NULL;
1290+
jl_value_t *b = NULL;
1291+
JL_GC_PUSH2(&a, &b);
1292+
a = normalize_unionalls(u->a);
1293+
b = normalize_unionalls(u->b);
1294+
if (a != u->a || b != u->b) {
1295+
t = jl_new_struct(jl_uniontype_type, a, b);
1296+
}
1297+
JL_GC_POP();
1298+
}
1299+
else if (jl_is_unionall(t)) {
1300+
jl_unionall_t *u = (jl_unionall_t *) t;
1301+
jl_value_t *body = normalize_unionalls(u->body);
1302+
if (body != u->body) {
1303+
JL_GC_PUSH1(&body);
1304+
t = jl_new_struct(jl_unionall_type, u->var, body);
1305+
JL_GC_POP();
1306+
u = (jl_unionall_t *) t;
1307+
}
1308+
1309+
if (u->var->lb == u->var->ub || may_substitute_ub(body, u->var))
1310+
t = jl_instantiate_unionall(u, u->var->ub);
1311+
}
1312+
JL_GC_POP();
1313+
return t;
1314+
}
1315+
12341316
static jl_value_t *_jl_instantiate_type_in_env(jl_value_t *ty, jl_unionall_t *env, jl_value_t **vals, jl_typeenv_t *prev, jl_typestack_t *stack);
12351317

12361318
static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value_t **iparams, size_t ntp,
@@ -1240,6 +1322,11 @@ static jl_value_t *inst_datatype_inner(jl_datatype_t *dt, jl_svec_t *p, jl_value
12401322
jl_typename_t *tn = dt->name;
12411323
int istuple = (tn == jl_tuple_typename);
12421324
int isnamedtuple = (tn == jl_namedtuple_typename);
1325+
if (dt->name != jl_type_typename) {
1326+
for (size_t i = 0; i < ntp; i++)
1327+
iparams[i] = normalize_unionalls(iparams[i]);
1328+
}
1329+
12431330
// check type cache
12441331
if (cacheable) {
12451332
size_t i;

test/subtype.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,7 +140,7 @@ function test_diagonal()
140140
@test !issub(Type{Tuple{T,Any} where T}, Type{Tuple{T,T}} where T)
141141
@test !issub(Type{Tuple{T,Any,T} where T}, Type{Tuple{T,T,T}} where T)
142142
@test_broken issub(Type{Tuple{T} where T}, Type{Tuple{T}} where T)
143-
@test_broken issub(Ref{Tuple{T} where T}, Ref{Tuple{T}} where T)
143+
@test issub(Ref{Tuple{T} where T}, Ref{Tuple{T}} where T)
144144
@test !issub(Type{Tuple{T,T} where T}, Type{Tuple{T,T}} where T)
145145
@test !issub(Type{Tuple{T,T,T} where T}, Type{Tuple{T,T,T}} where T)
146146
@test isequal_type(Ref{Tuple{T, T} where Int<:T<:Int},

0 commit comments

Comments
 (0)