Skip to content

Commit 6c2c940

Browse files
vtjnashJeffBezanson
authored andcommitted
malloc wrappers: ensure thread-safe (#33284)
Better align the API of the jl_ wrappers for malloc/realloc/free with the libc namesakes, including being safe to use on threads. fix #33223
1 parent b6ddd87 commit 6c2c940

File tree

5 files changed

+53
-37
lines changed

5 files changed

+53
-37
lines changed

src/codegen.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4512,7 +4512,9 @@ static Function* gen_cfun_wrapper(
45124512
// for now, just use a dummy field to avoid a branch in this function
45134513
ctx.world_age_field = ctx.builder.CreateSelect(have_tls, ctx.world_age_field, dummy_world);
45144514
Value *last_age = tbaa_decorate(tbaa_gcframe, ctx.builder.CreateLoad(ctx.world_age_field));
4515-
have_tls = ctx.builder.CreateAnd(have_tls, ctx.builder.CreateIsNotNull(last_age));
4515+
Value *valid_tls = ctx.builder.CreateIsNotNull(last_age);
4516+
have_tls = ctx.builder.CreateAnd(have_tls, valid_tls);
4517+
ctx.world_age_field = ctx.builder.CreateSelect(valid_tls, ctx.world_age_field, dummy_world);
45164518
Value *world_v = ctx.builder.CreateLoad(prepare_global(jlgetworld_global));
45174519

45184520
Value *age_ok = NULL;

src/dump.c

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2002,6 +2002,8 @@ static jl_value_t *jl_deserialize_value_any(jl_serializer_state *s, uint8_t tag,
20022002
int32_t nw = (sz == 0 ? 1 : (sz < 0 ? -sz : sz));
20032003
size_t nb = nw * gmp_limb_size;
20042004
void *buf = jl_gc_counted_malloc(nb);
2005+
if (buf == NULL)
2006+
jl_throw(jl_memory_exception);
20052007
ios_read(s->s, (char*)buf, nb);
20062008
jl_set_nth_field(v, 0, jl_box_int32(nw));
20072009
jl_set_nth_field(v, 1, sizefield);

src/gc.c

Lines changed: 34 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -3015,54 +3015,47 @@ JL_DLLEXPORT void jl_throw_out_of_memory_error(void)
30153015
JL_DLLEXPORT void *jl_gc_counted_malloc(size_t sz)
30163016
{
30173017
jl_ptls_t ptls = jl_get_ptls_states();
3018-
maybe_collect(ptls);
3019-
ptls->gc_num.allocd += sz;
3020-
ptls->gc_num.malloc++;
3021-
void *b = malloc(sz);
3022-
if (b == NULL)
3023-
jl_throw(jl_memory_exception);
3024-
return b;
3018+
if (ptls && ptls->world_age) {
3019+
maybe_collect(ptls);
3020+
ptls->gc_num.allocd += sz;
3021+
ptls->gc_num.malloc++;
3022+
}
3023+
return malloc(sz);
30253024
}
30263025

30273026
JL_DLLEXPORT void *jl_gc_counted_calloc(size_t nm, size_t sz)
30283027
{
30293028
jl_ptls_t ptls = jl_get_ptls_states();
3030-
maybe_collect(ptls);
3031-
ptls->gc_num.allocd += nm*sz;
3032-
ptls->gc_num.malloc++;
3033-
void *b = calloc(nm, sz);
3034-
if (b == NULL)
3035-
jl_throw(jl_memory_exception);
3036-
return b;
3029+
if (ptls && ptls->world_age) {
3030+
maybe_collect(ptls);
3031+
ptls->gc_num.allocd += nm*sz;
3032+
ptls->gc_num.malloc++;
3033+
}
3034+
return calloc(nm, sz);
30373035
}
30383036

30393037
JL_DLLEXPORT void jl_gc_counted_free_with_size(void *p, size_t sz)
30403038
{
30413039
jl_ptls_t ptls = jl_get_ptls_states();
30423040
free(p);
3043-
ptls->gc_num.freed += sz;
3044-
ptls->gc_num.freecall++;
3045-
}
3046-
3047-
// older name for jl_gc_counted_free_with_size
3048-
JL_DLLEXPORT void jl_gc_counted_free(void *p, size_t sz)
3049-
{
3050-
jl_gc_counted_free_with_size(p, sz);
3041+
if (ptls && ptls->world_age) {
3042+
ptls->gc_num.freed += sz;
3043+
ptls->gc_num.freecall++;
3044+
}
30513045
}
30523046

30533047
JL_DLLEXPORT void *jl_gc_counted_realloc_with_old_size(void *p, size_t old, size_t sz)
30543048
{
30553049
jl_ptls_t ptls = jl_get_ptls_states();
3056-
maybe_collect(ptls);
3057-
if (sz < old)
3058-
ptls->gc_num.freed += (old - sz);
3059-
else
3060-
ptls->gc_num.allocd += (sz - old);
3061-
ptls->gc_num.realloc++;
3062-
void *b = realloc(p, sz);
3063-
if (b == NULL)
3064-
jl_throw(jl_memory_exception);
3065-
return b;
3050+
if (ptls && ptls->world_age) {
3051+
maybe_collect(ptls);
3052+
if (sz < old)
3053+
ptls->gc_num.freed += (old - sz);
3054+
else
3055+
ptls->gc_num.allocd += (sz - old);
3056+
ptls->gc_num.realloc++;
3057+
}
3058+
return realloc(p, sz);
30663059
}
30673060

30683061
// allocation wrappers that save the size of allocations, to allow using
@@ -3071,16 +3064,20 @@ JL_DLLEXPORT void *jl_gc_counted_realloc_with_old_size(void *p, size_t old, size
30713064
JL_DLLEXPORT void *jl_malloc(size_t sz)
30723065
{
30733066
int64_t *p = (int64_t *)jl_gc_counted_malloc(sz + JL_SMALL_BYTE_ALIGNMENT);
3067+
if (p == NULL)
3068+
return NULL;
30743069
p[0] = sz;
3075-
return (void *)(p + 2);
3070+
return (void *)(p + 2); // assumes JL_SMALL_BYTE_ALIGNMENT == 16
30763071
}
30773072

30783073
JL_DLLEXPORT void *jl_calloc(size_t nm, size_t sz)
30793074
{
30803075
size_t nmsz = nm*sz;
30813076
int64_t *p = (int64_t *)jl_gc_counted_calloc(nmsz + JL_SMALL_BYTE_ALIGNMENT, 1);
3077+
if (p == NULL)
3078+
return NULL;
30823079
p[0] = nmsz;
3083-
return (void *)(p + 2);
3080+
return (void *)(p + 2); // assumes JL_SMALL_BYTE_ALIGNMENT == 16
30843081
}
30853082

30863083
JL_DLLEXPORT void jl_free(void *p)
@@ -3105,8 +3102,10 @@ JL_DLLEXPORT void *jl_realloc(void *p, size_t sz)
31053102
szold = pp[0] + JL_SMALL_BYTE_ALIGNMENT;
31063103
}
31073104
int64_t *pnew = (int64_t *)jl_gc_counted_realloc_with_old_size(pp, szold, sz + JL_SMALL_BYTE_ALIGNMENT);
3105+
if (pnew == NULL)
3106+
return NULL;
31083107
pnew[0] = sz;
3109-
return (void *)(pnew + 2);
3108+
return (void *)(pnew + 2); // assumes JL_SMALL_BYTE_ALIGNMENT == 16
31103109
}
31113110

31123111
// allocating blocks for Arrays and Strings

src/jl_uv.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -973,7 +973,6 @@ struct work_baton {
973973
void *work_args;
974974
void *work_retval;
975975
notify_cb_t notify_func;
976-
int tid;
977976
int notify_idx;
978977
};
979978

test/ccall.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1017,6 +1017,20 @@ end
10171017

10181018
@test ccall(:jl_getpagesize, Clong, ()) == @threadcall(:jl_getpagesize, Clong, ())
10191019

1020+
# make sure our malloc/realloc/free adapters are thread-safe and repeatable
1021+
for i = 1:8
1022+
ptr = @threadcall(:jl_malloc, Ptr{Cint}, (Csize_t,), sizeof(Cint))
1023+
@test ptr != C_NULL
1024+
unsafe_store!(ptr, 3)
1025+
@test unsafe_load(ptr) == 3
1026+
ptr = @threadcall(:jl_realloc, Ptr{Cint}, (Ptr{Cint}, Csize_t,), ptr, 2 * sizeof(Cint))
1027+
@test ptr != C_NULL
1028+
unsafe_store!(ptr, 4, 2)
1029+
@test unsafe_load(ptr, 1) == 3
1030+
@test unsafe_load(ptr, 2) == 4
1031+
@threadcall(:jl_free, Cvoid, (Ptr{Cint},), ptr)
1032+
end
1033+
10201034
# Pointer finalizer (issue #15408)
10211035
let A = [1]
10221036
ccall((:set_c_int, libccalltest), Cvoid, (Cint,), 1)

0 commit comments

Comments
 (0)