Skip to content

Commit d2868f1

Browse files
topolarityKristofferC
authored andcommitted
staticdata: Memoize type_in_worklist query (#57917)
When pre-compiling `stdlib/` this cache has a 91% hit rate, so this seems fairly profitable. It also dramatically improves some pathological cases, a few of which have been hit in the wild (arguably due to inference bugs) Without this PR, this package takes exponentially long to pre-compile: ```julia function BigType(N) (N == 0) && return Nothing T = BigType(N-1) return Pair{T,T} end foo(::Type{T}) where T = T precompile(foo, (Type{BigType(40)},)) ``` For an in-the-wild test case hit by a customer, this reduces pre-compilation time from over an hour to just ~two and a half minutes. Resolves #53331. (cherry picked from commit 89271dc)
1 parent 1b0e334 commit d2868f1

File tree

2 files changed

+94
-52
lines changed

2 files changed

+94
-52
lines changed

src/staticdata.c

Lines changed: 50 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -92,6 +92,22 @@ static const size_t WORLD_AGE_REVALIDATION_SENTINEL = 0x1;
9292
JL_DLLEXPORT size_t jl_require_world = ~(size_t)0;
9393
JL_DLLEXPORT _Atomic(size_t) jl_first_image_replacement_world = ~(size_t)0;
9494

95+
// This structure is used to store hash tables for the memoization
96+
// of queries in staticdata.c (currently only `type_in_worklist`).
97+
typedef struct {
98+
htable_t type_in_worklist;
99+
} jl_query_cache;
100+
101+
static void init_query_cache(jl_query_cache *cache)
102+
{
103+
htable_new(&cache->type_in_worklist, 0);
104+
}
105+
106+
static void destroy_query_cache(jl_query_cache *cache)
107+
{
108+
htable_free(&cache->type_in_worklist);
109+
}
110+
95111
#include "staticdata_utils.c"
96112
#include "precompile_utils.c"
97113

@@ -555,6 +571,7 @@ typedef struct {
555571
jl_array_t *method_roots_list;
556572
htable_t method_roots_index;
557573
uint64_t worklist_key;
574+
jl_query_cache *query_cache;
558575
jl_ptls_t ptls;
559576
jl_image_t *image;
560577
int8_t incremental;
@@ -702,14 +719,13 @@ static int jl_needs_serialization(jl_serializer_state *s, jl_value_t *v) JL_NOTS
702719
return 1;
703720
}
704721

705-
706-
static int caching_tag(jl_value_t *v) JL_NOTSAFEPOINT
722+
static int caching_tag(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
707723
{
708724
if (jl_is_method_instance(v)) {
709725
jl_method_instance_t *mi = (jl_method_instance_t*)v;
710726
jl_value_t *m = mi->def.value;
711727
if (jl_is_method(m) && jl_object_in_image(m))
712-
return 1 + type_in_worklist(mi->specTypes);
728+
return 1 + type_in_worklist(mi->specTypes, query_cache);
713729
}
714730
if (jl_is_binding(v)) {
715731
jl_globalref_t *gr = ((jl_binding_t*)v)->globalref;
@@ -724,24 +740,24 @@ static int caching_tag(jl_value_t *v) JL_NOTSAFEPOINT
724740
if (jl_is_tuple_type(dt) ? !dt->isconcretetype : dt->hasfreetypevars)
725741
return 0; // aka !is_cacheable from jltypes.c
726742
if (jl_object_in_image((jl_value_t*)dt->name))
727-
return 1 + type_in_worklist(v);
743+
return 1 + type_in_worklist(v, query_cache);
728744
}
729745
jl_value_t *dtv = jl_typeof(v);
730746
if (jl_is_datatype_singleton((jl_datatype_t*)dtv)) {
731-
return 1 - type_in_worklist(dtv); // these are already recached in the datatype in the image
747+
return 1 - type_in_worklist(dtv, query_cache); // these are already recached in the datatype in the image
732748
}
733749
return 0;
734750
}
735751

736-
static int needs_recaching(jl_value_t *v) JL_NOTSAFEPOINT
752+
static int needs_recaching(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
737753
{
738-
return caching_tag(v) == 2;
754+
return caching_tag(v, query_cache) == 2;
739755
}
740756

741-
static int needs_uniquing(jl_value_t *v) JL_NOTSAFEPOINT
757+
static int needs_uniquing(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
742758
{
743759
assert(!jl_object_in_image(v));
744-
return caching_tag(v) == 1;
760+
return caching_tag(v, query_cache) == 1;
745761
}
746762

747763
static void record_field_change(jl_value_t **addr, jl_value_t *newval) JL_NOTSAFEPOINT
@@ -861,7 +877,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
861877
jl_datatype_t *dt = (jl_datatype_t*)v;
862878
// ensure all type parameters are recached
863879
jl_queue_for_serialization_(s, (jl_value_t*)dt->parameters, 1, 1);
864-
if (jl_is_datatype_singleton(dt) && needs_uniquing(dt->instance)) {
880+
if (jl_is_datatype_singleton(dt) && needs_uniquing(dt->instance, s->query_cache)) {
865881
assert(jl_needs_serialization(s, dt->instance)); // should be true, since we visited dt
866882
// do not visit dt->instance for our template object as it leads to unwanted cycles here
867883
// (it may get serialized from elsewhere though)
@@ -872,7 +888,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
872888
if (s->incremental && jl_is_method_instance(v)) {
873889
jl_method_instance_t *mi = (jl_method_instance_t*)v;
874890
jl_value_t *def = mi->def.value;
875-
if (needs_uniquing(v)) {
891+
if (needs_uniquing(v, s->query_cache)) {
876892
// we only need 3 specific fields of this (the rest are not used)
877893
jl_queue_for_serialization(s, mi->def.value);
878894
jl_queue_for_serialization(s, mi->specTypes);
@@ -887,7 +903,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
887903
record_field_change((jl_value_t**)&mi->cache, NULL);
888904
}
889905
else {
890-
assert(!needs_recaching(v));
906+
assert(!needs_recaching(v, s->query_cache));
891907
}
892908
// n.b. opaque closures cannot be inspected and relied upon like a
893909
// normal method since they can get improperly introduced by generated
@@ -897,7 +913,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
897913
// error now.
898914
}
899915
if (s->incremental && jl_is_binding(v)) {
900-
if (needs_uniquing(v)) {
916+
if (needs_uniquing(v, s->query_cache)) {
901917
jl_binding_t *b = (jl_binding_t*)v;
902918
jl_queue_for_serialization(s, b->globalref->mod);
903919
jl_queue_for_serialization(s, b->globalref->name);
@@ -1121,9 +1137,9 @@ static void jl_queue_for_serialization_(jl_serializer_state *s, jl_value_t *v, i
11211137
// Items that require postorder traversal must visit their children prior to insertion into
11221138
// the worklist/serialization_order (and also before their first use)
11231139
if (s->incremental && !immediate) {
1124-
if (jl_is_datatype(t) && needs_uniquing(v))
1140+
if (jl_is_datatype(t) && needs_uniquing(v, s->query_cache))
11251141
immediate = 1;
1126-
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v))
1142+
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v, s->query_cache))
11271143
immediate = 1;
11281144
}
11291145

@@ -1286,7 +1302,7 @@ static uintptr_t _backref_id(jl_serializer_state *s, jl_value_t *v, jl_array_t *
12861302

12871303
static void record_uniquing(jl_serializer_state *s, jl_value_t *fld, uintptr_t offset) JL_NOTSAFEPOINT
12881304
{
1289-
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld)) {
1305+
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld, s->query_cache)) {
12901306
if (jl_is_datatype(fld) || jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(fld)))
12911307
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)offset);
12921308
else if (jl_is_method_instance(fld) || jl_is_binding(fld))
@@ -1510,7 +1526,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
15101526
// write header
15111527
if (object_id_expected)
15121528
write_uint(f, jl_object_id(v));
1513-
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t))
1529+
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t, s->query_cache))
15141530
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)(ios_pos(f)|1));
15151531
if (f == s->const_data)
15161532
write_uint(s->const_data, ((uintptr_t)t->smalltag << 4) | GC_OLD_MARKED | GC_IN_IMAGE);
@@ -1521,7 +1537,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
15211537
layout_table.items[item] = (void*)(reloc_offset | (f == s->const_data)); // store the inverse mapping of `serialization_order` (`id` => object-as-streampos)
15221538

