Skip to content

Commit acb7bd9

Browse files
vtjnashJeffBezanson
authored andcommitted
setindex: disallow breaking the object model (#34176)
This was written fairly carefully to be safe, assuming it was not improperly optimized. But others are not as careful when copying this code. And it is just better not to break the object model and attempt to mutate constant values.
1 parent 2835347 commit acb7bd9

File tree

10 files changed

+124
-62
lines changed

10 files changed

+124
-62
lines changed

base/deepcopy.jl

Lines changed: 26 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -56,20 +56,35 @@ end
5656

5757
function deepcopy_internal(@nospecialize(x), stackdict::IdDict)
5858
T = typeof(x)::DataType
59-
isbitstype(T) && return x
60-
if haskey(stackdict, x)
61-
return stackdict[x]
62-
end
63-
y = ccall(:jl_new_struct_uninit, Any, (Any,), T)
59+
nf = nfields(x)
6460
if T.mutable
61+
if haskey(stackdict, x)
62+
return stackdict[x]
63+
end
64+
y = ccall(:jl_new_struct_uninit, Any, (Any,), T)
6565
stackdict[x] = y
66-
end
67-
for i in 1:nfields(x)
68-
if isdefined(x,i)
69-
xi = getfield(x, i)
70-
xi = deepcopy_internal(xi, stackdict)::typeof(xi)
71-
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi)
66+
for i in 1:nf
67+
if isdefined(x, i)
68+
xi = getfield(x, i)
69+
xi = deepcopy_internal(xi, stackdict)::typeof(xi)
70+
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), y, i-1, xi)
71+
end
72+
end
73+
elseif nf == 0 || isbitstype(T)
74+
y = x
75+
else
76+
flds = Vector{Any}(undef, nf)
77+
for i in 1:nf
78+
if isdefined(x, i)
79+
xi = getfield(x, i)
80+
xi = deepcopy_internal(xi, stackdict)::typeof(xi)
81+
flds[i] = xi
82+
else
83+
nf = i - 1 # rest of tail must be undefined values
84+
break
85+
end
7286
end
87+
y = ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), T, flds, nf)
7388
end
7489
return y::T
7590
end

src/builtins.c

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -717,15 +717,22 @@ JL_CALLABLE(jl_f_tuple)
717717
tt = jl_inst_concrete_tupletype(types);
718718
JL_GC_POP();
719719
}
720-
return jl_new_structv(tt, args, nargs);
720+
if (tt->instance != NULL)
721+
return tt->instance;
722+
jl_ptls_t ptls = jl_get_ptls_states();
723+
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(tt), tt);
724+
for (i = 0; i < nargs; i++)
725+
set_nth_field(tt, (void*)jv, i, args[i]);
726+
return jv;
721727
}
722728

723729
JL_CALLABLE(jl_f_svec)
724730
{
725731
size_t i;
726-
if (nargs == 0) return (jl_value_t*)jl_emptysvec;
732+
if (nargs == 0)
733+
return (jl_value_t*)jl_emptysvec;
727734
jl_svec_t *t = jl_alloc_svec_uninit(nargs);
728-
for(i=0; i < nargs; i++) {
735+
for (i = 0; i < nargs; i++) {
729736
jl_svecset(t, i, args[i]);
730737
}
731738
return (jl_value_t*)t;
@@ -785,11 +792,11 @@ JL_CALLABLE(jl_f_setfield)
785792
JL_TYPECHK(setfield!, symbol, args[1]);
786793
idx = jl_field_index(st, (jl_sym_t*)args[1], 1);
787794
}
788-
jl_value_t *ft = jl_field_type(st,idx);
795+
jl_value_t *ft = jl_field_type(st, idx);
789796
if (!jl_isa(args[2], ft)) {
790797
jl_type_error("setfield!", ft, args[2]);
791798
}
792-
jl_set_nth_field(v, idx, args[2]);
799+
set_nth_field(st, (void*)v, idx, args[2]);
793800
return args[2];
794801
}
795802

