Skip to content

Commit 4ae3f5e

Browse files
authored
Add trim_mode parameter to JIT type-inference entrypoint (#58817)
Resolves #58786. I think this is only a partial fix, since we can still end up loading code from pkgimages that has been poorly inferred due to running without these `InferenceParams`. However, many of the common scenarios (such as JLL's depending on each other) seem to be OK since we have a targeted heuristic that adds `__init__()` to a pkgimage only if the module has inference enabled.
1 parent 4d40c64 commit 4ae3f5e

File tree

8 files changed

+32
-24
lines changed

8 files changed

+32
-24
lines changed

Compiler/src/bootstrap.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ function activate_codegen!()
1010
Core.eval(Compiler, quote
1111
let typeinf_world_age = Base.tls_world_age()
1212
@eval Core.OptimizedGenerics.CompilerPlugins.typeinf(::Nothing, mi::MethodInstance, source_mode::UInt8) =
13-
Base.invoke_in_world($(Expr(:$, :typeinf_world_age)), typeinf_ext_toplevel, mi, Base.tls_world_age(), source_mode)
13+
Base.invoke_in_world($(Expr(:$, :typeinf_world_age)), typeinf_ext_toplevel, mi, Base.tls_world_age(), source_mode, Compiler.TRIM_NO)
1414
end
1515
end)
1616
end
@@ -67,7 +67,7 @@ function bootstrap!()
6767
end
6868
mi = specialize_method(m.method, Tuple{params...}, m.sparams)
6969
#isa_compileable_sig(mi) || println(stderr, "WARNING: inferring `", mi, "` which isn't expected to be called.")
70-
typeinf_ext_toplevel(mi, world, isa_compileable_sig(mi) ? SOURCE_MODE_ABI : SOURCE_MODE_NOT_REQUIRED)
70+
typeinf_ext_toplevel(mi, world, isa_compileable_sig(mi) ? SOURCE_MODE_ABI : SOURCE_MODE_NOT_REQUIRED, TRIM_NO)
7171
end
7272
end
7373
end

Compiler/src/typeinfer.jl

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1564,8 +1564,9 @@ function typeinf_ext_toplevel(interp::AbstractInterpreter, mi::MethodInstance, s
15641564
end
15651565

15661566
# This is a bridge for the C code calling `jl_typeinf_func()` on a single Method match
1567-
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8)
1568-
interp = NativeInterpreter(world)
1567+
function typeinf_ext_toplevel(mi::MethodInstance, world::UInt, source_mode::UInt8, trim_mode::UInt8)
1568+
inf_params = InferenceParams(; force_enable_inference = trim_mode != TRIM_NO)
1569+
interp = NativeInterpreter(world; inf_params)
15691570
return typeinf_ext_toplevel(interp, mi, source_mode)
15701571
end
15711572

@@ -1648,11 +1649,11 @@ end
16481649

16491650
# This is a bridge for the C code calling `jl_typeinf_func()` on set of Method matches
16501651
# The trim_mode can be any of:
1651-
const TRIM_NO = 0
1652-
const TRIM_SAFE = 1
1653-
const TRIM_UNSAFE = 2
1654-
const TRIM_UNSAFE_WARN = 3
1655-
function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim_mode::Int)
1652+
const TRIM_NO = 0x0
1653+
const TRIM_SAFE = 0x1
1654+
const TRIM_UNSAFE = 0x2
1655+
const TRIM_UNSAFE_WARN = 0x3
1656+
function typeinf_ext_toplevel(methods::Vector{Any}, worlds::Vector{UInt}, trim_mode::UInt8)
16561657
inf_params = InferenceParams(; force_enable_inference = trim_mode != TRIM_NO)
16571658

16581659
# Create an "invokelatest" queue to enable eager compilation of speculative

