Skip to content

Commit 19d5ffc

Browse files
committed
Sema: add safety check for non-power-of-two shift amounts
1 parent 9116e26 commit 19d5ffc

File tree

5 files changed

+102
-50
lines changed

5 files changed

+102
-50
lines changed

src/Sema.zig

Lines changed: 88 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -10227,34 +10227,57 @@ fn zirShl(
1022710227
} else rhs;
1022810228

1022910229
try sema.requireRuntimeBlock(block, src, runtime_src);
10230-
if (block.wantSafety() and air_tag == .shl_exact) {
10231-
const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
10232-
const op_ov = try block.addInst(.{
10233-
.tag = .shl_with_overflow,
10234-
.data = .{ .ty_pl = .{
10235-
.ty = try sema.addType(op_ov_tuple_ty),
10236-
.payload = try sema.addExtra(Air.Bin{
10237-
.lhs = lhs,
10238-
.rhs = rhs,
10239-
}),
10240-
} },
10241-
});
10242-
const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
10243-
const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
10244-
try block.addInst(.{
10245-
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10246-
.data = .{ .reduce = .{
10247-
.operand = ov_bit,
10248-
.operation = .Or,
10230+
if (block.wantSafety()) {
10231+
const bit_count = scalar_ty.intInfo(target).bits;
10232+
if (!std.math.isPowerOfTwo(bit_count)) {
10233+
const bit_count_val = try Value.Tag.int_u64.create(sema.arena, bit_count);
10234+
10235+
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10236+
const bit_count_inst = try sema.addConstant(rhs_ty, try Value.Tag.repeated.create(sema.arena, bit_count_val));
10237+
const lt = try block.addCmpVector(rhs, bit_count_inst, .lt, try sema.addType(rhs_ty));
10238+
break :ok try block.addInst(.{
10239+
.tag = .reduce,
10240+
.data = .{ .reduce = .{
10241+
.operand = lt,
10242+
.operation = .And,
10243+
} },
10244+
});
10245+
} else ok: {
10246+
const bit_count_inst = try sema.addConstant(rhs_ty, bit_count_val);
10247+
break :ok try block.addBinOp(.cmp_lt, rhs, bit_count_inst);
10248+
};
10249+
try sema.addSafetyCheck(block, ok, .shift_rhs_too_big);
10250+
}
10251+
10252+
if (air_tag == .shl_exact) {
10253+
const op_ov_tuple_ty = try sema.overflowArithmeticTupleType(lhs_ty);
10254+
const op_ov = try block.addInst(.{
10255+
.tag = .shl_with_overflow,
10256+
.data = .{ .ty_pl = .{
10257+
.ty = try sema.addType(op_ov_tuple_ty),
10258+
.payload = try sema.addExtra(Air.Bin{
10259+
.lhs = lhs,
10260+
.rhs = rhs,
10261+
}),
1024910262
} },
10250-
})
10251-
else
10252-
ov_bit;
10253-
const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
10254-
const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
10263+
});
10264+
const ov_bit = try sema.tupleFieldValByIndex(block, src, op_ov, 1, op_ov_tuple_ty);
10265+
const any_ov_bit = if (lhs_ty.zigTypeTag() == .Vector)
10266+
try block.addInst(.{
10267+
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10268+
.data = .{ .reduce = .{
10269+
.operand = ov_bit,
10270+
.operation = .Or,
10271+
} },
10272+
})
10273+
else
10274+
ov_bit;
10275+
const zero_ov = try sema.addConstant(Type.@"u1", Value.zero);
10276+
const no_ov = try block.addBinOp(.cmp_eq, any_ov_bit, zero_ov);
1025510277

10256-
try sema.addSafetyCheck(block, no_ov, .shl_overflow);
10257-
return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
10278+
try sema.addSafetyCheck(block, no_ov, .shl_overflow);
10279+
return sema.tupleFieldValByIndex(block, src, op_ov, 0, op_ov_tuple_ty);
10280+
}
1025810281
}
1025910282
return block.addBinOp(air_tag, lhs, new_rhs);
1026010283
}
@@ -10333,20 +10356,43 @@ fn zirShr(
1033310356

1033410357
try sema.requireRuntimeBlock(block, src, runtime_src);
1033510358
const result = try block.addBinOp(air_tag, lhs, rhs);
10336-
if (block.wantSafety() and air_tag == .shr_exact) {
10337-
const back = try block.addBinOp(.shl, result, rhs);
10338-
10339-
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10340-
const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
10341-
break :ok try block.addInst(.{
10342-
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10343-
.data = .{ .reduce = .{
10344-
.operand = eql,
10345-
.operation = .And,
10346-
} },
10347-
});
10348-
} else try block.addBinOp(.cmp_eq, lhs, back);
10349-
try sema.addSafetyCheck(block, ok, .shr_overflow);
10359+
if (block.wantSafety()) {
10360+
const bit_count = scalar_ty.intInfo(target).bits;
10361+
if (!std.math.isPowerOfTwo(bit_count)) {
10362+
const bit_count_val = try Value.Tag.int_u64.create(sema.arena, bit_count);
10363+
10364+
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10365+
const bit_count_inst = try sema.addConstant(rhs_ty, try Value.Tag.repeated.create(sema.arena, bit_count_val));
10366+
const lt = try block.addCmpVector(rhs, bit_count_inst, .lt, try sema.addType(rhs_ty));
10367+
break :ok try block.addInst(.{
10368+
.tag = .reduce,
10369+
.data = .{ .reduce = .{
10370+
.operand = lt,
10371+
.operation = .And,
10372+
} },
10373+
});
10374+
} else ok: {
10375+
const bit_count_inst = try sema.addConstant(rhs_ty, bit_count_val);
10376+
break :ok try block.addBinOp(.cmp_lt, rhs, bit_count_inst);
10377+
};
10378+
try sema.addSafetyCheck(block, ok, .shift_rhs_too_big);
10379+
}
10380+
10381+
if (air_tag == .shr_exact) {
10382+
const back = try block.addBinOp(.shl, result, rhs);
10383+
10384+
const ok = if (rhs_ty.zigTypeTag() == .Vector) ok: {
10385+
const eql = try block.addCmpVector(lhs, back, .eq, try sema.addType(rhs_ty));
10386+
break :ok try block.addInst(.{
10387+
.tag = if (block.float_mode == .Optimized) .reduce_optimized else .reduce,
10388+
.data = .{ .reduce = .{
10389+
.operand = eql,
10390+
.operation = .And,
10391+
} },
10392+
});
10393+
} else try block.addBinOp(.cmp_eq, lhs, back);
10394+
try sema.addSafetyCheck(block, ok, .shr_overflow);
10395+
}
1035010396
}
1035110397
return result;
1035210398
}
@@ -19972,6 +20018,7 @@ pub const PanicId = enum {
1997220018
inactive_union_field,
1997320019
integer_part_out_of_bounds,
1997420020
corrupt_switch,
20021+
shift_rhs_too_big,
1997520022
};
1997620023

