Skip to content

Commit 441fcb1

Browse files
Merge opaque closure modules with the rest of the workqueue (#50724)
This sticks the compiled opaque closure module into the `compiled_functions` list of modules that we have compiled for the particular `jl_codegen_params_t`. We probably should manage that vector in codegen_params, since it lets us see if a particular codeinst has already been compiled but not yet emitted.
1 parent f4cb8bc commit 441fcb1

File tree

4 files changed

+78
-89
lines changed

4 files changed

+78
-89
lines changed

src/aotcompile.cpp

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
274274
jl_native_code_desc_t *data = new jl_native_code_desc_t;
275275
CompilationPolicy policy = (CompilationPolicy) _policy;
276276
bool imaging = imaging_default() || _imaging_mode == 1;
277-
jl_workqueue_t emitted;
278277
jl_method_instance_t *mi = NULL;
279278
jl_code_info_t *src = NULL;
280279
JL_GC_PUSH1(&src);
@@ -335,7 +334,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
335334
// find and prepare the source code to compile
336335
jl_code_instance_t *codeinst = NULL;
337336
jl_ci_cache_lookup(*cgparams, mi, params.world, &codeinst, &src);
338-
if (src && !emitted.count(codeinst)) {
337+
if (src && !params.compiled_functions.count(codeinst)) {
339338
// now add it to our compilation results
340339
JL_GC_PROMISE_ROOTED(codeinst->rettype);
341340
orc::ThreadSafeModule result_m = jl_create_ts_module(name_from_method_instance(codeinst->def),
@@ -344,13 +343,13 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
344343
Triple(clone.getModuleUnlocked()->getTargetTriple()));
345344
jl_llvm_functions_t decls = jl_emit_code(result_m, mi, src, codeinst->rettype, params);
346345
if (result_m)
347-
emitted[codeinst] = {std::move(result_m), std::move(decls)};
346+
params.compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
348347
}
349348
}
350349
}
351350

352351
// finally, make sure all referenced methods also get compiled or fixed up
353-
jl_compile_workqueue(emitted, *clone.getModuleUnlocked(), params, policy);
352+
jl_compile_workqueue(params, *clone.getModuleUnlocked(), policy);
354353
}
355354
JL_UNLOCK(&jl_codegen_lock); // Might GC
356355
JL_GC_POP();
@@ -369,7 +368,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
369368
data->jl_value_to_llvm[idx] = global.first;
370369
idx++;
371370
}
372-
CreateNativeMethods += emitted.size();
371+
CreateNativeMethods += params.compiled_functions.size();
373372

