Skip to content

Commit 0258553

Browse files
authored
small optimization to subtyping (#41672)
Zero and copy only the used portion of the union state buffer.
1 parent 29c9ea0 commit 0258553

File tree

1 file changed

+65
-36
lines changed

1 file changed

+65
-36
lines changed

src/subtype.c

Lines changed: 65 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,19 @@ extern "C" {
4242
// TODO: the stack probably needs to be artificially large because of some
4343
// deeper problem (see #21191) and could be shrunk once that is fixed
4444
typedef struct {
45-
int depth;
46-
int more;
45+
int16_t depth;
46+
int16_t more;
47+
int16_t used;
4748
uint32_t stack[100]; // stack of bits represented as a bit vector
4849
} jl_unionstate_t;
4950

51+
typedef struct {
52+
int16_t depth;
53+
int16_t more;
54+
int16_t used;
55+
void *stack;
56+
} jl_saved_unionstate_t;
57+
5058
// Linked list storing the type variable environment. A new jl_varbinding_t
5159
// is pushed for each UnionAll type we encounter. `lb` and `ub` are updated
5260
// during the computation.
@@ -68,14 +76,14 @@ typedef struct jl_varbinding_t {
6876
// and we would need to return `intersect(var,other)`. in this case
6977
// we choose to over-estimate the intersection by returning the var.
7078
int8_t constraintkind;
71-
int depth0; // # of invariant constructors nested around the UnionAll type for this var
79+
int8_t intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
80+
int16_t depth0; // # of invariant constructors nested around the UnionAll type for this var
7281
// when this variable's integer value is compared to that of another,
7382
// it equals `other + offset`. used by vararg length parameters.
74-
int offset;
83+
int16_t offset;
7584
// array of typevars that our bounds depend on, whose UnionAlls need to be
7685
// moved outside ours.
7786
jl_array_t *innervars;
78-
int intvalued; // must be integer-valued; i.e. occurs as N in Vararg{_,N}
7987
struct jl_varbinding_t *prev;
8088
} jl_varbinding_t;
8189

@@ -129,6 +137,23 @@ static void statestack_set(jl_unionstate_t *st, int i, int val) JL_NOTSAFEPOINT
129137
st->stack[i>>5] &= ~(1u<<(i&31));
130138
}
131139

140+
#define push_unionstate(saved, src) \
141+
do { \
142+
(saved)->depth = (src)->depth; \
143+
(saved)->more = (src)->more; \
144+
(saved)->used = (src)->used; \
145+
(saved)->stack = alloca(((src)->used+7)/8); \
146+
memcpy((saved)->stack, &(src)->stack, ((src)->used+7)/8); \
147+
} while (0);
148+
149+
#define pop_unionstate(dst, saved) \
150+
do { \
151+
(dst)->depth = (saved)->depth; \
152+
(dst)->more = (saved)->more; \
153+
(dst)->used = (saved)->used; \
154+
memcpy(&(dst)->stack, (saved)->stack, ((saved)->used+7)/8); \
155+
} while (0);
156+
132157
typedef struct {
133158
int8_t *buf;
134159
int rdepth;
@@ -486,6 +511,10 @@ static jl_value_t *pick_union_element(jl_value_t *u JL_PROPAGATES_ROOT, jl_stenv
486511
{
487512
jl_unionstate_t *state = R ? &e->Runions : &e->Lunions;
488513
do {
514+
if (state->depth >= state->used) {
515+
statestack_set(state, state->used, 0);
516+
state->used++;
517+
}
489518
int ui = statestack_get(state, state->depth);
490519
state->depth++;
491520
if (ui == 0) {
@@ -514,20 +543,19 @@ static int subtype_ccheck(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
514543
return 1;
515544
if (x == (jl_value_t*)jl_any_type && jl_is_datatype(y))
516545
return 0;
517-
jl_unionstate_t oldLunions = e->Lunions;
518-
jl_unionstate_t oldRunions = e->Runions;
546+
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
547+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
519548
int sub;
520-
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
521-
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
549+
e->Lunions.used = e->Runions.used = 0;
522550
e->Runions.depth = 0;
523551
e->Runions.more = 0;
524552
e->Lunions.depth = 0;
525553
e->Lunions.more = 0;
526554

527555
sub = forall_exists_subtype(x, y, e, 0);
528556

529-
e->Runions = oldRunions;
530-
e->Lunions = oldLunions;
557+
pop_unionstate(&e->Runions, &oldRunions);
558+
pop_unionstate(&e->Lunions, &oldLunions);
531559
return sub;
532560
}
533561

@@ -731,8 +759,8 @@ static jl_unionall_t *unalias_unionall(jl_unionall_t *u, jl_stenv_t *e)
731759
static int subtype_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_t *e, int8_t R, int param)
732760
{
733761
u = unalias_unionall(u, e);
734-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
735-
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
762+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
763+
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
736764
JL_GC_PUSH4(&u, &vb.lb, &vb.ub, &vb.innervars);
737765
e->vars = &vb;
738766
int ans;
@@ -1148,6 +1176,10 @@ static int subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, int param)
11481176
// union against the variable before trying to take it apart to see if there are any
11491177
// variables lurking inside.
11501178
jl_unionstate_t *state = &e->Runions;
1179+
if (state->depth >= state->used) {
1180+
statestack_set(state, state->used, 0);
1181+
state->used++;
1182+
}
11511183
ui = statestack_get(state, state->depth);
11521184
state->depth++;
11531185
if (ui == 0)
@@ -1310,21 +1342,21 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
13101342
(is_definite_length_tuple_type(x) && is_indefinite_length_tuple_type(y)))
13111343
return 0;
13121344

1313-
jl_unionstate_t oldLunions = e->Lunions;
1314-
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
1345+
jl_saved_unionstate_t oldLunions; push_unionstate(&oldLunions, &e->Lunions);
1346+
e->Lunions.used = 0;
13151347
int sub;
13161348

13171349
if (!jl_has_free_typevars(x) || !jl_has_free_typevars(y)) {
1318-
jl_unionstate_t oldRunions = e->Runions;
1319-
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
1350+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
1351+
e->Runions.used = 0;
13201352
e->Runions.depth = 0;
13211353
e->Runions.more = 0;
13221354
e->Lunions.depth = 0;
13231355
e->Lunions.more = 0;
13241356

13251357
sub = forall_exists_subtype(x, y, e, 2);
13261358

1327-
e->Runions = oldRunions;
1359+
pop_unionstate(&e->Runions, &oldRunions);
13281360
}
13291361
else {
13301362
int lastset = 0;
@@ -1342,13 +1374,13 @@ static int forall_exists_equal(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
13421374
}
13431375
}
13441376

1345-
e->Lunions = oldLunions;
1377+
pop_unionstate(&e->Lunions, &oldLunions);
13461378
return sub && subtype(y, x, e, 0);
13471379
}
13481380

13491381
static int exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, jl_value_t *saved, jl_savedenv_t *se, int param)
13501382
{
1351-
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
1383+
e->Runions.used = 0;
13521384
int lastset = 0;
13531385
while (1) {
13541386
e->Runions.depth = 0;
@@ -1379,7 +1411,7 @@ static int forall_exists_subtype(jl_value_t *x, jl_value_t *y, jl_stenv_t *e, in
13791411
JL_GC_PUSH1(&saved);
13801412
save_env(e, &saved, &se);
13811413

1382-
memset(e->Lunions.stack, 0, sizeof(e->Lunions.stack));
1414+
e->Lunions.used = 0;
13831415
int lastset = 0;
13841416
int sub;
13851417
while (1) {
@@ -1415,6 +1447,7 @@ static void init_stenv(jl_stenv_t *e, jl_value_t **env, int envsz)
14151447
e->emptiness_only = 0;
14161448
e->Lunions.depth = 0; e->Runions.depth = 0;
14171449
e->Lunions.more = 0; e->Runions.more = 0;
1450+
e->Lunions.used = 0; e->Runions.used = 0;
14181451
}
14191452

14201453
// subtyping entry points
@@ -2084,14 +2117,14 @@ static jl_value_t *intersect_aside(jl_value_t *x, jl_value_t *y, jl_stenv_t *e,
20842117
if (y == (jl_value_t*)jl_any_type && !jl_is_typevar(x))
20852118
return x;
20862119

2087-
jl_unionstate_t oldRunions = e->Runions;
2120+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
20882121
int savedepth = e->invdepth, Rsavedepth = e->Rinvdepth;
20892122
// TODO: this doesn't quite make sense
20902123
e->invdepth = e->Rinvdepth = d;
20912124

20922125
jl_value_t *res = intersect_all(x, y, e);
20932126

2094-
e->Runions = oldRunions;
2127+
pop_unionstate(&e->Runions, &oldRunions);
20952128
e->invdepth = savedepth;
20962129
e->Rinvdepth = Rsavedepth;
20972130
return res;
@@ -2102,10 +2135,10 @@ static jl_value_t *intersect_union(jl_value_t *x, jl_uniontype_t *u, jl_stenv_t
21022135
if (param == 2 || (!jl_has_free_typevars(x) && !jl_has_free_typevars((jl_value_t*)u))) {
21032136
jl_value_t *a=NULL, *b=NULL;
21042137
JL_GC_PUSH2(&a, &b);
2105-
jl_unionstate_t oldRunions = e->Runions;
2138+
jl_saved_unionstate_t oldRunions; push_unionstate(&oldRunions, &e->Runions);
21062139
a = R ? intersect_all(x, u->a, e) : intersect_all(u->a, x, e);
21072140
b = R ? intersect_all(x, u->b, e) : intersect_all(u->b, x, e);
2108-
e->Runions = oldRunions;
2141+
pop_unionstate(&e->Runions, &oldRunions);
21092142
jl_value_t *i = simple_join(a,b);
21102143
JL_GC_POP();
21112144
return i;
@@ -2600,8 +2633,8 @@ static jl_value_t *intersect_unionall(jl_value_t *t, jl_unionall_t *u, jl_stenv_
26002633
{
26012634
jl_value_t *res=NULL, *res2=NULL, *save=NULL, *save2=NULL;
26022635
jl_savedenv_t se, se2;
2603-
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0,
2604-
R ? e->Rinvdepth : e->invdepth, 0, NULL, 0, e->vars };
2636+
jl_varbinding_t vb = { u->var, u->var->lb, u->var->ub, R, 0, 0, 0, 0, 0,
2637+
R ? e->Rinvdepth : e->invdepth, 0, NULL, e->vars };
26052638
JL_GC_PUSH6(&res, &save2, &vb.lb, &vb.ub, &save, &vb.innervars);
26062639
save_env(e, &save, &se);
26072640
res = intersect_unionall_(t, u, e, R, param, &vb);
@@ -3159,7 +3192,7 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
31593192
{
31603193
e->Runions.depth = 0;
31613194
e->Runions.more = 0;
3162-
memset(e->Runions.stack, 0, sizeof(e->Runions.stack));
3195+
e->Runions.used = 0;
31633196
jl_value_t **is;
31643197
JL_GC_PUSHARGS(is, 3);
31653198
jl_value_t **saved = &is[2];
@@ -3176,11 +3209,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
31763209
save_env(e, saved, &se);
31773210
}
31783211
while (e->Runions.more) {
3179-
if (e->emptiness_only && ii != jl_bottom_type) {
3180-
free_env(&se);
3181-
JL_GC_POP();
3182-
return ii;
3183-
}
3212+
if (e->emptiness_only && ii != jl_bottom_type)
3213+
break;
31843214
e->Runions.depth = 0;
31853215
int set = e->Runions.more - 1;
31863216
e->Runions.more = 0;
@@ -3209,9 +3239,8 @@ static jl_value_t *intersect_all(jl_value_t *x, jl_value_t *y, jl_stenv_t *e)
32093239
}
32103240
total_iter++;
32113241
if (niter > 3 || total_iter > 400000) {
3212-
free_env(&se);
3213-
JL_GC_POP();
3214-
return y;
3242+
ii = y;
3243+
break;
32153244
}
32163245
}
32173246
free_env(&se);

0 commit comments

Comments
 (0)