Skip to content

Commit 3076a96

Browse files
authored
Add cfunction support for --trim (JuliaLang#58812)
1 parent b4a6288 commit 3076a96

File tree

5 files changed

+79
-33
lines changed

5 files changed

+79
-33
lines changed

Compiler/src/typeinfer.jl

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1480,16 +1480,30 @@ function collectinvokes!(workqueue::CompilationQueue, ci::CodeInfo, sptypes::Vec
14801480
# No dynamic dispatch to resolve / enqueue
14811481
continue
14821482
end
1483+
elseif isexpr(stmt, :cfunction) && length(stmt.args) == 5
1484+
(pointer_type, f, rt, at, call_type) = stmt.args
1485+
linfo = ci.parent
14831486

1484-
let workqueue = invokelatest_queue
1485-
# make a best-effort attempt to enqueue the relevant code for the finalizer
1486-
mi = compileable_specialization_for_call(workqueue.interp, atype)
1487-
mi === nothing && continue
1487+
linfo isa MethodInstance || continue
1488+
at isa SimpleVector || continue
14881489

1489-
push!(workqueue, mi)
1490+
ft = argextype(f, ci, sptypes)
1491+
argtypes = Any[ft]
1492+
for i = 1:length(at)
1493+
push!(argtypes, sp_type_rewrap(at[i], linfo, #= isreturn =# false))
14901494
end
1495+
atype = argtypes_to_type(argtypes)
1496+
else
1497+
# TODO: handle other StmtInfo like OpaqueClosure?
1498+
continue
1499+
end
1500+
let workqueue = invokelatest_queue
1501+
# make a best-effort attempt to enqueue the relevant code for the dynamic invokelatest call
1502+
mi = compileable_specialization_for_call(workqueue.interp, atype)
1503+
mi === nothing && continue
1504+
1505+
push!(workqueue, mi)
14911506
end
1492-
# TODO: handle other StmtInfo like @cfunction and OpaqueClosure?
14931507
end
14941508
end
14951509

Compiler/src/verifytrim.jl

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ using ..Compiler:
1414
argextype, empty!, error, get, get_ci_mi, get_world_counter, getindex, getproperty,
1515
hasintersect, haskey, in, isdispatchelem, isempty, isexpr, iterate, length, map!, max,
1616
pop!, popfirst!, push!, pushfirst!, reinterpret, reverse!, reverse, setindex!,
17-
setproperty!, similar, singleton_type, sptypes_from_meth_instance,
17+
setproperty!, similar, singleton_type, sptypes_from_meth_instance, sp_type_rewrap,
1818
unsafe_pointer_to_objref, widenconst, isconcretetype,
1919
# misc
2020
@nospecialize, @assert, C_NULL
@@ -256,9 +256,27 @@ function verify_codeinstance!(interp::NativeInterpreter, codeinst::CodeInstance,
256256
warn = true # downgrade must-throw calls to be only a warning
257257
end
258258
elseif isexpr(stmt, :cfunction)
259+
length(stmt.args) != 5 && continue # required by IR legality
260+
(pointer_type, f, rt, at, call_type) = stmt.args
261+
262+
at isa SimpleVector || continue # required by IR legality
263+
ft = argextype(f, codeinfo, sptypes)
264+
argtypes = Any[ft]
265+
for i = 1:length(at)
266+
push!(argtypes, sp_type_rewrap(at[i], get_ci_mi(codeinst), #= isreturn =# false))
267+
end
268+
atype = argtypes_to_type(argtypes)
269+
270+
mi = compileable_specialization_for_call(interp, atype)
271+
if mi !== nothing
272+
# n.b.: Codegen may choose unpredictably to emit this `@cfunction` as a dynamic invoke or a full
273+
# dynamic call, but in either case it guarantees that the required adapter(s) are emitted. All
274+
# that we are required to verify here is that the callee CodeInstance is covered.
275+
ci = get(caches, mi, nothing)
276+
ci isa CodeInstance && continue
277+
end
278+
259279
error = "unresolved cfunction"
260-
#TODO: parse the cfunction expression to check the target is defined
261-
warn = true
262280
elseif isexpr(stmt, :foreigncall)
263281
foreigncall = stmt.args[1]
264282
if foreigncall isa QuoteNode

Compiler/test/verifytrim.jl

Lines changed: 15 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -36,36 +36,39 @@ let infos = typeinf_ext_toplevel(Any[Core.svec(Nothing, Tuple{typeof(finalizer),
3636
\[1\] finalizer\(f::Any, o::Any\)""", repr)
3737
end
3838

39+
# test that basic `cfunction` generation is allowed, when the dispatch target can be resolved
3940
make_cfunction() = @cfunction(+, Float64, (Int64,Int64))
41+
let infos = typeinf_ext_toplevel(Any[Core.svec(Ptr{Cvoid}, Tuple{typeof(make_cfunction)})], [Base.get_world_counter()], TRIM_UNSAFE)
42+
errors, parents = get_verify_typeinf_trim(infos)
43+
@test isempty(errors)
44+
end
4045

4146
# use TRIM_UNSAFE to bypass verifier inside typeinf_ext_toplevel
42-
let infos = typeinf_ext_toplevel(Any[Core.svec(Ptr{Cvoid}, Tuple{typeof(make_cfunction)})], [Base.get_world_counter()], TRIM_UNSAFE)
47+
make_cfunction_bad(@nospecialize(f::Any)) = @cfunction($f, Float64, (Int64,Int64))::Base.CFunction
48+
let infos = typeinf_ext_toplevel(Any[Core.svec(Base.CFunction, Tuple{typeof(make_cfunction_bad), Any})], [Base.get_world_counter()], TRIM_UNSAFE)
4349
errors, parents = get_verify_typeinf_trim(infos)
44-
@test_broken isempty(errors) # missing cfunction
50+
@test !isempty(errors) # missing cfunction
4551

46-
desc = only(errors)
47-
@test desc.first
48-
desc = desc.second
52+
(is_warning, desc) = only(errors)
53+
@test !is_warning
4954
@test desc isa CallMissing
5055
@test occursin("cfunction", desc.desc)
5156
repr = sprint(verify_print_error, desc, parents)
52-
@test occursin(
53-
r"""^unresolved cfunction from statement \$\(Expr\(:cfunction, Ptr{Nothing}, :\(\$\(QuoteNode\(\+\)\)\), Float64, :\(svec\(Int64, Int64\)::Core.SimpleVector\), :\(:ccall\)\)\)::Ptr{Nothing}
57+
@test occursin(r"""^unresolved cfunction from statement \$\(Expr\(:cfunction, Base.CFunction, :\(f::Any\), Float64, :\(svec\(Int64, Int64\)::Core.SimpleVector\), :\(:ccall\)\)\)::Base.CFunction
5458
Stacktrace:
55-
\[1\] make_cfunction\(\)""", repr)
56-
59+
\[1\] make_cfunction_bad\(f::Any\)""", repr)
5760
resize!(infos, 1)
5861
@test infos[1] isa Core.SimpleVector && infos[1][1] isa Type && infos[1][2] isa Type
5962
errors, parents = get_verify_typeinf_trim(infos)
6063
desc = only(errors)
6164
@test !desc.first
6265
desc = desc.second
6366
@test desc isa CCallableMissing
64-
@test desc.rt == Ptr{Cvoid}
65-
@test desc.sig == Tuple{typeof(make_cfunction)}
67+
@test desc.rt == Base.CFunction
68+
@test desc.sig == Tuple{typeof(make_cfunction_bad), Any}
6669
@test occursin("unresolved ccallable", desc.desc)
6770
repr = sprint(verify_print_error, desc, parents)
68-
@test repr == "unresolved ccallable for Tuple{$(typeof(make_cfunction))} => Ptr{Nothing}\n\n"
71+
@test repr == "unresolved ccallable for Tuple{$(typeof(make_cfunction_bad)), Any} => Base.CFunction\n\n"
6972
end
7073

7174
let infos = typeinf_ext_toplevel(Any[Core.svec(Base.SecretBuffer, Tuple{Type{Base.SecretBuffer}})], [Base.get_world_counter()], TRIM_UNSAFE)

src/codegen.cpp

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6356,8 +6356,8 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
63566356
}
63576357
else if (head == jl_cfunction_sym) {
63586358
assert(nargs == 5);
6359-
jl_cgval_t fexpr_rt = emit_expr(ctx, args[1]);
6360-
return emit_cfunction(ctx, args[0], fexpr_rt, args[2], (jl_svec_t*)args[3]);
6359+
jl_cgval_t fexpr_val = emit_expr(ctx, args[1]);
6360+
return emit_cfunction(ctx, args[0], fexpr_val, args[2], (jl_svec_t*)args[3]);
63616361
}
63626362
else if (head == jl_assign_sym) {
63636363
assert(nargs == 2);
@@ -7576,7 +7576,7 @@ static const char *derive_sigt_name(jl_value_t *jargty)
75767576
// Get the LLVM Function* for the C-callable entry point for a certain function
75777577
// and argument types.
75787578
// here argt does not include the leading function type argument
7579-
static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, const jl_cgval_t &fexpr_rt, jl_value_t *declrt, jl_svec_t *argt)
7579+
static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, const jl_cgval_t &fexpr_val, jl_value_t *declrt, jl_svec_t *argt)
75807580
{
75817581
jl_unionall_t *unionall_env = (jl_is_method(ctx.linfo->def.method) && jl_is_unionall(ctx.linfo->def.method->sig))
75827582
? (jl_unionall_t*)ctx.linfo->def.method->sig
@@ -7634,8 +7634,8 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
76347634
// compute+verify the dispatch signature, and see if it depends on the environment sparams
76357635
bool approx = false;
76367636
sigt = (jl_value_t*)jl_alloc_svec(nargt + 1);
7637-
jl_svecset(sigt, 0, fexpr_rt.typ);
7638-
if (!fexpr_rt.constant && (!jl_is_concrete_type(fexpr_rt.typ) || jl_is_kind(fexpr_rt.typ)))
7637+
jl_svecset(sigt, 0, fexpr_val.typ);
7638+
if (!fexpr_val.constant && (!jl_is_concrete_type(fexpr_val.typ) || jl_is_kind(fexpr_val.typ)))
76397639
approx = true;
76407640
for (size_t i = 0; i < nargt; i++) {
76417641
jl_value_t *jargty = jl_svecref(argt, i);
@@ -7664,25 +7664,25 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
76647664
unionall_env = NULL;
76657665
}
76667666

7667-
bool nest = (!fexpr_rt.constant || unionall_env);
7667+
bool nest = (!fexpr_val.constant || unionall_env);
76687668
if (ctx.emission_context.TargetTriple.isAArch64() || ctx.emission_context.TargetTriple.isARM() || ctx.emission_context.TargetTriple.isPPC64()) {
76697669
if (nest) {
76707670
emit_error(ctx, "cfunction: closures are not supported on this platform");
76717671
JL_GC_POP();
76727672
return jl_cgval_t();
76737673
}
76747674
}
7675-
const char *name = derive_sigt_name(fexpr_rt.typ);
7675+
const char *name = derive_sigt_name(fexpr_val.typ);
76767676
Value *F = gen_cfun_wrapper(
76777677
jl_Module, ctx.emission_context,
7678-
sig, fexpr_rt.constant, name,
7678+
sig, fexpr_val.constant, name,
76797679
declrt, sigt,
76807680
unionall_env, sparam_vals, &closure_types);
76817681
bool outboxed;
76827682
if (nest) {
76837683
// F is actually an init_trampoline function that returns the real address
76847684
// Now fill in the nest parameters
7685-
Value *fobj = boxed(ctx, fexpr_rt);
7685+
Value *fobj = boxed(ctx, fexpr_val);
76867686
jl_svec_t *fill = jl_emptysvec;
76877687
if (closure_types) {
76887688
assert(ctx.spvals_ptr);
@@ -7722,7 +7722,7 @@ static jl_cgval_t emit_cfunction(jl_codectx_t &ctx, jl_value_t *output_type, con
77227722
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, tbaa);
77237723
ai.decorateInst(ctx.builder.CreateStore(F, derived_strct));
77247724
ai.decorateInst(ctx.builder.CreateStore(
7725-
ctx.builder.CreatePtrToInt(literal_pointer_val(ctx, fexpr_rt.constant), ctx.types().T_size),
7725+
ctx.builder.CreatePtrToInt(literal_pointer_val(ctx, fexpr_val.constant), ctx.types().T_size),
77267726
ctx.builder.CreateConstInBoundsGEP1_32(ctx.types().T_size, derived_strct, 1)));
77277727
ai.decorateInst(ctx.builder.CreateStore(Constant::getNullValue(ctx.types().T_size),
77287728
ctx.builder.CreateConstInBoundsGEP1_32(ctx.types().T_size, derived_strct, 2)));

test/trimming/basic_jll.jl

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,21 @@
11
using Libdl
22
using Zstd_jll # Note this uses the vendored older non-LazyLibrary version of Zstd_jll
33

4+
function print_string(fptr::Ptr{Cvoid})
5+
println(Core.stdout, unsafe_string(ccall(fptr, Cstring, ())))
6+
end
7+
48
function @main(args::Vector{String})::Cint
9+
# Test the basic "Hello, world!"
510
println(Core.stdout, "Julia! Hello, world!")
6-
fptr = dlsym(Zstd_jll.libzstd_handle, :ZSTD_versionString)
7-
println(Core.stdout, unsafe_string(ccall(fptr, Cstring, ())))
11+
12+
# Make sure that JLL's are working as expected
813
println(Core.stdout, unsafe_string(ccall((:ZSTD_versionString, libzstd), Cstring, ())))
14+
15+
# Add an indirection via `@cfunction`
16+
cfunc = @cfunction(print_string, Cvoid, (Ptr{Cvoid},))
17+
fptr = dlsym(Zstd_jll.libzstd_handle, :ZSTD_versionString)
18+
ccall(cfunc, Cvoid, (Ptr{Cvoid},), fptr)
19+
920
return 0
1021
end

0 commit comments

Comments
 (0)