374373
size_t offset = gvars.size();
375374
data->jl_external_to_llvm.resize(params.external_fns.size());
@@ -394,7 +393,7 @@ void *jl_create_native_impl(jl_array_t *methods, LLVMOrcThreadSafeModuleRef llvm
394393
{
395394
JL_TIMING(NATIVE_AOT, NATIVE_Merge);
396395
Linker L(*clone.getModuleUnlocked());
397-
for (auto &def : emitted) {
396+
for (auto &def : params.compiled_functions) {
398397
jl_merge_module(clone, std::move(std::get<0>(def.second)));
399398
jl_code_instance_t *this_code = def.first;
400399
jl_llvm_functions_t decls = std::get<1>(def.second);

src/codegen.cpp

Lines changed: 48 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1616,7 +1616,6 @@ class jl_codectx_t {
16161616
std::vector<std::tuple<jl_cgval_t, BasicBlock *, AllocaInst *, PHINode *, jl_value_t *>> PhiNodes;
16171617
std::vector<bool> ssavalue_assigned;
16181618
std::vector<int> ssavalue_usecount;
1619-
std::vector<orc::ThreadSafeModule> oc_modules;
16201619
jl_module_t *module = NULL;
16211620
jl_typecache_t type_cache;
16221621
jl_tbaacache_t tbaa_cache;
@@ -4460,7 +4459,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
44604459
// Check if we already queued this up
44614460
auto it = ctx.call_targets.find(codeinst);
44624461
if (need_to_emit && it != ctx.call_targets.end()) {
4463-
protoname = std::get<2>(it->second)->getName();
4462+
protoname = it->second.decl->getName();
44644463
need_to_emit = cache_valid = false;
44654464
}
44664465

@@ -4504,7 +4503,7 @@ static jl_cgval_t emit_invoke(jl_codectx_t &ctx, const jl_cgval_t &lival, const
45044503
handled = true;
45054504
if (need_to_emit) {
45064505
Function *trampoline_decl = cast<Function>(jl_Module->getNamedValue(protoname));
4507-
ctx.call_targets[codeinst] = std::make_tuple(cc, return_roots, trampoline_decl, specsig);
4506+
ctx.call_targets[codeinst] = {cc, return_roots, trampoline_decl, specsig};
45084507
}
45094508
}
45104509
}
@@ -5369,8 +5368,7 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
53695368
{
53705369
jl_svec_t *sig_args = NULL;
53715370
jl_value_t *sigtype = NULL;
5372-
jl_code_info_t *ir = NULL;
5373-
JL_GC_PUSH3(&sig_args, &sigtype, &ir);
5371+
JL_GC_PUSH2(&sig_args, &sigtype);
53745372

53755373
size_t nsig = 1 + jl_svec_len(argt_typ->parameters);
53765374
sig_args = jl_alloc_svec_uninit(nsig);
@@ -5392,16 +5390,25 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
53925390
JL_GC_POP();
53935391
return std::make_pair((Function*)NULL, (Function*)NULL);
53945392
}
5395-
++EmittedOpaqueClosureFunctions;
53965393

5397-
ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
5394+
auto it = ctx.emission_context.compiled_functions.find(ci);
53985395

5399-
// TODO: Emit this inline and outline it late using LLVM's coroutine support.
5400-
orc::ThreadSafeModule closure_m = jl_create_ts_module(
5401-
name_from_method_instance(mi), ctx.emission_context.tsctx,
5402-
ctx.emission_context.imaging,
5403-
jl_Module->getDataLayout(), Triple(jl_Module->getTargetTriple()));
5404-
jl_llvm_functions_t closure_decls = emit_function(closure_m, mi, ir, rettype, ctx.emission_context);
5396+
if (it == ctx.emission_context.compiled_functions.end()) {
5397+
++EmittedOpaqueClosureFunctions;
5398+
jl_code_info_t *ir = jl_uncompress_ir(closure_method, ci, (jl_value_t*)inferred);
5399+
JL_GC_PUSH1(&ir);
5400+
// TODO: Emit this inline and outline it late using LLVM's coroutine support.
5401+
orc::ThreadSafeModule closure_m = jl_create_ts_module(
5402+
name_from_method_instance(mi), ctx.emission_context.tsctx,
5403+
ctx.emission_context.imaging,
5404+
jl_Module->getDataLayout(), Triple(jl_Module->getTargetTriple()));
5405+
jl_llvm_functions_t closure_decls = emit_function(closure_m, mi, ir, rettype, ctx.emission_context);
5406+
JL_GC_POP();
5407+
it = ctx.emission_context.compiled_functions.insert(std::make_pair(ci, std::make_pair(std::move(closure_m), std::move(closure_decls)))).first;
5408+
}
5409+
5410+
auto &closure_m = it->second.first;
5411+
auto &closure_decls = it->second.second;
54055412

54065413
assert(closure_decls.functionObject != "jl_fptr_sparam");
54075414
bool isspecsig = closure_decls.functionObject != "jl_fptr_args";
@@ -5432,7 +5439,6 @@ static std::pair<Function*, Function*> get_oc_function(jl_codectx_t &ctx, jl_met
54325439
specF = cast<Function>(returninfo.decl.getCallee());
54335440
}
54345441
}
5435-
ctx.oc_modules.push_back(std::move(closure_m));
54365442
JL_GC_POP();
54375443
return std::make_pair(F, specF);
54385444
}
@@ -5715,7 +5721,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
57155721
if (jl_is_concrete_type(env_t)) {
57165722
jl_tupletype_t *argt_typ = (jl_tupletype_t*)argt.constant;
57175723
Function *F, *specF;
5718-
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, (jl_datatype_t*)env_t, argt_typ, ub.constant);
5724+
std::tie(F, specF) = get_oc_function(ctx, (jl_method_t*)source.constant, (jl_tupletype_t*)env_t, argt_typ, ub.constant);
57195725
if (F) {
57205726
jl_cgval_t jlcall_ptr = mark_julia_type(ctx, F, false, jl_voidpointer_type);
57215727
jl_aliasinfo_t ai = jl_aliasinfo_t::fromTBAA(ctx, ctx.tbaa().tbaa_gcframe);
@@ -5725,7 +5731,7 @@ static jl_cgval_t emit_expr(jl_codectx_t &ctx, jl_value_t *expr, ssize_t ssaidx_
57255731
if (specF)
57265732
fptr = mark_julia_type(ctx, specF, false, jl_voidpointer_type);
57275733
else
5728-
fptr = mark_julia_type(ctx, (llvm::Value*)Constant::getNullValue(ctx.types().T_size), false, jl_voidpointer_type);
5734+
fptr = mark_julia_type(ctx, Constant::getNullValue(ctx.types().T_size), false, jl_voidpointer_type);
57295735

57305736
// TODO: Inline the env at the end of the opaque closure and generate a descriptor for GC
57315737
jl_cgval_t env = emit_new_struct(ctx, env_t, nargs-4, &argv.data()[4]);
@@ -8757,19 +8763,6 @@ static jl_llvm_functions_t
87578763
jl_Module->getFunction(FN)->setLinkage(GlobalVariable::InternalLinkage);
87588764
}
87598765

8760-
// link in opaque closure modules
8761-
for (auto &TSMod : ctx.oc_modules) {
8762-
SmallVector<std::string, 1> Exports;
8763-
TSMod.withModuleDo([&](Module &Mod) {
8764-
for (const auto &F: Mod.functions())
8765-
if (!F.isDeclaration())
8766-
Exports.push_back(F.getName().str());
8767-
});
8768-
jl_merge_module(TSM, std::move(TSMod));
8769-
for (auto FN: Exports)
8770-
jl_Module->getFunction(FN)->setLinkage(GlobalVariable::InternalLinkage);
8771-
}
8772-
87738766
JL_GC_POP();
87748767
return declarations;
87758768
}
@@ -8931,22 +8924,18 @@ jl_llvm_functions_t jl_emit_codeinst(
89318924

89328925

89338926
void jl_compile_workqueue(
8934-
jl_workqueue_t &emitted,
8927+
jl_codegen_params_t &params,
89358928
Module &original,
8936-
jl_codegen_params_t &params, CompilationPolicy policy)
8929+
CompilationPolicy policy)
89378930
{
89388931
JL_TIMING(CODEGEN, CODEGEN_Workqueue);
89398932
jl_code_info_t *src = NULL;
89408933
JL_GC_PUSH1(&src);
89418934
while (!params.workqueue.empty()) {
89428935
jl_code_instance_t *codeinst;
8943-
Function *protodecl;
8944-
jl_returninfo_t::CallingConv proto_cc;
8945-
bool proto_specsig;
8946-
unsigned proto_return_roots;
89478936
auto it = params.workqueue.back();
89488937
codeinst = it.first;
8949-
std::tie(proto_cc, proto_return_roots, protodecl, proto_specsig) = it.second;
8938+
auto proto = it.second;
89508939
params.workqueue.pop_back();
89518940
// try to emit code for this item from the workqueue
89528941
assert(codeinst->min_world <= params.world && codeinst->max_world >= params.world &&
@@ -8974,12 +8963,8 @@ void jl_compile_workqueue(
89748963
}
89758964
}
89768965
else {
8977-
auto &result = emitted[codeinst];
8978-
jl_llvm_functions_t *decls = NULL;
8979-
if (std::get<0>(result)) {
8980-
decls = &std::get<1>(result);
8981-
}
8982-
else {
8966+
auto it = params.compiled_functions.find(codeinst);
8967+
if (it == params.compiled_functions.end()) {
89838968
// Reinfer the function. The JIT came along and removed the inferred
89848969
// method body. See #34993
89858970
if (policy != CompilationPolicy::Default &&
@@ -8990,47 +8975,46 @@ void jl_compile_workqueue(
89908975
jl_create_ts_module(name_from_method_instance(codeinst->def),
89918976
params.tsctx, params.imaging,
89928977
original.getDataLayout(), Triple(original.getTargetTriple()));
8993-
result.second = jl_emit_code(result_m, codeinst->def, src, src->rettype, params);
8994-
result.first = std::move(result_m);
8978+
auto decls = jl_emit_code(result_m, codeinst->def, src, src->rettype, params);
8979+
if (result_m)
8980+
it = params.compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
89958981
}
89968982
}
89978983
else {
89988984
orc::ThreadSafeModule result_m =
89998985
jl_create_ts_module(name_from_method_instance(codeinst->def),
90008986
params.tsctx, params.imaging,
90018987
original.getDataLayout(), Triple(original.getTargetTriple()));
9002-
result.second = jl_emit_codeinst(result_m, codeinst, NULL, params);
9003-
result.first = std::move(result_m);
8988+
auto decls = jl_emit_codeinst(result_m, codeinst, NULL, params);
8989+
if (result_m)
8990+
it = params.compiled_functions.insert(std::make_pair(codeinst, std::make_pair(std::move(result_m), std::move(decls)))).first;
90048991
}
9005-
if (std::get<0>(result))
9006-
decls = &std::get<1>(result);
9007-
else
9008-
emitted.erase(codeinst); // undo the insert above
90098992
}
9010-
if (decls) {
9011-
if (decls->functionObject == "jl_fptr_args") {
9012-
preal_decl = decls->specFunctionObject;
8993+
if (it != params.compiled_functions.end()) {
8994+
auto &decls = it->second.second;
8995+
if (decls.functionObject == "jl_fptr_args") {
8996+
preal_decl = decls.specFunctionObject;
90138997
}
9014-
else if (decls->functionObject != "jl_fptr_sparam") {
9015-
preal_decl = decls->specFunctionObject;
8998+
else if (decls.functionObject != "jl_fptr_sparam") {
8999+
preal_decl = decls.specFunctionObject;
90169000
preal_specsig = true;
90179001
}
90189002
}
90199003
}
90209004
// patch up the prototype we emitted earlier
9021-
Module *mod = protodecl->getParent();
9022-
assert(protodecl->isDeclaration());
9023-
if (proto_specsig) {
9005+
Module *mod = proto.decl->getParent();
9006+
assert(proto.decl->isDeclaration());
9007+
if (proto.specsig) {
90249008
// expected specsig
90259009
if (!preal_specsig) {
90269010
// emit specsig-to-(jl)invoke conversion
90279011
Function *preal = emit_tojlinvoke(codeinst, mod, params);
9028-
protodecl->setLinkage(GlobalVariable::InternalLinkage);
9012+
proto.decl->setLinkage(GlobalVariable::InternalLinkage);
90299013
//protodecl->setAlwaysInline();
9030-
jl_init_function(protodecl, params.TargetTriple);
9014+
jl_init_function(proto.decl, params.TargetTriple);
90319015
size_t nrealargs = jl_nparams(codeinst->def->specTypes); // number of actual arguments being passed
90329016
// TODO: maybe this can be cached in codeinst->specfptr?
9033-
emit_cfunc_invalidate(protodecl, proto_cc, proto_return_roots, codeinst->def->specTypes, codeinst->rettype, false, nrealargs, params, preal);
9017+
emit_cfunc_invalidate(proto.decl, proto.cc, proto.return_roots, codeinst->def->specTypes, codeinst->rettype, false, nrealargs, params, preal);
90349018
preal_decl = ""; // no need to fixup the name
90359019
}
90369020
else {
@@ -9047,11 +9031,11 @@ void jl_compile_workqueue(
90479031
if (!preal_decl.empty()) {
90489032
// merge and/or rename this prototype to the real function
90499033
if (Value *specfun = mod->getNamedValue(preal_decl)) {
9050-
if (protodecl != specfun)
9051-
protodecl->replaceAllUsesWith(specfun);
9034+
if (proto.decl != specfun)
9035+
proto.decl->replaceAllUsesWith(specfun);
90529036
}
90539037
else {
9054-
protodecl->setName(preal_decl);
9038+
proto.decl->setName(preal_decl);
90559039
}
90569040
}
90579041
}

src/jitlayers.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -215,16 +215,15 @@ static jl_callptr_t _jl_compile_codeinst(
215215
params.world = world;
216216
params.imaging = imaging_default();
217217
params.debug_level = jl_options.debug_level;
218-
jl_workqueue_t emitted;
219218
{
220219
orc::ThreadSafeModule result_m =
221220
jl_create_ts_module(name_from_method_instance(codeinst->def), params.tsctx, params.imaging, params.DL, params.TargetTriple);
222221
jl_llvm_functions_t decls = jl_emit_codeinst(result_m, codeinst, src, params);
223222
if (result_m)
224-
emitted[codeinst] = {std::move(result_m), std::move(decls)};
223+
params.compiled_functions[codeinst] = {std::move(result_m), std::move(decls)};
225224
{
226225
auto temp_module = jl_create_llvm_module(name_from_method_instance(codeinst->def), params.getContext(), params.imaging);
227-
jl_compile_workqueue(emitted, *temp_module, params, CompilationPolicy::Default);
226+
jl_compile_workqueue(params, *temp_module, CompilationPolicy::Default);
228227
}
229228

230229
if (params._shared_module)
@@ -241,7 +240,7 @@ static jl_callptr_t _jl_compile_codeinst(
241240
for (auto &global : params.global_targets) {
242241
NewGlobals[global.second->getName()] = global.first;
243242
}
244-
for (auto &def : emitted) {
243+
for (auto &def : params.compiled_functions) {
245244
auto M = std::get<0>(def.second).getModuleUnlocked();
246245
for (auto &GV : M->globals()) {
247246
auto InitValue = NewGlobals.find(GV.getName());
@@ -252,14 +251,14 @@ static jl_callptr_t _jl_compile_codeinst(
252251
}
253252
}
254253

255-
// Collect the exported functions from the emitted modules,
254+
// Collect the exported functions from the params.compiled_functions modules,
256255
// which form dependencies on which functions need to be
257256
// compiled first. Cycles of functions are compiled together.
258257
// (essentially we compile a DAG of SCCs in reverse topological order,
259258
// if we treat declarations of external functions as edges from declaration
260259
// to definition)
261260
StringMap<orc::ThreadSafeModule*> NewExports;
262-
for (auto &def : emitted) {
261+
for (auto &def : params.compiled_functions) {
263262
orc::ThreadSafeModule &TSM = std::get<0>(def.second);
264263
//The underlying context object is still locked because params is not destroyed yet
265264
auto M = TSM.getModuleUnlocked();
@@ -271,19 +270,19 @@ static jl_callptr_t _jl_compile_codeinst(
271270
}
272271
DenseMap<orc::ThreadSafeModule*, int> Queued;
273272
std::vector<orc::ThreadSafeModule*> Stack;
274-
for (auto &def : emitted) {
273+
for (auto &def : params.compiled_functions) {
275274
// Add the results to the execution engine now
276275
orc::ThreadSafeModule &M = std::get<0>(def.second);
277276
jl_add_to_ee(M, NewExports, Queued, Stack);
278277
assert(Queued.empty() && Stack.empty() && !M);
279278
}
280279
++CompiledCodeinsts;
281-
MaxWorkqueueSize.updateMax(emitted.size());
282-
IndirectCodeinsts += emitted.size() - 1;
280+
MaxWorkqueueSize.updateMax(params.compiled_functions.size());
281+
IndirectCodeinsts += params.compiled_functions.size() - 1;
283282
}
284283

285284
size_t i = 0;
286-
for (auto &def : emitted) {
285+
for (auto &def : params.compiled_functions) {
287286
jl_code_instance_t *this_code = def.first;
288287
if (i < jl_timing_print_limit)
289288
jl_timing_show_func_sig(this_code->def->specTypes, JL_TIMING_DEFAULT_BLOCK);

0 commit comments

Comments
 (0)