Skip to content

Commit 5cba88c

Browse files
authored
allow ccall library name to be non-constant (#37123)
Fixes #36458
1 parent 0bff5bf commit 5cba88c

File tree

10 files changed

+148
-28
lines changed

10 files changed

+148
-28
lines changed

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@ New language features
1010
and associativity as other arrow-like operators ([#36666]).
1111
* Compilation and type inference can now be enabled or disabled at the module level
1212
using the experimental macro `Base.Experimental.@compiler_options` ([#37041]).
13+
* The library name passed to `ccall` or `@ccall` can now be an expression involving
14+
global variables and function calls. The expression will be evaluated the first
15+
time the `ccall` executes ([#36458]).
1316

1417
Language changes
1518
----------------

doc/src/manual/calling-c-and-fortran-code.md

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -878,7 +878,15 @@ it must be handled in other ways.
878878

879879
## Non-constant Function Specifications
880880

881-
A `(name, library)` function specification must be a constant expression. However, it is possible
881+
In some cases, the exact name or path of the needed library is not known in advance and must
882+
be computed at run time. To handle such cases, the library component of a `(name, library)`
883+
specification can be a function call, e.g. `(:dgemm_, find_blas())`. The call expression will
884+
be executed when the `ccall` itself is executed. However, it is assumed that the library
885+
location does not change once it is determined, so the result of the call can be cached and
886+
reused. Therefore, the number of times the expression executes is undefined, and returning
887+
different values for multiple calls results in undefined behavior.
888+
889+
If even more flexibility is needed, it is possible
882890
to use computed values as function names by staging through [`eval`](@ref) as follows:
883891

884892
```

src/ast.scm

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -355,6 +355,9 @@
355355
(define (ssavalue? e)
356356
(and (pair? e) (eq? (car e) 'ssavalue)))
357357

358+
(define (slot? e)
359+
(and (pair? e) (eq? (car e) 'slot)))
360+
358361
(define (globalref? e)
359362
(and (pair? e) (eq? (car e) 'globalref)))
360363

@@ -439,6 +442,11 @@
439442
(let ((x (cadr e)))
440443
(not (simple-atom? x)))))
441444

445+
(define (tuple-call? e)
446+
(and (length> e 1)
447+
(eq? (car e) 'call)
448+
(equal? (cadr e) '(core tuple))))
449+
442450
(define (eq-sym? a b)
443451
(or (eq? a b) (and (ssavalue? a) (ssavalue? b) (eqv? (cdr a) (cdr b)))))
444452

src/ccall.cpp

Lines changed: 83 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,8 @@ static bool runtime_sym_gvs(jl_codegen_params_t &emission_context, const char *f
7373
static Value *runtime_sym_lookup(
7474
jl_codegen_params_t &emission_context,
7575
IRBuilder<> &irbuilder,
76-
PointerType *funcptype, const char *f_lib,
76+
jl_codectx_t *ctx,
77+
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
7778
const char *f_name, Function *f,
7879
GlobalVariable *libptrgv,
7980
GlobalVariable *llvmgv, bool runtime_lib)
@@ -106,16 +107,25 @@ static Value *runtime_sym_lookup(
106107
assert(f->getParent() != NULL);
107108
f->getBasicBlockList().push_back(dlsym_lookup);
108109
irbuilder.SetInsertPoint(dlsym_lookup);
109-
Value *libname;
110-
if (runtime_lib) {
111-
libname = stringConstPtr(emission_context, irbuilder, f_lib);
110+
Instruction *llvmf;
111+
Value *nameval = stringConstPtr(emission_context, irbuilder, f_name);
112+
if (lib_expr) {
113+
jl_cgval_t libval = emit_expr(*ctx, lib_expr);
114+
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jllazydlsym_func),
115+
{ boxed(*ctx, libval), nameval });
112116
}
113117
else {
114-
// f_lib is actually one of the special sentinel values
115-
libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8);
118+
Value *libname;
119+
if (runtime_lib) {
120+
libname = stringConstPtr(emission_context, irbuilder, f_lib);
121+
}
122+
else {
123+
// f_lib is actually one of the special sentinel values
124+
libname = ConstantExpr::getIntToPtr(ConstantInt::get(T_size, (uintptr_t)f_lib), T_pint8);
125+
}
126+
llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
127+
{ libname, nameval, libptrgv });
116128
}
117-
Value *llvmf = irbuilder.CreateCall(prepare_call_in(jl_builderModule(irbuilder), jldlsym_func),
118-
{ libname, stringConstPtr(emission_context, irbuilder, f_name), libptrgv });
119129
StoreInst *store = irbuilder.CreateAlignedStore(llvmf, llvmgv, Align(sizeof(void*)));
120130
store->setAtomic(AtomicOrdering::Release);
121131
irbuilder.CreateBr(ccall_bb);
@@ -124,21 +134,49 @@ static Value *runtime_sym_lookup(
124134
irbuilder.SetInsertPoint(ccall_bb);
125135
PHINode *p = irbuilder.CreatePHI(T_pvoidfunc, 2);
126136
p->addIncoming(llvmf_orig, enter_bb);
127-
p->addIncoming(llvmf, dlsym_lookup);
137+
p->addIncoming(llvmf, llvmf->getParent());
128138
return irbuilder.CreateBitCast(p, funcptype);
129139
}
130140

131141
static Value *runtime_sym_lookup(
132142
jl_codectx_t &ctx,
133-
PointerType *funcptype, const char *f_lib,
143+
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
144+
const char *f_name, Function *f,
145+
GlobalVariable *libptrgv,
146+
GlobalVariable *llvmgv, bool runtime_lib)
147+
{
148+
return runtime_sym_lookup(ctx.emission_context, ctx.builder, &ctx, funcptype, f_lib, lib_expr,
149+
f_name, f, libptrgv, llvmgv, runtime_lib);
150+
}
151+
152+
static Value *runtime_sym_lookup(
153+
jl_codectx_t &ctx,
154+
PointerType *funcptype, const char *f_lib, jl_value_t *lib_expr,
134155
const char *f_name, Function *f)
135156
{
136157
GlobalVariable *libptrgv;
137158
GlobalVariable *llvmgv;
138-
bool runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv);
139-
libptrgv = prepare_global_in(jl_Module, libptrgv);
159+
bool runtime_lib;
160+
if (lib_expr) {
161+
// for computed library names, generate a global variable to cache the function
162+
// pointer just for this call site.
163+
runtime_lib = true;
164+
libptrgv = NULL;
165+
std::string gvname = "libname_";
166+
gvname += f_name;
167+
gvname += "_";
168+
gvname += std::to_string(globalUnique++);
169+
Module *M = ctx.emission_context.shared_module(jl_LLVMContext);
170+
llvmgv = new GlobalVariable(*M, T_pvoidfunc, false,
171+
GlobalVariable::ExternalLinkage,
172+
Constant::getNullValue(T_pvoidfunc), gvname);
173+
}
174+
else {
175+
runtime_lib = runtime_sym_gvs(ctx.emission_context, f_lib, f_name, libptrgv, llvmgv);
176+
libptrgv = prepare_global_in(jl_Module, libptrgv);
177+
}
140178
llvmgv = prepare_global_in(jl_Module, llvmgv);
141-
return runtime_sym_lookup(ctx.emission_context, ctx.builder, funcptype, f_lib, f_name, f, libptrgv, llvmgv, runtime_lib);
179+
return runtime_sym_lookup(ctx, funcptype, f_lib, lib_expr, f_name, f, libptrgv, llvmgv, runtime_lib);
142180
}
143181

144182
// Emit a "PLT" entry that will be lazily initialized
@@ -169,7 +207,7 @@ static GlobalVariable *emit_plt_thunk(
169207
fname);
170208
BasicBlock *b0 = BasicBlock::Create(jl_LLVMContext, "top", plt);
171209
IRBuilder<> irbuilder(b0);
172-
Value *ptr = runtime_sym_lookup(emission_context, irbuilder, funcptype, f_lib, f_name, plt, libptrgv,
210+
Value *ptr = runtime_sym_lookup(emission_context, irbuilder, NULL, funcptype, f_lib, NULL, f_name, plt, libptrgv,
173211
llvmgv, runtime_lib);
174212
StoreInst *store = irbuilder.CreateAlignedStore(irbuilder.CreateBitCast(ptr, T_pvoidfunc), got, Align(sizeof(void*)));
175213
store->setAtomic(AtomicOrdering::Release);
@@ -475,6 +513,7 @@ typedef struct {
475513
void (*fptr)(void); // if the argument is a constant pointer
476514
const char *f_name; // if the symbol name is known
477515
const char *f_lib; // if a library name is specified
516+
jl_value_t *lib_expr; // expression to compute library path lazily
478517
jl_value_t *gcroot;
479518
} native_sym_arg_t;
480519

@@ -488,6 +527,24 @@ static void interpret_symbol_arg(jl_codectx_t &ctx, native_sym_arg_t &out, jl_va
488527

489528
jl_value_t *ptr = static_eval(ctx, arg);
490529
if (ptr == NULL) {
530+
if (jl_is_expr(arg) && ((jl_expr_t*)arg)->head == call_sym && jl_expr_nargs(arg) == 3 &&
531+
jl_is_globalref(jl_exprarg(arg,0)) && jl_globalref_mod(jl_exprarg(arg,0)) == jl_core_module &&
532+
jl_globalref_name(jl_exprarg(arg,0)) == jl_symbol("tuple")) {
533+
// attempt to interpret a non-constant 2-tuple expression as (func_name, lib_name()), where
534+
// `lib_name()` will be executed when first used.
535+
jl_value_t *name_val = static_eval(ctx, jl_exprarg(arg,1));
536+
if (name_val && jl_is_symbol(name_val)) {
537+
f_name = jl_symbol_name((jl_sym_t*)name_val);
538+
out.lib_expr = jl_exprarg(arg, 2);
539+
return;
540+
}
541+
else if (name_val && jl_is_string(name_val)) {
542+
f_name = jl_string_data(name_val);
543+
out.gcroot = name_val;
544+
out.lib_expr = jl_exprarg(arg, 2);
545+
return;
546+
}
547+
}
491548
jl_cgval_t arg1 = emit_expr(ctx, arg);
492549
jl_value_t *ptr_ty = arg1.typ;
493550
if (!jl_is_cpointer_type(ptr_ty)) {
@@ -586,8 +643,11 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
586643
jl_printf(JL_STDERR,"WARNING: literal address used in cglobal for %s; code cannot be statically compiled\n", sym.f_name);
587644
}
588645
else {
589-
if (imaging_mode) {
590-
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, sym.f_name, ctx.f);
646+
if (sym.lib_expr) {
647+
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), NULL, sym.lib_expr, sym.f_name, ctx.f);
648+
}
649+
else if (imaging_mode) {
650+
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f);
591651
res = ctx.builder.CreatePtrToInt(res, lrt);
592652
}
593653
else {
@@ -597,7 +657,7 @@ static jl_cgval_t emit_cglobal(jl_codectx_t &ctx, jl_value_t **args, size_t narg
597657
if (!libsym || !jl_dlsym(libsym, sym.f_name, &symaddr, 0)) {
598658
// Error mode, either the library or the symbol couldn't be find during compiletime.
599659
// Fallback to a runtime symbol lookup.
600-
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, sym.f_name, ctx.f);
660+
res = runtime_sym_lookup(ctx, cast<PointerType>(T_pint8), sym.f_lib, NULL, sym.f_name, ctx.f);
601661
res = ctx.builder.CreatePtrToInt(res, lrt);
602662
} else {
603663
// since we aren't saving this code, there's no sense in
@@ -1737,11 +1797,14 @@ jl_cgval_t function_sig_t::emit_a_ccall(
17371797
else {
17381798
assert(symarg.f_name != NULL);
17391799
PointerType *funcptype = PointerType::get(functype, 0);
1740-
if (imaging_mode) {
1800+
if (symarg.lib_expr) {
1801+
llvmf = runtime_sym_lookup(ctx, funcptype, NULL, symarg.lib_expr, symarg.f_name, ctx.f);
1802+
}
1803+
else if (imaging_mode) {
17411804
// vararg requires musttail,
17421805
// but musttail is incompatible with noreturn.
17431806
if (functype->isVarArg())
1744-
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f);
1807+
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
17451808
else
17461809
llvmf = emit_plt(ctx, functype, attributes, cc, symarg.f_lib, symarg.f_name);
17471810
}
@@ -1751,7 +1814,7 @@ jl_cgval_t function_sig_t::emit_a_ccall(
17511814
if (!libsym || !jl_dlsym(libsym, symarg.f_name, &symaddr, 0)) {
17521815
// either the library or the symbol could not be found, place a runtime
17531816
// lookup here instead.
1754-
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, symarg.f_name, ctx.f);
1817+
llvmf = runtime_sym_lookup(ctx, funcptype, symarg.f_lib, NULL, symarg.f_name, ctx.f);
17551818
} else {
17561819
// since we aren't saving this code, there's no sense in
17571820
// putting anything complicated here: just JIT the function address

src/codegen.cpp

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,12 @@ static const auto jldlsym_func = new JuliaFunction{
704704
{T_pint8, T_pint8, PointerType::get(T_pint8, 0)}, false); },
705705
nullptr,
706706
};
707+
static const auto jllazydlsym_func = new JuliaFunction{
708+
"jl_lazy_load_and_lookup",
709+
[](LLVMContext &C) { return FunctionType::get(T_pvoidfunc,
710+
{T_prjlvalue, T_pint8}, false); },
711+
nullptr,
712+
};
707713
static const auto jltypeassert_func = new JuliaFunction{
708714
"jl_typeassert",
709715
[](LLVMContext &C) { return FunctionType::get(T_void,

src/julia-syntax.scm

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3906,11 +3906,9 @@ f(x) = yt(x)
39063906
(cond ((eq? (car e) 'foreigncall)
39073907
;; NOTE: 2nd to 5th arguments of ccall must be left in place
39083908
;; the 1st should be compiled if an atom.
3909-
(append (if (or (atom? (cadr e))
3910-
(let ((fptr (cadr e)))
3911-
(not (and (length> fptr 1)
3912-
(eq? (car fptr) 'call)
3913-
(equal? (cadr fptr) '(core tuple))))))
3909+
(append (if (let ((fptr (cadr e)))
3910+
(or (atom? fptr)
3911+
(not (tuple-call? fptr))))
39143912
(compile-args (list (cadr e)) break-labels)
39153913
(list (cadr e)))
39163914
(list-head (cddr e) 4)
@@ -4466,8 +4464,14 @@ f(x) = yt(x)
44664464
`(gotoifnot ,(renumber-stuff (cadr e)) ,(get label-table (caddr e))))
44674465
((eq? (car e) 'lambda)
44684466
(renumber-lambda e 'none 0))
4469-
(else (cons (car e)
4470-
(map renumber-stuff (cdr e))))))
4467+
(else
4468+
(let ((e (cons (car e)
4469+
(map renumber-stuff (cdr e)))))
4470+
(if (and (eq? (car e) 'foreigncall)
4471+
(tuple-call? (cadr e))
4472+
(expr-contains-p (lambda (x) (or (ssavalue? x) (slot? x))) (cadr e)))
4473+
(error "ccall function name and library expression cannot reference local variables"))
4474+
e))))
44714475
(let ((body (renumber-stuff (lam:body lam)))
44724476
(vi (lam:vinfo lam)))
44734477
(listify-lambda

src/julia_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,7 @@ void *jl_get_library_(const char *f_lib, int throw_err) JL_NOTSAFEPOINT;
988988
#define jl_get_library(f_lib) jl_get_library_(f_lib, 1)
989989
JL_DLLEXPORT void *jl_load_and_lookup(const char *f_lib, const char *f_name,
990990
void **hnd) JL_NOTSAFEPOINT;
991+
JL_DLLEXPORT void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name);
991992
JL_DLLEXPORT jl_value_t *jl_get_cfunction_trampoline(
992993
jl_value_t *fobj, jl_datatype_t *result, htable_t *cache, jl_svec_t *fill,
993994
void *(*init_trampoline)(void *tramp, void **nval),

src/runtime_ccall.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,23 @@ void *jl_load_and_lookup(const char *f_lib, const char *f_name, void **hnd) JL_N
6464
return ptr;
6565
}
6666

67+
// jl_load_and_lookup, but with library computed at run time on first call
68+
extern "C" JL_DLLEXPORT
69+
void *jl_lazy_load_and_lookup(jl_value_t *lib_val, const char *f_name)
70+
{
71+
char *f_lib;
72+
73+
if (jl_is_symbol(lib_val))
74+
f_lib = jl_symbol_name((jl_sym_t*)lib_val);
75+
else if (jl_is_string(lib_val))
76+
f_lib = jl_string_data(lib_val);
77+
else
78+
jl_type_error("ccall", (jl_value_t*)jl_symbol_type, lib_val);
79+
void *ptr;
80+
jl_dlsym(jl_get_library(f_lib), f_name, &ptr, 1);
81+
return ptr;
82+
}
83+
6784
// miscellany
6885
std::string jl_get_cpu_name_llvm(void)
6986
{

test/ccall.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1701,3 +1701,11 @@ end
17011701
str = GC.@preserve buffer unsafe_string(Cwstring(pointer(buffer)))
17021702
@test str == "α+β=15"
17031703
end
1704+
1705+
# issue #36458
1706+
compute_lib_name() = "libcc" * "alltest"
1707+
ccall_lazy_lib_name(x) = ccall((:testUcharX, compute_lib_name()), Int32, (UInt8,), x % UInt8)
1708+
@test ccall_lazy_lib_name(0) == 0
1709+
@test ccall_lazy_lib_name(3) == 1
1710+
ccall_with_undefined_lib() = ccall((:time, xx_nOt_DeFiNeD_xx), Cint, (Ptr{Cvoid},), C_NULL)
1711+
@test_throws UndefVarError(:xx_nOt_DeFiNeD_xx) ccall_with_undefined_lib()

test/syntax.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1656,6 +1656,8 @@ end
16561656
# #6080
16571657
@test Meta.lower(@__MODULE__, :(ccall(:a, Cvoid, (Cint,), &x))) == Expr(:error, "invalid syntax &x")
16581658

1659+
@test Meta.lower(@__MODULE__, :(f(x) = (y = x + 1; ccall((:a, y), Cvoid, ())))) == Expr(:error, "ccall function name and library expression cannot reference local variables")
1660+
16591661
@test_throws ParseError Meta.parse("x.'")
16601662
@test_throws ParseError Meta.parse("0.+1")
16611663

0 commit comments

Comments
 (0)