1997720024
fn addSafetyCheck(
@@ -20268,6 +20315,7 @@ fn safetyPanic(
2026820315
.inactive_union_field => "access of inactive union field",
2026920316
.integer_part_out_of_bounds => "integer part of floating point value out of bounds",
2027020317
.corrupt_switch => "switch on corrupt value",
20318+
.shift_rhs_too_big => "shift amount is greater than the type size",
2027120319
};
2027220320

2027320321
const msg_inst = msg_inst: {

test/cases/safety/shift left by huge amount.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ pub fn main() !void {
1717
}
1818

1919
// run
20-
// backend=stage1
20+
// backend=llvm
2121
// target=native

test/cases/safety/shift right by huge amount.zig

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,5 @@ pub fn main() !void {
1717
}
1818

1919
// run
20-
// backend=stage1
20+
// backend=llvm
2121
// target=native

test/cases/safety/signed integer division overflow - vectors.zig

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
const std = @import("std");
22

33
pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
4-
_ = message;
54
_ = stack_trace;
6-
std.process.exit(0);
5+
if (std.mem.eql(u8, message, "integer overflow")) {
6+
std.process.exit(0);
7+
}
8+
std.process.exit(1);
79
}
810

911
pub fn main() !void {
@@ -17,5 +19,5 @@ fn div(a: @Vector(4, i16), b: @Vector(4, i16)) @Vector(4, i16) {
1719
return @divTrunc(a, b);
1820
}
1921
// run
20-
// backend=stage1
21-
// target=native
22+
// backend=llvm
23+
// target=native
Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
const std = @import("std");
22

33
pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace) noreturn {
4-
_ = message;
54
_ = stack_trace;
6-
std.process.exit(0);
5+
if (std.mem.eql(u8, message, "integer overflow")) {
6+
std.process.exit(0);
7+
}
8+
std.process.exit(1);
79
}
810

911
pub fn main() !void {
@@ -15,5 +17,5 @@ fn div(a: i16, b: i16) i16 {
1517
return @divTrunc(a, b);
1618
}
1719
// run
18-
// backend=stage1
19-
// target=native
20+
// backend=llvm
21+
// target=native

0 commit comments

Comments
 (0)