15231539
if (s->incremental) {
1524-
if (needs_uniquing(v)) {
1540+
if (needs_uniquing(v, s->query_cache)) {
15251541
if (jl_typetagis(v, jl_binding_type)) {
15261542
jl_binding_t *b = (jl_binding_t*)v;
15271543
if (b->globalref == NULL)
@@ -1550,7 +1566,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
15501566
assert(jl_is_datatype_singleton(t) && "unreachable");
15511567
}
15521568
}
1553-
else if (needs_recaching(v)) {
1569+
else if (needs_recaching(v, s->query_cache)) {
15541570
arraylist_push(jl_is_datatype(v) ? &s->fixup_types : &s->fixup_objs, (void*)reloc_offset);
15551571
}
15561572
}
@@ -1985,7 +2001,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
19852001
}
19862002
}
19872003
void *superidx = ptrhash_get(&serialization_order, dt->super);
1988-
if (s->incremental && superidx != HT_NOTFOUND && from_seroder_entry(superidx) > item && needs_uniquing((jl_value_t*)dt->super))
2004+
if (s->incremental && superidx != HT_NOTFOUND && from_seroder_entry(superidx) > item && needs_uniquing((jl_value_t*)dt->super, s->query_cache))
19892005
arraylist_push(&s->uniquing_super, dt->super);
19902006
}
19912007
else if (jl_is_typename(v)) {
@@ -2919,13 +2935,14 @@ JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val, int insert)
29192935
static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *newly_inferred,
29202936
/* outputs */ jl_array_t **extext_methods JL_REQUIRE_ROOTED_SLOT,
29212937
jl_array_t **new_ext_cis JL_REQUIRE_ROOTED_SLOT,
2922-
jl_array_t **edges JL_REQUIRE_ROOTED_SLOT)
2938+
jl_array_t **edges JL_REQUIRE_ROOTED_SLOT,
2939+
jl_query_cache *query_cache)
29232940
{
29242941
// extext_methods: [method1, ...], worklist-owned "extending external" methods added to functions owned by modules outside the worklist
29252942
// edges: [caller1, ext_targets, ...] for worklist-owned methods calling external methods
29262943

29272944
// Save the inferred code from newly inferred, external methods
2928-
*new_ext_cis = queue_external_cis(newly_inferred);
2945+
*new_ext_cis = queue_external_cis(newly_inferred, query_cache);
29292946

29302947
// Collect method extensions and edges data
29312948
*extext_methods = jl_alloc_vec_any(0);
@@ -2955,7 +2972,8 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
29552972
// In addition to the system image (where `worklist = NULL`), this can also save incremental images with external linkage
29562973
static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
29572974
jl_array_t *worklist, jl_array_t *extext_methods,
2958-
jl_array_t *new_ext_cis, jl_array_t *edges)
2975+
jl_array_t *new_ext_cis, jl_array_t *edges,
2976+
jl_query_cache *query_cache)
29592977
{
29602978
htable_new(&field_replace, 0);
29612979
htable_new(&bits_replace, 0);
@@ -3062,6 +3080,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
30623080
ios_mem(&gvar_record, 0);
30633081
ios_mem(&fptr_record, 0);
30643082
jl_serializer_state s = {0};
3083+
s.query_cache = query_cache;
30653084
s.incremental = !(worklist == NULL);
30663085
s.s = &sysimg;
30673086
s.const_data = &const_data;
@@ -3422,11 +3441,14 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34223441
int64_t datastartpos = 0;
34233442
JL_GC_PUSH4(&mod_array, &extext_methods, &new_ext_cis, &edges);
34243443

3444+
jl_query_cache query_cache;
3445+
init_query_cache(&query_cache);
3446+
34253447
if (worklist) {
34263448
mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
34273449
// Generate _native_data`
34283450
if (_native_data != NULL) {
3429-
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, NULL);
3451+
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, NULL, &query_cache);
34303452
jl_precompile_toplevel_module = (jl_module_t*)jl_array_ptr_ref(worklist, jl_array_len(worklist)-1);
34313453
*_native_data = jl_precompile_worklist(worklist, extext_methods, new_ext_cis);
34323454
jl_precompile_toplevel_module = NULL;
@@ -3457,7 +3479,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34573479
assert((ct->reentrant_timing & 0b1110) == 0);
34583480
ct->reentrant_timing |= 0b1000;
34593481
if (worklist) {
3460-
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, &edges);
3482+
jl_prepare_serialization_data(mod_array, newly_inferred, &extext_methods, &new_ext_cis, &edges, &query_cache);
34613483
if (!emit_split) {
34623484
write_int32(f, 0); // No clone_targets
34633485
write_padding(f, LLT_ALIGN(ios_pos(f), JL_CACHE_BYTE_ALIGNMENT) - ios_pos(f));
@@ -3469,7 +3491,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34693491
}
34703492
if (_native_data != NULL)
34713493
native_functions = *_native_data;
3472-
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_ext_cis, edges);
3494+
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_ext_cis, edges, &query_cache);
34733495
if (_native_data != NULL)
34743496
native_functions = NULL;
34753497
// make sure we don't run any Julia code concurrently before this point
@@ -3498,6 +3520,8 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
34983520
}
34993521
}
35003522