src/datatype.c

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -897,7 +897,7 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
897897
va_start(args, type);
898898
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
899899
for (size_t i = 0; i < nf; i++) {
900-
jl_set_nth_field(jv, i, va_arg(args, jl_value_t*));
900+
set_nth_field(type, (void*)jv, i, va_arg(args, jl_value_t*));
901901
}
902902
va_end(args);
903903
return jv;
@@ -906,14 +906,15 @@ JL_DLLEXPORT jl_value_t *jl_new_struct(jl_datatype_t *type, ...)
906906
static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
907907
{
908908
size_t nf = jl_datatype_nfields(type);
909-
for(size_t i=na; i < nf; i++) {
909+
char *data = (char*)jl_data_ptr(jv);
910+
for (size_t i = na; i < nf; i++) {
910911
if (jl_field_isptr(type, i)) {
911-
*(jl_value_t**)((char*)jl_data_ptr(jv)+jl_field_offset(type,i)) = NULL;
912+
*(jl_value_t**)(data + jl_field_offset(type, i)) = NULL;
912913
}
913914
else {
914915
jl_value_t *ft = jl_field_type(type, i);
915916
if (jl_is_uniontype(ft)) {
916-
uint8_t *psel = &((uint8_t *)jv)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
917+
uint8_t *psel = &((uint8_t *)data)[jl_field_offset(type, i) + jl_field_size(type, i) - 1];
917918
*psel = 0;
918919
}
919920
}
@@ -923,6 +924,10 @@ static void init_struct_tail(jl_datatype_t *type, jl_value_t *jv, size_t na)
923924
JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args, uint32_t na)
924925
{
925926
jl_ptls_t ptls = jl_get_ptls_states();
927+
if (!jl_is_datatype(type) || type->layout == NULL)
928+
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
929+
if (type->ninitialized > na || na > jl_datatype_nfields(type))
930+
jl_error("invalid struct allocation");
926931
if (type->instance != NULL) {
927932
for (size_t i = 0; i < na; i++) {
928933
jl_value_t *ft = jl_field_type(type, i);
@@ -931,15 +936,13 @@ JL_DLLEXPORT jl_value_t *jl_new_structv(jl_datatype_t *type, jl_value_t **args,
931936
}
932937
return type->instance;
933938
}
934-
if (type->layout == NULL)
935-
jl_type_error("new", (jl_value_t*)jl_datatype_type, (jl_value_t*)type);
936939
jl_value_t *jv = jl_gc_alloc(ptls, jl_datatype_size(type), type);
937940
JL_GC_PUSH1(&jv);
938941
for (size_t i = 0; i < na; i++) {
939942
jl_value_t *ft = jl_field_type(type, i);
940943
if (!jl_isa(args[i], ft))
941944
jl_type_error("new", ft, args[i]);
942-
jl_set_nth_field(jv, i, args[i]);
945+
set_nth_field(type, (void*)jv, i, args[i]);
943946
}
944947
init_struct_tail(type, jv, na);
945948
JL_GC_POP();
@@ -951,7 +954,7 @@ JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
951954
jl_ptls_t ptls = jl_get_ptls_states();
952955
if (!jl_is_tuple(tup))
953956
jl_type_error("new", (jl_value_t*)jl_tuple_type, tup);
954-
if (type->layout == NULL)
957+
if (!jl_is_datatype(type) || type->layout == NULL)
955958
jl_type_error("new", (jl_value_t *)jl_datatype_type, (jl_value_t *)type);
956959
size_t nargs = jl_nfields(tup);
957960
size_t nf = jl_datatype_nfields(type);
@@ -975,7 +978,7 @@ JL_DLLEXPORT jl_value_t *jl_new_structt(jl_datatype_t *type, jl_value_t *tup)
975978
fi = jl_get_nth_field(tup, i);
976979
if (!jl_isa(fi, ft))
977980
jl_type_error("new", ft, fi);
978-
jl_set_nth_field(jv, i, fi);
981+
set_nth_field(type, (void*)jv, i, fi);
979982
}
980983
JL_GC_POP();
981984
return jv;
@@ -1074,9 +1077,8 @@ JL_DLLEXPORT jl_value_t *jl_get_nth_field_checked(jl_value_t *v, size_t i)
10741077
return undefref_check((jl_datatype_t*)ty, jl_new_bits(ty, (char*)v + offs));
10751078
}
10761079

1077-
JL_DLLEXPORT void jl_set_nth_field(jl_value_t *v, size_t i, jl_value_t *rhs) JL_NOTSAFEPOINT
1080+
void set_nth_field(jl_datatype_t *st, void *v, size_t i, jl_value_t *rhs) JL_NOTSAFEPOINT
10781081
{
1079-
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
10801082
size_t offs = jl_field_offset(st, i);
10811083
if (jl_field_isptr(st, i)) {
10821084
*(jl_value_t**)((char*)v + offs) = rhs;

src/dump.c

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2119,7 +2119,7 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
21192119
v = jl_new_struct_uninit(tag == TAG_GOTONODE ? jl_gotonode_type : jl_quotenode_type);
21202120
if (usetable)
21212121
arraylist_push(&backref_list, v);
2122-
jl_set_nth_field(v, 0, jl_deserialize_value(s, NULL));
2122+
set_nth_field(tag == TAG_GOTONODE ? jl_gotonode_type : jl_quotenode_type, (void*)v, 0, jl_deserialize_value(s, NULL));
21232123
return v;
21242124
case TAG_UNIONALL:
21252125
pos = backref_list.len;
@@ -2235,7 +2235,7 @@ static jl_value_t *jl_deserialize_value(jl_serializer_state *s, jl_value_t **loc
22352235
arraylist_push(&backref_list, v);
22362236
for (i = 0; i < jl_datatype_nfields(jl_lineinfonode_type); i++) {
22372237
size_t offs = jl_field_offset(jl_lineinfonode_type, i);
2238-
jl_set_nth_field(v, i, jl_deserialize_value(s, (jl_value_t**)((char*)v + offs)));
2238+
set_nth_field(jl_lineinfonode_type, (void*)v, i, jl_deserialize_value(s, (jl_value_t**)((char*)v + offs)));
22392239
}
22402240
return v;
22412241
case TAG_DATATYPE:

src/interpreter.c

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -509,7 +509,6 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
509509
JL_GC_PUSHARGS(argv, nargs);
510510
for (size_t i = 0; i < nargs; i++)
511511
argv[i] = eval_value(args[i], s);
512-
assert(jl_is_structtype(argv[0]));
513512
jl_value_t *v = jl_new_structv((jl_datatype_t*)argv[0], &argv[1], nargs - 1);
514513
JL_GC_POP();
515514
return v;
@@ -519,7 +518,6 @@ static jl_value_t *eval_value(jl_value_t *e, interpreter_state *s)
519518
JL_GC_PUSHARGS(argv, 2);
520519
argv[0] = eval_value(args[0], s);
521520
argv[1] = eval_value(args[1], s);
522-
assert(jl_is_structtype(argv[0]));
523521
jl_value_t *v = jl_new_structt((jl_datatype_t*)argv[0], argv[1]);
524522
JL_GC_POP();
525523
return v;

src/julia_internal.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -440,6 +440,7 @@ void jl_precompute_memoized_dt(jl_datatype_t *dt);
440440
jl_datatype_t *jl_wrap_Type(jl_value_t *t); // x -> Type{x}
441441
jl_value_t *jl_wrap_vararg(jl_value_t *t, jl_value_t *n);
442442
void jl_assign_bits(void *dest, jl_value_t *bits) JL_NOTSAFEPOINT;
443+
void set_nth_field(jl_datatype_t *st, void *v, size_t i, jl_value_t *rhs) JL_NOTSAFEPOINT;
443444
jl_expr_t *jl_exprn(jl_sym_t *head, size_t n);
444445
jl_function_t *jl_new_generic_function(jl_sym_t *name, jl_module_t *module);
445446
jl_function_t *jl_new_generic_function_with_supertype(jl_sym_t *name, jl_module_t *module, jl_datatype_t *st);

src/rtutils.c

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,22 @@ JL_DLLEXPORT jl_value_t *jl_value_ptr(jl_value_t *a)
343343
return a;
344344
}
345345

346+
// optimization of setfield which bypasses boxing of the idx (and checking field type validity)
347+
JL_DLLEXPORT void jl_set_nth_field(jl_value_t *v, size_t idx0, jl_value_t *rhs)
348+
{
349+
jl_datatype_t *st = (jl_datatype_t*)jl_typeof(v);
350+
if (!st->mutabl)
351+
jl_errorf("setfield! immutable struct of type %s cannot be changed", jl_symbol_name(st->name->name));
352+
if (idx0 >= jl_datatype_nfields(st))
353+
jl_bounds_error_int(v, idx0 + 1);
354+
//jl_value_t *ft = jl_field_type(st, idx0);
355+
//if (!jl_isa(rhs, ft)) {
356+
// jl_type_error("setfield!", ft, rhs);
357+
//}
358+
set_nth_field(st, (void*)v, idx0, rhs);
359+
}
360+
361+
346362
// parsing --------------------------------------------------------------------
347363

348364
int substr_isspace(char *p, char *pend)

stdlib/Serialization/src/Serialization.jl

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -631,7 +631,7 @@ function serialize_any(s::AbstractSerializer, @nospecialize(x))
631631
serialize_type(s, t)
632632
write(s.io, x)
633633
else
634-
if t.mutable && nf > 0
634+
if t.mutable
635635
serialize_cycle(s, x) && return
636636
serialize_type(s, t, true)
637637
else
@@ -1288,36 +1288,31 @@ function deserialize(s::AbstractSerializer, t::DataType)
12881288
if nf == 0 && t.size > 0
12891289
# bits type
12901290
return read(s.io, t)
1291-
end
1292-
if nf == 0
1293-
return ccall(:jl_new_struct, Any, (Any,Any...), t)
1294-
elseif isbitstype(t)
1295-
if nf == 1
1296-
f1 = deserialize(s)
1297-
return ccall(:jl_new_struct, Any, (Any,Any...), t, f1)
1298-
elseif nf == 2
1299-
f1 = deserialize(s)
1300-
f2 = deserialize(s)
1301-
return ccall(:jl_new_struct, Any, (Any,Any...), t, f1, f2)
1302-
elseif nf == 3
1303-
f1 = deserialize(s)
1304-
f2 = deserialize(s)
1305-
f3 = deserialize(s)
1306-
return ccall(:jl_new_struct, Any, (Any,Any...), t, f1, f2, f3)
1307-
else
1308-
flds = Any[ deserialize(s) for i = 1:nf ]
1309-
return ccall(:jl_new_structv, Any, (Any,Ptr{Cvoid},UInt32), t, flds, nf)
1310-
end
1311-
else
1291+
elseif t.mutable
13121292
x = ccall(:jl_new_struct_uninit, Any, (Any,), t)
1313-
t.mutable && deserialize_cycle(s, x)
1293+
deserialize_cycle(s, x)
13141294
for i in 1:nf
13151295
tag = Int32(read(s.io, UInt8)::UInt8)
13161296
if tag != UNDEFREF_TAG
13171297
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), x, i-1, handle_deserialize(s, tag))
13181298
end
13191299
end
13201300
return x
1301+
elseif nf == 0
1302+
return ccall(:jl_new_struct_uninit, Any, (Any,), t)
1303+
else
1304+
na = nf
1305+
vflds = Vector{Any}(undef, nf)
1306+
for i in 1:nf
1307+
tag = Int32(read(s.io, UInt8)::UInt8)
1308+
if tag != UNDEFREF_TAG
1309+
f = handle_deserialize(s, tag)
1310+
na >= i && (vflds[i] = f)
1311+
else
1312+
na >= i && (na = i - 1) # rest of tail must be undefined values
1313+
end
1314+
end
1315+
return ccall(:jl_new_structv, Any, (Any, Ptr{Any}, UInt32), t, vflds, na)
13211316
end
13221317
end
13231318

stdlib/Serialization/test/runtests.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,17 @@ create_serialization_stream() do s
427427
@test C[1] === C[2]
428428
end
429429

430+
mutable struct MSingle end
431+
create_serialization_stream() do s
432+
x = MSingle()
433+
A = [x, x, MSingle()]
434+
serialize(s, A)
435+
seekstart(s)
436+
C = deserialize(s)
437+
@test A[1] === x === A[2] !== A[3]
438+
@test x !== C[1] === C[2] !== C[3]
439+
end
440+
430441
# Regex
431442
create_serialization_stream() do s
432443
r1 = r"a?b.*"

test/core.jl

Lines changed: 25 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,8 +1085,8 @@ let
10851085
@test_throws BoundsError(z, -1) getfield(z, -1)
10861086
@test_throws BoundsError(z, 0) getfield(z, 0)
10871087
@test_throws BoundsError(z, 3) getfield(z, 3)
1088-
1089-
strct = LoadError("yofile", 0, "bad")
1088+
end
1089+
let strct = LoadError("yofile", 0, "bad")
10901090
@test nfields(strct) == 3 # sanity test
10911091
@test_throws BoundsError(strct, 10) getfield(strct, 10)
10921092
@test_throws ErrorException("setfield! immutable struct of type LoadError cannot be changed") setfield!(strct, 0, "")
@@ -1098,8 +1098,8 @@ let
10981098
@test getfield(strct, 1) == "yofile"
10991099
@test getfield(strct, 2) === 0
11001100
@test getfield(strct, 3) == "bad"
1101-
1102-
mstrct = TestMutable("melm", 1, nothing)
1101+
end
1102+
let mstrct = TestMutable("melm", 1, nothing)
11031103
@test Base.setproperty!(mstrct, :line, 8.0) === 8
11041104
@test mstrct.line === 8
11051105
@test_throws TypeError(:setfield!, "", Int, 8.0) setfield!(mstrct, :line, 8.0)
@@ -1112,6 +1112,14 @@ let
11121112
@test_throws BoundsError(mstrct, 0) setfield!(mstrct, 0, "")
11131113
@test_throws BoundsError(mstrct, 4) setfield!(mstrct, 4, "")
11141114
end
1115+
let strct = LoadError("yofile", 0, "bad")
1116+
@test_throws(ErrorException("setfield! immutable struct of type LoadError cannot be changed"),
1117+
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), strct, 0, ""))
1118+
end
1119+
let mstrct = TestMutable("melm", 1, nothing)
1120+
@test_throws(BoundsError(mstrct, 4),
1121+
ccall(:jl_set_nth_field, Cvoid, (Any, Csize_t, Any), mstrct, 3, ""))
1122+
end
11151123