src/aotcompile.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
712712
fargs[2] = (jl_value_t*)worlds;
713713
jl_array_data(worlds, size_t)[0] = jl_typeinf_world;
714714
jl_array_data(worlds, size_t)[compiler_world] = world; // might overwrite previous
715-
fargs[3] = jl_box_long(trim);
715+
fargs[3] = jl_box_uint8(trim);
716716
size_t last_age = ct->world_age;
717717
ct->world_age = jl_typeinf_world;
718718
codeinfos = (jl_array_t*)jl_apply(fargs, 4);
@@ -2488,7 +2488,7 @@ void jl_get_llvmf_defn_impl(jl_llvmf_dump_t *dump, jl_method_instance_t *mi, jl_
24882488
jl_method_instance_t *mi = jl_get_specialization1((jl_tupletype_t*)sigt, latestworld, 0);
24892489
if (mi == nullptr)
24902490
continue;
2491-
jl_code_instance_t *codeinst = jl_type_infer(mi, latestworld, SOURCE_MODE_NOT_REQUIRED);
2491+
jl_code_instance_t *codeinst = jl_type_infer(mi, latestworld, SOURCE_MODE_NOT_REQUIRED, jl_options.trim);
24922492
if (codeinst == nullptr || compiled_functions.count(codeinst))
24932493
continue;
24942494
orc::ThreadSafeModule decl_m = jl_create_ts_module("extern", ctx, DL, TT);

src/gf.c

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ static jl_code_instance_t *jl_method_inferred_with_abi(jl_method_instance_t *mi
403403
// returns the inferred source, and may cache the result in mi
404404
// if successful, also updates the mi argument to describe the validity of this src
405405
// if inference doesn't occur (or can't finish), returns NULL instead
406-
jl_code_instance_t *jl_type_infer(jl_method_instance_t *mi, size_t world, uint8_t source_mode)
406+
jl_code_instance_t *jl_type_infer(jl_method_instance_t *mi, size_t world, uint8_t source_mode, uint8_t trim_mode)
407407
{
408408
if (jl_typeinf_func == NULL) {
409409
if (source_mode == SOURCE_MODE_ABI)
@@ -427,11 +427,12 @@ jl_code_instance_t *jl_type_infer(jl_method_instance_t *mi, size_t world, uint8_
427427
return NULL;
428428
JL_TIMING(INFERENCE, INFERENCE);
429429
jl_value_t **fargs;
430-
JL_GC_PUSHARGS(fargs, 4);
430+
JL_GC_PUSHARGS(fargs, 5);
431431
fargs[0] = (jl_value_t*)jl_typeinf_func;
432432
fargs[1] = (jl_value_t*)mi;
433433
fargs[2] = jl_box_ulong(world);
434434
fargs[3] = jl_box_uint8(source_mode);
435+
fargs[4] = jl_box_uint8(trim_mode);
435436
int last_errno = errno;
436437
#ifdef _OS_WINDOWS_
437438
DWORD last_error = GetLastError();
@@ -458,7 +459,7 @@ jl_code_instance_t *jl_type_infer(jl_method_instance_t *mi, size_t world, uint8_
458459
// allocate another bit for the counter.
459460
ct->reentrant_timing += 0b10;
460461
JL_TRY {
461-
ci = (jl_code_instance_t*)jl_apply(fargs, 4);
462+
ci = (jl_code_instance_t*)jl_apply(fargs, 5);
462463
}
463464
JL_CATCH {
464465
jl_value_t *e = jl_current_exception(ct);
@@ -3196,7 +3197,7 @@ jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *mi, size_t
31963197
int should_skip_inference = !jl_is_method(mi->def.method) || jl_symbol_name(mi->def.method->name)[0] == '@';
31973198

31983199
if (!should_skip_inference) {
3199-
codeinst = jl_type_infer(mi, world, SOURCE_MODE_ABI);
3200+
codeinst = jl_type_infer(mi, world, SOURCE_MODE_ABI, jl_options.trim);
32003201
}
32013202
}
32023203

@@ -3516,7 +3517,7 @@ static void _generate_from_hint(jl_method_instance_t *mi, size_t world)
35163517
{
35173518
jl_value_t *codeinst = jl_rettype_inferred_native(mi, world, world);
35183519
if (codeinst == jl_nothing) {
3519-
(void)jl_type_infer(mi, world, SOURCE_MODE_NOT_REQUIRED);
3520+
(void)jl_type_infer(mi, world, SOURCE_MODE_NOT_REQUIRED, jl_options.trim);
35203521
codeinst = jl_rettype_inferred_native(mi, world, world);
35213522
}
35223523
if (codeinst != jl_nothing) {
@@ -3559,10 +3560,10 @@ JL_DLLEXPORT void jl_compile_method_instance(jl_method_instance_t *mi, jl_tuplet
35593560
miflags = jl_atomic_load_relaxed(&mi2->flags) | JL_MI_FLAGS_MASK_PRECOMPILED;
35603561
jl_atomic_store_relaxed(&mi2->flags, miflags);
35613562
if (jl_rettype_inferred_native(mi2, world, world) == jl_nothing)
3562-
(void)jl_type_infer(mi2, world, SOURCE_MODE_NOT_REQUIRED);
3563+
(void)jl_type_infer(mi2, world, SOURCE_MODE_NOT_REQUIRED, jl_options.trim);
35633564
if (jl_typeinf_func && jl_atomic_load_relaxed(&mi->def.method->primary_world) <= tworld) {
35643565
if (jl_rettype_inferred_native(mi2, tworld, tworld) == jl_nothing)
3565-
(void)jl_type_infer(mi2, tworld, SOURCE_MODE_NOT_REQUIRED);
3566+
(void)jl_type_infer(mi2, tworld, SOURCE_MODE_NOT_REQUIRED, jl_options.trim);
35663567
}
35673568
}
35683569
}

src/julia_internal.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,7 @@ JL_DLLEXPORT void jl_engine_fulfill(jl_code_instance_t *ci, jl_code_info_t *src)
693693
void jl_engine_sweep(jl_ptls_t *gc_all_tls_states) JL_NOTSAFEPOINT;
694694
int jl_engine_hasreserved(jl_method_instance_t *m, jl_value_t *owner) JL_NOTSAFEPOINT;
695695

696-
JL_DLLEXPORT jl_code_instance_t *jl_type_infer(jl_method_instance_t *li JL_PROPAGATES_ROOT, size_t world, uint8_t source_mode);
696+
JL_DLLEXPORT jl_code_instance_t *jl_type_infer(jl_method_instance_t *li JL_PROPAGATES_ROOT, size_t world, uint8_t source_mode, uint8_t trim_mode);
697697
JL_DLLEXPORT jl_code_info_t *jl_gdbcodetyped1(jl_method_instance_t *mi, size_t world);
698698
JL_DLLEXPORT jl_code_instance_t *jl_compile_method_internal(jl_method_instance_t *meth JL_PROPAGATES_ROOT, size_t world);
699699
JL_DLLEXPORT jl_code_instance_t *jl_get_method_inferred(

src/runtime_ccall.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -403,7 +403,7 @@ void *jl_get_abi_converter(jl_task_t *ct, void *data)
403403
}
404404
JL_UNLOCK(&cfun_lock);
405405
// next, try to figure out what the target should look like (outside of the lock since this is very slow)
406-
codeinst = mi ? jl_type_infer(mi, world, SOURCE_MODE_ABI) : nullptr;
406+
codeinst = mi ? jl_type_infer(mi, world, SOURCE_MODE_ABI, jl_options.trim) : nullptr;
407407
// relock for the remainder of the function
408408
JL_LOCK(&cfun_lock);
409409
} while (jl_atomic_load_acquire(&jl_world_counter) != world); // restart entirely, since jl_world_counter changed thus jl_get_specialization1 might have changed

src/toplevel.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -757,7 +757,7 @@ JL_DLLEXPORT jl_value_t *jl_toplevel_eval_flex(jl_module_t *JL_NONNULL m, jl_val
757757
size_t world = jl_atomic_load_acquire(&jl_world_counter);
758758
ct->world_age = world;
759759
if (!has_defs && jl_get_module_infer(m) != 0) {
760-
(void)jl_type_infer(mfunc, world, SOURCE_MODE_ABI);
760+
(void)jl_type_infer(mfunc, world, SOURCE_MODE_ABI, jl_options.trim);
761761
}
762762
result = jl_invoke(/*func*/NULL, /*args*/NULL, /*nargs*/0, mfunc);
763763
ct->world_age = last_age;

test/trimming/basic_jll.jl

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
using Libdl
22
using Zstd_jll # Note this uses the vendored older non-LazyLibrary version of Zstd_jll
33

4+
# JLL usage at build-time should function as expected
5+
Zstd_jll.__init__()
6+
const build_ver = unsafe_string(ccall((:ZSTD_versionString, libzstd), Cstring, ()))
7+
48
function print_string(fptr::Ptr{Cvoid})
59
println(Core.stdout, unsafe_string(ccall(fptr, Cstring, ())))
610
end
@@ -9,10 +13,12 @@ function @main(args::Vector{String})::Cint
913
# Test the basic "Hello, world!"
1014
println(Core.stdout, "Julia! Hello, world!")
1115

12-
# Make sure that JLL's are working as expected
13-
println(Core.stdout, unsafe_string(ccall((:ZSTD_versionString, libzstd), Cstring, ())))
16+
# JLL usage at run-time should function as expected
17+
ver = unsafe_string(ccall((:ZSTD_versionString, libzstd), Cstring, ()))
18+
println(Core.stdout, ver)
19+
@assert ver == build_ver
1420

15-
# Add an indirection via `@cfunction`
21+
# Add an indirection via `@cfunction` / 1-arg ccall
1622
cfunc = @cfunction(print_string, Cvoid, (Ptr{Cvoid},))
1723
fptr = dlsym(Zstd_jll.libzstd_handle, :ZSTD_versionString)
1824
ccall(cfunc, Cvoid, (Ptr{Cvoid},), fptr)

0 commit comments

Comments
 (0)