3523+
destroy_query_cache(&query_cache);
3524+
35013525
JL_GC_POP();
35023526
*s = f;
35033527
if (emit_split)

src/staticdata_utils.c

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -131,63 +131,81 @@ JL_DLLEXPORT void jl_push_newly_inferred(jl_value_t* ci)
131131
JL_UNLOCK(&newly_inferred_mutex);
132132
}
133133

134-
135134
// compute whether a type references something internal to worklist
136135
// and thus could not have existed before deserialize
137136
// and thus does not need delayed unique-ing
138-
static int type_in_worklist(jl_value_t *v) JL_NOTSAFEPOINT
137+
static int type_in_worklist(jl_value_t *v, jl_query_cache *cache) JL_NOTSAFEPOINT
139138
{
140139
if (jl_object_in_image(v))
141140
return 0; // fast-path for rejection
141+
142+
void *cached = HT_NOTFOUND;
143+
if (cache != NULL)
144+
cached = ptrhash_get(&cache->type_in_worklist, v);
145+
146+
// fast-path for memoized results
147+
if (cached != HT_NOTFOUND)
148+
return cached == v;
149+
150+
int result = 0;
142151
if (jl_is_uniontype(v)) {
143152
jl_uniontype_t *u = (jl_uniontype_t*)v;
144-
return type_in_worklist(u->a) ||
145-
type_in_worklist(u->b);
153+
result = type_in_worklist(u->a, cache) ||
154+
type_in_worklist(u->b, cache);
146155
}
147156
else if (jl_is_unionall(v)) {
148157
jl_unionall_t *ua = (jl_unionall_t*)v;
149-
return type_in_worklist((jl_value_t*)ua->var) ||
150-
type_in_worklist(ua->body);
158+
result = type_in_worklist((jl_value_t*)ua->var, cache) ||
159+
type_in_worklist(ua->body, cache);
151160
}
152161
else if (jl_is_typevar(v)) {
153162
jl_tvar_t *tv = (jl_tvar_t*)v;
154-
return type_in_worklist(tv->lb) ||
155-
type_in_worklist(tv->ub);
163+
result = type_in_worklist(tv->lb, cache) ||
164+
type_in_worklist(tv->ub, cache);
156165
}
157166
else if (jl_is_vararg(v)) {
158167
jl_vararg_t *tv = (jl_vararg_t*)v;
159-
if (tv->T && type_in_worklist(tv->T))
160-
return 1;
161-
if (tv->N && type_in_worklist(tv->N))
162-
return 1;
168+
result = ((tv->T && type_in_worklist(tv->T, cache)) ||
169+
(tv->N && type_in_worklist(tv->N, cache)));
163170
}
164171
else if (jl_is_datatype(v)) {
165172
jl_datatype_t *dt = (jl_datatype_t*)v;
166-
if (!jl_object_in_image((jl_value_t*)dt->name))
167-
return 1;
168-
jl_svec_t *tt = dt->parameters;
169-
size_t i, l = jl_svec_len(tt);
170-
for (i = 0; i < l; i++)
171-
if (type_in_worklist(jl_tparam(dt, i)))
172-
return 1;
173+
if (!jl_object_in_image((jl_value_t*)dt->name)) {
174+
result = 1;
175+
}
176+
else {
177+
jl_svec_t *tt = dt->parameters;
178+
size_t i, l = jl_svec_len(tt);
179+
for (i = 0; i < l; i++) {
180+
if (type_in_worklist(jl_tparam(dt, i), cache)) {
181+
result = 1;
182+
break;
183+
}
184+
}
185+
}
173186
}
174187
else {
175-
return type_in_worklist(jl_typeof(v));
188+
return type_in_worklist(jl_typeof(v), cache);
176189
}
177-
return 0;
190+
191+
// Memoize result
192+
if (cache != NULL)
193+
ptrhash_put(&cache->type_in_worklist, (void*)v, result ? (void*)v : NULL);
194+
195+
return result;
178196
}
179197

