Skip to content

Commit 3ea80fb

Browse files
authored
Add a macro to opt into aggressive constprop (#38080)
Right now aggressive constprop is essentially tied to the inlining threshold (or to their name being `getproperty` or `setproperty!` respectively, which can be both somewhat brittle if the inlining cost changes and insufficient when you do really know that const prop would be beneficial even if the function is not inlineable. This adds a simple macro that can be used to manually annotate methods to force aggressive constprop on them.
1 parent debf26e commit 3ea80fb

File tree

11 files changed

+67
-11
lines changed

11 files changed

+67
-11
lines changed

base/compiler/abstractinterpretation.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -288,7 +288,7 @@ function abstract_call_method_with_const_args(interp::AbstractInterpreter, @nosp
288288
istopfunction(f, :<<) || istopfunction(f, :>>))
289289
return Any
290290
end
291-
force_inference = allconst || InferenceParams(interp).aggressive_constant_propagation
291+
force_inference = allconst || method.aggressive_constprop || InferenceParams(interp).aggressive_constant_propagation
292292
if istopfunction(f, :getproperty) || istopfunction(f, :setproperty!)
293293
force_inference = true
294294
end

base/expr.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,19 @@ macro pure(ex)
242242
esc(isa(ex, Expr) ? pushmeta!(ex, :pure) : ex)
243243
end
244244

245+
"""
246+
@aggressive_constprop ex
247+
@aggressive_constprop(ex)
248+
249+
`@aggressive_constprop` requests more aggressive interprocedural constant
250+
propagation for the annotated function. For a method where the return type
251+
depends on the value of the arguments, this can yield improved inference results
252+
at the cost of additional compile time.
253+
"""
254+
macro aggressive_constprop(ex)
255+
esc(isa(ex, Expr) ? pushmeta!(ex, :aggressive_constprop) : ex)
256+
end
257+
245258
"""
246259
@propagate_inbounds
247260

src/ast.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ jl_sym_t *static_parameter_sym; jl_sym_t *inline_sym;
5858
jl_sym_t *noinline_sym; jl_sym_t *generated_sym;
5959
jl_sym_t *generated_only_sym; jl_sym_t *isdefined_sym;
6060
jl_sym_t *propagate_inbounds_sym; jl_sym_t *specialize_sym;
61+
jl_sym_t *aggressive_constprop_sym;
6162
jl_sym_t *nospecialize_sym; jl_sym_t *macrocall_sym;
6263
jl_sym_t *colon_sym; jl_sym_t *hygienicscope_sym;
6364
jl_sym_t *throw_undef_if_not_sym; jl_sym_t *getfield_undefref_sym;
@@ -385,6 +386,7 @@ void jl_init_common_symbols(void)
385386
noinline_sym = jl_symbol("noinline");
386387
polly_sym = jl_symbol("polly");
387388
propagate_inbounds_sym = jl_symbol("propagate_inbounds");
389+
aggressive_constprop_sym = jl_symbol("aggressive_constprop");
388390
isdefined_sym = jl_symbol("isdefined");
389391
nospecialize_sym = jl_symbol("nospecialize");
390392
specialize_sym = jl_symbol("specialize");

src/dump.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -594,6 +594,7 @@ static void jl_serialize_value_(jl_serializer_state *s, jl_value_t *v, int as_li
594594
write_int8(s->s, m->isva);
595595
write_int8(s->s, m->pure);
596596
write_int8(s->s, m->is_for_opaque_closure);
597+
write_int8(s->s, m->aggressive_constprop);
597598
jl_serialize_value(s, (jl_value_t*)m->slot_syms);
598599
jl_serialize_value(s, (jl_value_t*)m->roots);
599600
jl_serialize_value(s, (jl_value_t*)m->ccallable);
@@ -1442,6 +1443,7 @@ static jl_value_t *jl_deserialize_value_method(jl_serializer_state *s, jl_value_
14421443
m->isva = read_int8(s->s);
14431444
m->pure = read_int8(s->s);
14441445
m->is_for_opaque_closure = read_int8(s->s);
1446+
m->aggressive_constprop = read_int8(s->s);
14451447
m->slot_syms = jl_deserialize_value(s, (jl_value_t**)&m->slot_syms);
14461448
jl_gc_wb(m, m->slot_syms);
14471449
m->roots = (jl_array_t*)jl_deserialize_value(s, (jl_value_t**)&m->roots);

src/ircode.c

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -702,7 +702,8 @@ JL_DLLEXPORT jl_array_t *jl_compress_ir(jl_method_t *m, jl_code_info_t *code)
702702
jl_get_ptls_states()
703703
};
704704

705-
uint8_t flags = (code->inferred << 3)
705+
uint8_t flags = (code->aggressive_constprop << 4)
706+
| (code->inferred << 3)
706707
| (code->inlineable << 2)
707708
| (code->propagate_inbounds << 1)
708709
| (code->pure << 0);
@@ -787,6 +788,7 @@ JL_DLLEXPORT jl_code_info_t *jl_uncompress_ir(jl_method_t *m, jl_code_instance_t
787788

788789
jl_code_info_t *code = jl_new_code_info_uninit();
789790
uint8_t flags = read_uint8(s.s);
791+
code->aggressive_constprop = !!(flags & (1 << 4));
790792
code->inferred = !!(flags & (1 << 3));
791793
code->inlineable = !!(flags & (1 << 2));
792794
code->propagate_inbounds = !!(flags & (1 << 1));

src/jltypes.c

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2150,7 +2150,7 @@ void jl_init_types(void) JL_GC_DISABLED
21502150
jl_code_info_type =
21512151
jl_new_datatype(jl_symbol("CodeInfo"), core,
21522152
jl_any_type, jl_emptysvec,
2153-
jl_perm_symsvec(18,
2153+
jl_perm_symsvec(19,
21542154
"code",
21552155
"codelocs",
21562156
"ssavaluetypes",
@@ -2168,8 +2168,9 @@ void jl_init_types(void) JL_GC_DISABLED
21682168
"inferred",
21692169
"inlineable",
21702170
"propagate_inbounds",
2171-
"pure"),
2172-
jl_svec(18,
2171+
"pure",
2172+
"aggressive_constprop"),
2173+
jl_svec(19,
21732174
jl_array_any_type,
21742175
jl_array_int32_type,
21752176
jl_any_type,
@@ -2187,13 +2188,14 @@ void jl_init_types(void) JL_GC_DISABLED
21872188
jl_bool_type,
21882189
jl_bool_type,
21892190
jl_bool_type,
2191+
jl_bool_type,
21902192
jl_bool_type),
2191-
0, 1, 18);
2193+
0, 1, 19);
21922194

21932195
jl_method_type =
21942196
jl_new_datatype(jl_symbol("Method"), core,
21952197
jl_any_type, jl_emptysvec,
2196-
jl_perm_symsvec(23,
2198+
jl_perm_symsvec(24,
21972199
"name",
21982200
"module",
21992201
"file",
@@ -2216,8 +2218,9 @@ void jl_init_types(void) JL_GC_DISABLED
22162218
"nkw",
22172219
"isva",
22182220
"pure",
2219-
"is_for_opaque_closure"),
2220-
jl_svec(23,
2221+
"is_for_opaque_closure",
2222+
"aggressive_constprop"),
2223+
jl_svec(24,
22212224
jl_symbol_type,
22222225
jl_module_type,
22232226
jl_symbol_type,
@@ -2240,6 +2243,7 @@ void jl_init_types(void) JL_GC_DISABLED
22402243
jl_int32_type,
22412244
jl_bool_type,
22422245
jl_bool_type,
2246+
jl_bool_type,
22432247
jl_bool_type),
22442248
0, 1, 10);
22452249

src/julia.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -288,6 +288,7 @@ typedef struct _jl_code_info_t {
288288
uint8_t inlineable;
289289
uint8_t propagate_inbounds;
290290
uint8_t pure;
291+
uint8_t aggressive_constprop;
291292
} jl_code_info_t;
292293

293294
// This type describes a single method definition, and stores data
@@ -328,6 +329,7 @@ typedef struct _jl_method_t {
328329
uint8_t isva;
329330
uint8_t pure;
330331
uint8_t is_for_opaque_closure;
332+
uint8_t aggressive_constprop;
331333

332334
// hidden fields:
333335
// lock for modifications to the method

src/julia_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1289,6 +1289,7 @@ extern jl_sym_t *static_parameter_sym; extern jl_sym_t *inline_sym;
12891289
extern jl_sym_t *noinline_sym; extern jl_sym_t *generated_sym;
12901290
extern jl_sym_t *generated_only_sym; extern jl_sym_t *isdefined_sym;
12911291
extern jl_sym_t *propagate_inbounds_sym; extern jl_sym_t *specialize_sym;
1292+
extern jl_sym_t *aggressive_constprop_sym;
12921293
extern jl_sym_t *nospecialize_sym; extern jl_sym_t *macrocall_sym;
12931294
extern jl_sym_t *colon_sym; extern jl_sym_t *hygienicscope_sym;
12941295
extern jl_sym_t *throw_undef_if_not_sym; extern jl_sym_t *getfield_undefref_sym;

src/method.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@ static void jl_code_info_set_ir(jl_code_info_t *li, jl_expr_t *ir)
269269
li->inlineable = 1;
270270
else if (ma == (jl_value_t*)propagate_inbounds_sym)
271271
li->propagate_inbounds = 1;
272+
else if (ma == (jl_value_t*)aggressive_constprop_sym)
273+
li->aggressive_constprop = 1;
272274
else
273275
jl_array_ptr_set(meta, ins++, ma);
274276
}
@@ -528,6 +530,7 @@ static void jl_method_set_source(jl_method_t *m, jl_code_info_t *src)
528530
}
529531
m->called = called;
530532
m->pure = src->pure;
533+
m->aggressive_constprop = src->aggressive_constprop;
531534
jl_add_function_name_to_lineinfo(src, (jl_value_t*)m->name);
532535

533536
jl_array_t *copy = NULL;

stdlib/Serialization/src/Serialization.jl

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ mutable struct Serializer{I<:IO} <: AbstractSerializer
2222
table::IdDict{Any,Any}
2323
pending_refs::Vector{Int}
2424
known_object_data::Dict{UInt64,Any}
25-
Serializer{I}(io::I) where I<:IO = new(io, 0, IdDict(), Int[], Dict{UInt64,Any}())
25+
version::Int
26+
Serializer{I}(io::I) where I<:IO = new(io, 0, IdDict(), Int[], Dict{UInt64,Any}(), ser_version)
2627
end
2728

2829
Serializer(io::IO) = Serializer{typeof(io)}(io)
@@ -78,7 +79,10 @@ const TAGS = Any[
7879

7980
@assert length(TAGS) == 255
8081

81-
const ser_version = 13 # do not make changes without bumping the version #!
82+
const ser_version = 14 # do not make changes without bumping the version #!
83+
84+
format_version(::AbstractSerializer) = ser_version
85+
format_version(s::Serializer) = s.version
8286

8387
const NTAGS = length(TAGS)
8488

@@ -414,6 +418,7 @@ function serialize(s::AbstractSerializer, meth::Method)
414418
serialize(s, meth.nargs)
415419
serialize(s, meth.isva)
416420
serialize(s, meth.is_for_opaque_closure)
421+
serialize(s, meth.aggressive_constprop)
417422
if isdefined(meth, :source)
418423
serialize(s, Base._uncompressed_ast(meth, meth.source))
419424
else
@@ -717,6 +722,8 @@ function readheader(s::AbstractSerializer)
717722
error("""Cannot read stream serialized with a newer version of Julia.
718723
Got data version $version > current version $ser_version""")
719724
end
725+
s.version = version
726+
return
720727
end
721728

722729
"""
@@ -988,9 +995,13 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
988995
nargs = deserialize(s)::Int32
989996
isva = deserialize(s)::Bool
990997
is_for_opaque_closure = false
998+
aggressive_constprop = false
991999
template_or_is_opaque = deserialize(s)
9921000
if isa(template_or_is_opaque, Bool)
9931001
is_for_opaque_closure = template_or_is_opaque
1002+
if format_version(s) >= 14
1003+
aggressive_constprop = deserialize(s)::Bool
1004+
end
9941005
template = deserialize(s)
9951006
else
9961007
template = template_or_is_opaque
@@ -1005,6 +1016,7 @@ function deserialize(s::AbstractSerializer, ::Type{Method})
10051016
meth.nargs = nargs
10061017
meth.isva = isva
10071018
meth.is_for_opaque_closure = is_for_opaque_closure
1019+
meth.aggressive_constprop = aggressive_constprop
10081020
if template !== nothing
10091021
# TODO: compress template
10101022
meth.source = template::CodeInfo
@@ -1125,6 +1137,9 @@ function deserialize(s::AbstractSerializer, ::Type{CodeInfo})
11251137
ci.inlineable = deserialize(s)
11261138
ci.propagate_inbounds = deserialize(s)
11271139
ci.pure = deserialize(s)
1140+
if format_version(s) >= 14
1141+
ci.aggressive_constprop = deserialize(s)::Bool
1142+
end
11281143
return ci
11291144
end
11301145

0 commit comments

Comments
 (0)