11161124
# test getfield-overloading
11171125
function Base.getproperty(mstrct::TestMutable, p::Symbol)
@@ -3614,10 +3622,19 @@ end
36143622
return nothing
36153623
end
36163624
end
3617-
@test_throws TypeError f1()
3618-
@test_throws TypeError f2()
3619-
@test_throws TypeError f3()
3620-
@test_throws TypeError eval(Expr(:new, B, 1))
3625+
@test_throws TypeError("new", A, 1) f1()
3626+
@test_throws TypeError("new", A, 1) f2()
3627+
@test_throws TypeError("new", A, 1) f3()
3628+
@test_throws TypeError("new", A, 1) eval(Expr(:new, B, 1))
3629+
3630+
# some tests for handling of malformed syntax--these cases should not be possible in normal code
3631+
@test eval(Expr(:new, B, A())) == B(A())
3632+
@test_throws ErrorException("invalid struct allocation") eval(Expr(:new, B))
3633+
@test_throws ErrorException("invalid struct allocation") eval(Expr(:new, B, A(), A()))
3634+
@test_throws TypeError("new", DataType, Complex) eval(Expr(:new, Complex))
3635+
@test_throws TypeError("new", DataType, Complex.body) eval(Expr(:new, Complex.body))
3636+
@test_throws TypeError("new", DataType, Complex) eval(Expr(:splatnew, Complex, ()))
3637+
@test_throws TypeError("new", DataType, Complex.body) eval(Expr(:splatnew, Complex.body, ()))
36213638

36223639
end
36233640

0 commit comments

Comments
 (0)