180198
// When we infer external method instances, ensure they link back to the
181199
// package. Otherwise they might be, e.g., for external macros.
182200
// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
183-
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack)
201+
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack, jl_query_cache *query_cache)
184202
{
185203
jl_module_t *mod = mi->def.module;
186204
if (jl_is_method(mod))
187205
mod = ((jl_method_t*)mod)->module;
188206
assert(jl_is_module(mod));
189207
uint8_t is_precompiled = jl_atomic_load_relaxed(&mi->flags) & JL_MI_FLAGS_MASK_PRECOMPILED;
190-
if (is_precompiled || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes)) {
208+
if (is_precompiled || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes, query_cache)) {
191209
return 1;
192210
}
193211
if (!mi->backedges) {
@@ -211,7 +229,7 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited,
211229
jl_code_instance_t *be;
212230
i = get_next_edge(mi->backedges, i, NULL, &be);
213231
JL_GC_PROMISE_ROOTED(be); // get_next_edge propagates the edge for us here
214-
int child_found = has_backedge_to_worklist(jl_get_ci_mi(be), visited, stack);
232+
int child_found = has_backedge_to_worklist(jl_get_ci_mi(be), visited, stack, query_cache);
215233
if (child_found == 1 || child_found == 2) {
216234
// found what we were looking for, so terminate early
217235
found = 1;
@@ -243,7 +261,7 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited,
243261
// from the worklist or explicitly added by a `precompile` statement, and
244262
// (4) are the most recently computed result for that method.
245263
// These will be preserved in the image.
246-
static jl_array_t *queue_external_cis(jl_array_t *list)
264+
static jl_array_t *queue_external_cis(jl_array_t *list, jl_query_cache *query_cache)
247265
{
248266
if (list == NULL)
249267
return NULL;
@@ -262,7 +280,7 @@ static jl_array_t *queue_external_cis(jl_array_t *list)
262280
jl_method_instance_t *mi = jl_get_ci_mi(ci);
263281
jl_method_t *m = mi->def.method;
264282
if (ci->owner == jl_nothing && jl_atomic_load_relaxed(&ci->inferred) && jl_is_method(m) && jl_object_in_image((jl_value_t*)m->module)) {
265-
int found = has_backedge_to_worklist(mi, &visited, &stack);
283+
int found = has_backedge_to_worklist(mi, &visited, &stack, query_cache);
266284
assert(found == 0 || found == 1 || found == 2);
267285
assert(stack.len == 0);
268286
if (found == 1 && jl_atomic_load_relaxed(&ci->max_world) == ~(size_t)0) {

0 commit comments

Comments
 (0)