Skip to content

Commit ee1e525

Browse files
committed
staticdata: Memoize type_in_worklist query (#57917)
When pre-compiling `stdlib/` this cache has a 91% hit rate, so this seems fairly profitable. It also dramatically improves some pathological cases, a few of which have been hit in the wild (arguably due to inference bugs) Without this PR, this package takes exponentially long to pre-compile: ```julia function BigType(N) (N == 0) && return Nothing T = BigType(N-1) return Pair{T,T} end foo(::Type{T}) where T = T precompile(foo, (Type{BigType(40)},)) ``` For an in-the-wild test case hit by a customer, this reduces pre-compilation time from over an hour to just ~two and a half minutes. Resolves #53331.
1 parent 618c74e commit ee1e525

File tree

2 files changed

+93
-51
lines changed

2 files changed

+93
-51
lines changed

src/staticdata.c

Lines changed: 49 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,22 @@ External links:
9090

9191
static const size_t WORLD_AGE_REVALIDATION_SENTINEL = 0x1;
9292

93+
// This structure is used to store hash tables for the memoization
94+
// of queries in staticdata.c (currently only `type_in_worklist`).
95+
typedef struct {
96+
htable_t type_in_worklist;
97+
} jl_query_cache;
98+
99+
static void init_query_cache(jl_query_cache *cache)
100+
{
101+
htable_new(&cache->type_in_worklist, 0);
102+
}
103+
104+
static void destroy_query_cache(jl_query_cache *cache)
105+
{
106+
htable_free(&cache->type_in_worklist);
107+
}
108+
93109
#include "staticdata_utils.c"
94110
#include "precompile_utils.c"
95111

@@ -512,6 +528,7 @@ typedef struct {
512528
jl_array_t *link_ids_gctags;
513529
jl_array_t *link_ids_gvars;
514530
jl_array_t *link_ids_external_fnvars;
531+
jl_query_cache *query_cache;
515532
jl_ptls_t ptls;
516533
// Set (implemented has a hasmap of MethodInstances to themselves) of which MethodInstances have (forward) edges
517534
// to other MethodInstances.
@@ -658,38 +675,37 @@ static int jl_needs_serialization(jl_serializer_state *s, jl_value_t *v) JL_NOTS
658675
return 1;
659676
}
660677

661-
662-
static int caching_tag(jl_value_t *v) JL_NOTSAFEPOINT
678+
static int caching_tag(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
663679
{
664680
if (jl_is_method_instance(v)) {
665681
jl_method_instance_t *mi = (jl_method_instance_t*)v;
666682
jl_value_t *m = mi->def.value;
667683
if (jl_is_method(m) && jl_object_in_image(m))
668-
return 1 + type_in_worklist(mi->specTypes);
684+
return 1 + type_in_worklist(mi->specTypes, query_cache);
669685
}
670686
if (jl_is_datatype(v)) {
671687
jl_datatype_t *dt = (jl_datatype_t*)v;
672688
if (jl_is_tuple_type(dt) ? !dt->isconcretetype : dt->hasfreetypevars)
673689
return 0; // aka !is_cacheable from jltypes.c
674690
if (jl_object_in_image((jl_value_t*)dt->name))
675-
return 1 + type_in_worklist(v);
691+
return 1 + type_in_worklist(v, query_cache);
676692
}
677693
jl_value_t *dtv = jl_typeof(v);
678694
if (jl_is_datatype_singleton((jl_datatype_t*)dtv)) {
679-
return 1 - type_in_worklist(dtv); // these are already recached in the datatype in the image
695+
return 1 - type_in_worklist(dtv, query_cache); // these are already recached in the datatype in the image
680696
}
681697
return 0;
682698
}
683699

684-
static int needs_recaching(jl_value_t *v) JL_NOTSAFEPOINT
700+
static int needs_recaching(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
685701
{
686-
return caching_tag(v) == 2;
702+
return caching_tag(v, query_cache) == 2;
687703
}
688704

689-
static int needs_uniquing(jl_value_t *v) JL_NOTSAFEPOINT
705+
static int needs_uniquing(jl_value_t *v, jl_query_cache *query_cache) JL_NOTSAFEPOINT
690706
{
691707
assert(!jl_object_in_image(v));
692-
return caching_tag(v) == 1;
708+
return caching_tag(v, query_cache) == 1;
693709
}
694710

695711
static void record_field_change(jl_value_t **addr, jl_value_t *newval) JL_NOTSAFEPOINT
@@ -774,7 +790,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
774790
jl_datatype_t *dt = (jl_datatype_t*)v;
775791
// ensure all type parameters are recached
776792
jl_queue_for_serialization_(s, (jl_value_t*)dt->parameters, 1, 1);
777-
if (jl_is_datatype_singleton(dt) && needs_uniquing(dt->instance)) {
793+
if (jl_is_datatype_singleton(dt) && needs_uniquing(dt->instance, s->query_cache)) {
778794
assert(jl_needs_serialization(s, dt->instance)); // should be true, since we visited dt
779795
// do not visit dt->instance for our template object as it leads to unwanted cycles here
780796
// (it may get serialized from elsewhere though)
@@ -785,7 +801,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
785801
if (s->incremental && jl_is_method_instance(v)) {
786802
jl_method_instance_t *mi = (jl_method_instance_t*)v;
787803
jl_value_t *def = mi->def.value;
788-
if (needs_uniquing(v)) {
804+
if (needs_uniquing(v, s->query_cache)) {
789805
// we only need 3 specific fields of this (the rest are not used)
790806
jl_queue_for_serialization(s, mi->def.value);
791807
jl_queue_for_serialization(s, mi->specTypes);
@@ -801,7 +817,7 @@ static void jl_insert_into_serialization_queue(jl_serializer_state *s, jl_value_
801817
record_field_change((jl_value_t**)&mi->cache, NULL);
802818
}
803819
else {
804-
assert(!needs_recaching(v));
820+
assert(!needs_recaching(v, s->query_cache));
805821
}
806822
// n.b. opaque closures cannot be inspected and relied upon like a
807823
// normal method since they can get improperly introduced by generated
@@ -947,9 +963,9 @@ static void jl_queue_for_serialization_(jl_serializer_state *s, jl_value_t *v, i
947963
// Items that require postorder traversal must visit their children prior to insertion into
948964
// the worklist/serialization_order (and also before their first use)
949965
if (s->incremental && !immediate) {
950-
if (jl_is_datatype(t) && needs_uniquing(v))
966+
if (jl_is_datatype(t) && needs_uniquing(v, s->query_cache))
951967
immediate = 1;
952-
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v))
968+
if (jl_is_datatype_singleton((jl_datatype_t*)t) && needs_uniquing(v, s->query_cache))
953969
immediate = 1;
954970
}
955971

@@ -1113,7 +1129,7 @@ static uintptr_t _backref_id(jl_serializer_state *s, jl_value_t *v, jl_array_t *
11131129

11141130
static void record_uniquing(jl_serializer_state *s, jl_value_t *fld, uintptr_t offset) JL_NOTSAFEPOINT
11151131
{
1116-
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld)) {
1132+
if (s->incremental && jl_needs_serialization(s, fld) && needs_uniquing(fld, s->query_cache)) {
11171133
if (jl_is_datatype(fld) || jl_is_datatype_singleton((jl_datatype_t*)jl_typeof(fld)))
11181134
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)offset);
11191135
else if (jl_is_method_instance(fld))
@@ -1297,7 +1313,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
12971313
// write header
12981314
if (object_id_expected)
12991315
write_uint(f, jl_object_id(v));
1300-
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t))
1316+
if (s->incremental && jl_needs_serialization(s, (jl_value_t*)t) && needs_uniquing((jl_value_t*)t, s->query_cache))
13011317
arraylist_push(&s->uniquing_types, (void*)(uintptr_t)(ios_pos(f)|1));
13021318
if (f == s->const_data)
13031319
write_uint(s->const_data, ((uintptr_t)t->smalltag << 4) | GC_OLD_MARKED | GC_IN_IMAGE);
@@ -1308,7 +1324,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
13081324
layout_table.items[item] = (void*)(reloc_offset | (f == s->const_data)); // store the inverse mapping of `serialization_order` (`id` => object-as-streampos)
13091325

13101326
if (s->incremental) {
1311-
if (needs_uniquing(v)) {
1327+
if (needs_uniquing(v, s->query_cache)) {
13121328
if (jl_is_method_instance(v)) {
13131329
assert(f == s->s);
13141330
jl_method_instance_t *mi = (jl_method_instance_t*)v;
@@ -1329,7 +1345,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
13291345
assert(jl_is_datatype_singleton(t) && "unreachable");
13301346
}
13311347
}
1332-
else if (needs_recaching(v)) {
1348+
else if (needs_recaching(v, s->query_cache)) {
13331349
arraylist_push(jl_is_datatype(v) ? &s->fixup_types : &s->fixup_objs, (void*)reloc_offset);
13341350
}
13351351
else if (jl_typetagis(v, jl_binding_type)) {
@@ -1731,7 +1747,7 @@ static void jl_write_values(jl_serializer_state *s) JL_GC_DISABLED
17311747
}
17321748
}
17331749
void *superidx = ptrhash_get(&serialization_order, dt->super);
1734-
if (s->incremental && superidx != HT_NOTFOUND && (char*)superidx - 1 - (char*)HT_NOTFOUND > item && needs_uniquing((jl_value_t*)dt->super))
1750+
if (s->incremental && superidx != HT_NOTFOUND && (char*)superidx - 1 - (char*)HT_NOTFOUND > item && needs_uniquing((jl_value_t*)dt->super, s->query_cache))
17351751
arraylist_push(&s->uniquing_super, dt->super);
17361752
}
17371753
else if (jl_is_typename(v)) {
@@ -2540,7 +2556,8 @@ JL_DLLEXPORT jl_value_t *jl_as_global_root(jl_value_t *val, int insert)
25402556

25412557
static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *newly_inferred, uint64_t worklist_key,
25422558
/* outputs */ jl_array_t **extext_methods, jl_array_t **new_ext_cis,
2543-
jl_array_t **method_roots_list, jl_array_t **ext_targets, jl_array_t **edges)
2559+
jl_array_t **method_roots_list, jl_array_t **ext_targets, jl_array_t **edges,
2560+
jl_query_cache *query_cache)
25442561
{
25452562
// extext_methods: [method1, ...], worklist-owned "extending external" methods added to functions owned by modules outside the worklist
25462563
// ext_targets: [invokesig1, callee1, matches1, ...] non-worklist callees of worklist-owned methods
@@ -2551,7 +2568,7 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
25512568
assert(edges_map == NULL);
25522569

25532570
// Save the inferred code from newly inferred, external methods
2554-
*new_ext_cis = queue_external_cis(newly_inferred);
2571+
*new_ext_cis = queue_external_cis(newly_inferred, query_cache);
25552572

25562573
// Collect method extensions and edges data
25572574
JL_GC_PUSH1(&edges_map);
@@ -2590,7 +2607,8 @@ static void jl_prepare_serialization_data(jl_array_t *mod_array, jl_array_t *new
25902607
static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
25912608
jl_array_t *worklist, jl_array_t *extext_methods,
25922609
jl_array_t *new_ext_cis, jl_array_t *method_roots_list,
2593-
jl_array_t *ext_targets, jl_array_t *edges) JL_GC_DISABLED
2610+
jl_array_t *ext_targets, jl_array_t *edges,
2611+
jl_query_cache *query_cache) JL_GC_DISABLED
25942612
{
25952613
htable_new(&field_replace, 0);
25962614
// strip metadata and IR when requested
@@ -2617,6 +2635,7 @@ static void jl_save_system_image_to_stream(ios_t *f, jl_array_t *mod_array,
26172635
ios_mem(&gvar_record, 0);
26182636
ios_mem(&fptr_record, 0);
26192637
jl_serializer_state s = {0};
2638+
s.query_cache = query_cache;
26202639
s.incremental = !(worklist == NULL);
26212640
s.s = &sysimg;
26222641
s.const_data = &const_data;
@@ -2947,12 +2966,15 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
29472966
int64_t datastartpos = 0;
29482967
JL_GC_PUSH6(&mod_array, &extext_methods, &new_ext_cis, &method_roots_list, &ext_targets, &edges);
29492968

2969+
jl_query_cache query_cache;
2970+
init_query_cache(&query_cache);
2971+
29502972
if (worklist) {
29512973
mod_array = jl_get_loaded_modules(); // __toplevel__ modules loaded in this session (from Base.loaded_modules_array)
29522974
// Generate _native_data`
29532975
if (_native_data != NULL) {
29542976
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist),
2955-
&extext_methods, &new_ext_cis, NULL, NULL, NULL);
2977+
&extext_methods, &new_ext_cis, NULL, NULL, NULL, &query_cache);
29562978
jl_precompile_toplevel_module = (jl_module_t*)jl_array_ptr_ref(worklist, jl_array_len(worklist)-1);
29572979
*_native_data = jl_precompile_worklist(worklist, extext_methods, new_ext_cis);
29582980
jl_precompile_toplevel_module = NULL;
@@ -2981,7 +3003,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
29813003
if (worklist) {
29823004
htable_new(&relocatable_ext_cis, 0);
29833005
jl_prepare_serialization_data(mod_array, newly_inferred, jl_worklist_key(worklist),
2984-
&extext_methods, &new_ext_cis, &method_roots_list, &ext_targets, &edges);
3006+
&extext_methods, &new_ext_cis, &method_roots_list, &ext_targets, &edges, &query_cache);
29853007
if (!emit_split) {
29863008
write_int32(f, 0); // No clone_targets
29873009
write_padding(f, LLT_ALIGN(ios_pos(f), JL_CACHE_BYTE_ALIGNMENT) - ios_pos(f));
@@ -2993,7 +3015,7 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
29933015
}
29943016
if (_native_data != NULL)
29953017
native_functions = *_native_data;
2996-
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_ext_cis, method_roots_list, ext_targets, edges);
3018+
jl_save_system_image_to_stream(ff, mod_array, worklist, extext_methods, new_ext_cis, method_roots_list, ext_targets, edges, &query_cache);
29973019
if (_native_data != NULL)
29983020
native_functions = NULL;
29993021
if (worklist)
@@ -3024,6 +3046,8 @@ JL_DLLEXPORT void jl_create_system_image(void **_native_data, jl_array_t *workli
30243046
}
30253047
}
30263048

3049+
destroy_query_cache(&query_cache);
3050+
30273051
JL_GC_POP();
30283052
*s = f;
30293053
if (emit_split)

src/staticdata_utils.c

Lines changed: 44 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -104,62 +104,80 @@ JL_DLLEXPORT void jl_push_newly_inferred(jl_value_t* ci)
104104
JL_UNLOCK(&newly_inferred_mutex);
105105
}
106106

107-
108107
// compute whether a type references something internal to worklist
109108
// and thus could not have existed before deserialize
110109
// and thus does not need delayed unique-ing
111-
static int type_in_worklist(jl_value_t *v) JL_NOTSAFEPOINT
110+
static int type_in_worklist(jl_value_t *v, jl_query_cache *cache) JL_NOTSAFEPOINT
112111
{
113112
if (jl_object_in_image(v))
114113
return 0; // fast-path for rejection
114+
115+
void *cached = HT_NOTFOUND;
116+
if (cache != NULL)
117+
cached = ptrhash_get(&cache->type_in_worklist, v);
118+
119+
// fast-path for memoized results
120+
if (cached != HT_NOTFOUND)
121+
return cached == v;
122+
123+
int result = 0;
115124
if (jl_is_uniontype(v)) {
116125
jl_uniontype_t *u = (jl_uniontype_t*)v;
117-
return type_in_worklist(u->a) ||
118-
type_in_worklist(u->b);
126+
result = type_in_worklist(u->a, cache) ||
127+
type_in_worklist(u->b, cache);
119128
}
120129
else if (jl_is_unionall(v)) {
121130
jl_unionall_t *ua = (jl_unionall_t*)v;
122-
return type_in_worklist((jl_value_t*)ua->var) ||
123-
type_in_worklist(ua->body);
131+
result = type_in_worklist((jl_value_t*)ua->var, cache) ||
132+
type_in_worklist(ua->body, cache);
124133
}
125134
else if (jl_is_typevar(v)) {
126135
jl_tvar_t *tv = (jl_tvar_t*)v;
127-
return type_in_worklist(tv->lb) ||
128-
type_in_worklist(tv->ub);
136+
result = type_in_worklist(tv->lb, cache) ||
137+
type_in_worklist(tv->ub, cache);
129138
}
130139
else if (jl_is_vararg(v)) {
131140
jl_vararg_t *tv = (jl_vararg_t*)v;
132-
if (tv->T && type_in_worklist(tv->T))
133-
return 1;
134-
if (tv->N && type_in_worklist(tv->N))
135-
return 1;
141+
result = ((tv->T && type_in_worklist(tv->T, cache)) ||
142+
(tv->N && type_in_worklist(tv->N, cache)));
136143
}
137144
else if (jl_is_datatype(v)) {
138145
jl_datatype_t *dt = (jl_datatype_t*)v;
139-
if (!jl_object_in_image((jl_value_t*)dt->name))
140-
return 1;
141-
jl_svec_t *tt = dt->parameters;
142-
size_t i, l = jl_svec_len(tt);
143-
for (i = 0; i < l; i++)
144-
if (type_in_worklist(jl_tparam(dt, i)))
145-
return 1;
146+
if (!jl_object_in_image((jl_value_t*)dt->name)) {
147+
result = 1;
148+
}
149+
else {
150+
jl_svec_t *tt = dt->parameters;
151+
size_t i, l = jl_svec_len(tt);
152+
for (i = 0; i < l; i++) {
153+
if (type_in_worklist(jl_tparam(dt, i), cache)) {
154+
result = 1;
155+
break;
156+
}
157+
}
158+
}
146159
}
147160
else {
148-
return type_in_worklist(jl_typeof(v));
161+
return type_in_worklist(jl_typeof(v), cache);
149162
}
150-
return 0;
163+
164+
// Memoize result
165+
if (cache != NULL)
166+
ptrhash_put(&cache->type_in_worklist, (void*)v, result ? (void*)v : NULL);
167+
168+
return result;
151169
}
152170

153171
// When we infer external method instances, ensure they link back to the
154172
// package. Otherwise they might be, e.g., for external macros.
155173
// Implements Tarjan's SCC (strongly connected components) algorithm, simplified to remove the count variable
156-
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack)
174+
static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited, arraylist_t *stack, jl_query_cache *query_cache)
157175
{
158176
jl_module_t *mod = mi->def.module;
159177
if (jl_is_method(mod))
160178
mod = ((jl_method_t*)mod)->module;
161179
assert(jl_is_module(mod));
162-
if (jl_atomic_load_relaxed(&mi->precompiled) || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes)) {
180+
if (jl_atomic_load_relaxed(&mi->precompiled) || !jl_object_in_image((jl_value_t*)mod) || type_in_worklist(mi->specTypes, query_cache)) {
163181
return 1;
164182
}
165183
if (!mi->backedges) {
@@ -182,7 +200,7 @@ static int has_backedge_to_worklist(jl_method_instance_t *mi, htable_t *visited,
182200
while (i < n) {
183201
jl_method_instance_t *be;
184202
i = get_next_edge(mi->backedges, i, NULL, &be);
185-
int child_found = has_backedge_to_worklist(be, visited, stack);
203+
int child_found = has_backedge_to_worklist(be, visited, stack, query_cache);
186204
if (child_found == 1 || child_found == 2) {
187205
// found what we were looking for, so terminate early
188206
found = 1;
@@ -225,7 +243,7 @@ static int is_relocatable_ci(htable_t *relocatable_ext_cis, jl_code_instance_t *
225243
// from the worklist or explicitly added by a `precompile` statement, and
226244
// (4) are the most recently computed result for that method.
227245
// These will be preserved in the image.
228-
static jl_array_t *queue_external_cis(jl_array_t *list)
246+
static jl_array_t *queue_external_cis(jl_array_t *list, jl_query_cache *query_cache)
229247
{
230248
if (list == NULL)
231249
return NULL;
@@ -246,7 +264,7 @@ static jl_array_t *queue_external_cis(jl_array_t *list)
246264
jl_method_instance_t *mi = ci->def;
247265
jl_method_t *m = mi->def.method;
248266
if (jl_atomic_load_relaxed(&ci->inferred) && jl_is_method(m) && jl_object_in_image((jl_value_t*)m->module)) {
249-
int found = has_backedge_to_worklist(mi, &visited, &stack);
267+
int found = has_backedge_to_worklist(mi, &visited, &stack, query_cache);
250268
assert(found == 0 || found == 1 || found == 2);
251269
assert(stack.len == 0);
252270
if (found == 1 && jl_atomic_load_relaxed(&ci->max_world) == ~(size_t)0) {

0 commit comments

Comments
 (0)