Skip to content

spirv: saturating arithmetic implementation #24317

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: master
Choose a base branch
from
Open
264 changes: 219 additions & 45 deletions src/codegen/spirv.zig
Original file line number Diff line number Diff line change
Expand Up @@ -3249,10 +3249,14 @@ const NavGen = struct {
.rem, .rem_optimized => try self.airArithOp(inst, .f_rem, .s_rem, .u_mod),
.mod, .mod_optimized => try self.airArithOp(inst, .f_mod, .s_mod, .u_mod),

.add_with_overflow => try self.airAddSubOverflow(inst, .i_add, .u_lt, .s_lt),
.sub_with_overflow => try self.airAddSubOverflow(inst, .i_sub, .u_gt, .s_gt),
.add_with_overflow => try self.airAddSubWithOverflow(inst, .i_add),
.sub_with_overflow => try self.airAddSubWithOverflow(inst, .i_sub),
.mul_with_overflow => try self.airMulOverflow(inst),
.shl_with_overflow => try self.airShlOverflow(inst),
.add_sat => try self.airAddSubSaturating(inst, .i_add),
.sub_sat => try self.airAddSubSaturating(inst, .i_sub),
.mul_sat => try self.airMulSaturating(inst),
.shl_sat => try self.airShlSaturating(inst),

.mul_add => try self.airMulAdd(inst),

Expand Down Expand Up @@ -3608,8 +3612,8 @@ const NavGen = struct {
const info = self.arithmeticTypeInfo(lhs.ty);

const result = switch (info.class) {
.composite_integer => unreachable, // TODO
.integer, .strange_integer => switch (info.signedness) {
.composite_integer, .strange_integer => unreachable, // TODO
.integer => switch (info.signedness) {
.signed => try self.buildBinary(sop, lhs, rhs),
.unsigned => try self.buildBinary(uop, lhs, rhs),
},
Expand Down Expand Up @@ -3654,42 +3658,38 @@ const NavGen = struct {
}
}

fn airAddSubOverflow(
fn buildAddSub(
self: *NavGen,
inst: Air.Inst.Index,
comptime add: BinaryOp,
comptime ucmp: CmpPredicate,
comptime scmp: CmpPredicate,
lhs: Temporary,
rhs: Temporary,
result_ty: Type,
comptime op: BinaryOp,
comptime mode: enum { WithOverflow, Saturating },
) !?IdRef {
_ = scmp;
// Note: OpIAddCarry and OpISubBorrow are not really useful here: For unsigned numbers,
// there is in both cases only one extra operation required. For signed operations,
// the overflow bit is set then going from 0x80.. to 0x00.., but this doesn't actually
// normally set a carry bit. So the SPIR-V overflow operations are not particularly
// useful here.

const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;

const lhs = try self.temporary(extra.lhs);
const rhs = try self.temporary(extra.rhs);

const result_ty = self.typeOfIndex(inst);

const info = self.arithmeticTypeInfo(lhs.ty);
switch (info.class) {
.composite_integer => unreachable, // TODO
.strange_integer, .integer => {},
.float, .bool => unreachable,
}

const sum = try self.buildBinary(add, lhs, rhs);
const sum = try self.buildBinary(op, lhs, rhs);
const result = try self.normalize(sum, info);

const overflowed = switch (info.signedness) {
// Overflow happened if the result is smaller than either of the operands. It doesn't matter which.
// For subtraction the conditions need to be swapped.
.unsigned => try self.buildCmp(ucmp, result, lhs),
.unsigned => switch (op) {
.i_add => try self.buildCmp(.u_lt, result, lhs),
.i_sub => try self.buildCmp(.u_gt, result, lhs),
else => @compileLog(op),
},
// For signed operations, we check the signs of the operands and the result.
.signed => blk: {
// Signed overflow detection using the sign bits of the operands and the result.
Expand All @@ -3708,31 +3708,88 @@ const NavGen = struct {
const signs_match = try self.buildCmp(.l_eq, lhs_is_neg, rhs_is_neg);
const result_sign_differs = try self.buildCmp(.l_ne, lhs_is_neg, result_is_neg);

const overflow_condition = if (add == .i_add)
signs_match
else // .i_sub
try self.buildUnary(.l_not, signs_match);
const overflow_condition = switch (op) {
.i_add => signs_match,
.i_sub => try self.buildUnary(.l_not, signs_match),
else => @compileLog(op),
};

break :blk try self.buildBinary(.l_and, overflow_condition, result_sign_differs);
},
};

const ov = try self.intFromBool(overflowed);
switch (mode) {
.WithOverflow => {
const ov = try self.intFromBool(overflowed);
const struct_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(struct_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
},
.Saturating => {
const scalar_ty = result_ty.scalarType(self.pt.zcu);

const result_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(result_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
}
if (info.signedness == .signed) {
const min_val: i64 = if (info.bits == 0) 0 else if (info.bits == 64) std.math.minInt(i64) else -(@as(i64, 1) << @as(u6, @intCast(info.bits - 1)));
const max_val: i64 = if (info.bits == 0) 0 else if (info.bits == 64) std.math.maxInt(i64) else (@as(i64, 1) << @as(u6, @intCast(info.bits - 1))) - 1;

fn airMulOverflow(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const pt = self.pt;
const min_id = try self.constInt(scalar_ty, min_val);
const max_id = try self.constInt(scalar_ty, max_val);

const min_tmp = Temporary.init(scalar_ty, min_id);
const max_tmp = Temporary.init(scalar_ty, max_id);

// The sign of the left-hand-side operand predicts the direction of saturation.
// If lhs is negative, any overflow/underflow will be towards min.
// If lhs is positive, any overflow will be towards max.
const zero = Temporary.init(lhs.ty, try self.constInt(lhs.ty, 0));
const lhs_is_neg = try self.buildCmp(.s_lt, lhs, zero);

const saturation_value = try self.buildSelect(lhs_is_neg, min_tmp, max_tmp);
const final_result = try self.buildSelect(overflowed, saturation_value, result);
return try final_result.materialize(self);
} else {
const saturation_val: u64 = switch (op) {
.i_add => if (info.bits == 0) 0 else if (info.bits == 64) std.math.maxInt(u64) else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1, // saturate to max on overflow
.i_sub => 0, // saturate to min (0) on underflow
else => @compileLog(op),
};
const saturation_id = try self.constInt(scalar_ty, saturation_val);
const saturation_tmp = Temporary.init(scalar_ty, saturation_id);
const final_result = try self.buildSelect(overflowed, saturation_tmp, result);
return try final_result.materialize(self);
}
},
}
}

fn airAddSubWithOverflow(self: *NavGen, inst: Air.Inst.Index, comptime op: BinaryOp) !?IdRef {
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;

const lhs = try self.temporary(extra.lhs);
const rhs = try self.temporary(extra.rhs);

const result_ty = self.typeOfIndex(inst);
return self.buildAddSub(lhs, rhs, result_ty, op, .WithOverflow);
}

fn airAddSubSaturating(self: *NavGen, inst: Air.Inst.Index, comptime op: BinaryOp) !?IdRef {
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;

const lhs = try self.temporary(bin_op.lhs);
const rhs = try self.temporary(bin_op.rhs);

const result_ty = self.typeOfIndex(inst);
return self.buildAddSub(lhs, rhs, result_ty, op, .Saturating);
}

fn buildMul(
self: *NavGen,
lhs: Temporary,
rhs: Temporary,
result_ty: Type,
comptime mode: enum { WithOverflow, Saturating },
) !?IdRef {
const pt = self.pt;

const info = self.arithmeticTypeInfo(lhs.ty);
switch (info.class) {
Expand Down Expand Up @@ -3889,26 +3946,79 @@ const NavGen = struct {
},
};

const ov = try self.intFromBool(overflowed);
switch (mode) {
.WithOverflow => {
const ov = try self.intFromBool(overflowed);

const result_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(result_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
const result_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(result_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
},
.Saturating => {
const scalar_ty = result_ty.scalarType(self.pt.zcu);
const sat_val_tmp: Temporary = blk: switch (info.signedness) {
.unsigned => {
const max_val: u64 = if (info.bits == 64) std.math.maxInt(u64) else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
const max_id = try self.constInt(scalar_ty, max_val);
break :blk .{ .ty = scalar_ty, .value = .{ .singleton = max_id } };
},
.signed => {
const zero = Temporary.init(rhs.ty, try self.constInt(rhs.ty, 0));
const lhs_is_neg = try self.buildCmp(.s_lt, lhs, zero);
const rhs_is_neg = try self.buildCmp(.s_lt, rhs, zero);
const signs_differ = try self.buildCmp(.l_ne, lhs_is_neg, rhs_is_neg);

const min_val: i64 = if (info.bits == 0) 0 else -(@as(i64, 1) << @as(u6, @intCast(info.bits - 1)));
const max_val: u64 = if (info.bits == 0) 0 else (@as(u64, 1) << @as(u6, @intCast(info.bits - 1))) - 1;

const min_id = try self.constInt(scalar_ty, min_val);
const max_id = try self.constInt(scalar_ty, max_val);

break :blk try self.buildSelect(
signs_differ,
Temporary.init(scalar_ty, min_id),
Temporary.init(scalar_ty, max_id),
);
},
};
const final_result = try self.buildSelect(overflowed, sat_val_tmp, result);
return try final_result.materialize(self);
},
}
}

fn airShlOverflow(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const zcu = self.pt.zcu;

fn airMulOverflow(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;

if (self.typeOf(extra.lhs).isVector(zcu) and !self.typeOf(extra.rhs).isVector(zcu)) {
return self.fail("vector shift with scalar rhs", .{});
}
const lhs = try self.temporary(extra.lhs);
const rhs = try self.temporary(extra.rhs);

const base = try self.temporary(extra.lhs);
const shift = try self.temporary(extra.rhs);
const result_ty = self.typeOfIndex(inst);
return self.buildMul(lhs, rhs, result_ty, .WithOverflow);
}

fn airMulSaturating(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;

const lhs = try self.temporary(bin_op.lhs);
const rhs = try self.temporary(bin_op.rhs);

const result_ty = self.typeOfIndex(inst);
return self.buildMul(lhs, rhs, result_ty, .Saturating);
}

fn buildShl(
self: *NavGen,
base: Temporary,
shift: Temporary,
result_ty: Type,
comptime mode: enum { WithOverflow, Saturating },
) !?IdRef {
const zcu = self.pt.zcu;

if (base.ty.isVector(zcu) and !shift.ty.isVector(zcu)) {
return self.fail("vector shift with scalar rhs", .{});
}

const info = self.arithmeticTypeInfo(base.ty);
switch (info.class) {
Expand All @@ -3924,16 +4034,80 @@ const NavGen = struct {
const left = try self.buildBinary(.sll, base, casted_shift);
const result = try self.normalize(left, info);

// Check if shift amount >= bit width, which always causes overflow (except when base is 0)
const bit_width_id = try self.constInt(base.ty.scalarType(zcu), info.bits);
const bit_width_tmp = Temporary.init(base.ty.scalarType(zcu), bit_width_id);
const shift_too_large = try self.buildCmp(.u_ge, casted_shift, bit_width_tmp);

const zero_tmp = Temporary.init(base.ty, try self.constInt(base.ty, 0));
const base_is_zero = try self.buildCmp(.i_eq, base, zero_tmp);
const large_shift_overflow = try self.buildBinary(.l_and, shift_too_large, try self.buildUnary(.l_not, base_is_zero));

const right = switch (info.signedness) {
.unsigned => try self.buildBinary(.srl, result, casted_shift),
.signed => try self.buildBinary(.sra, result, casted_shift),
};

const overflowed = try self.buildCmp(.i_ne, base, right);
const ov = try self.intFromBool(overflowed);
const round_trip_overflow = try self.buildCmp(.i_ne, base, right);
const overflowed = try self.buildBinary(.l_or, large_shift_overflow, round_trip_overflow);

const result_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(result_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
switch (mode) {
.WithOverflow => {
const ov = try self.intFromBool(overflowed);
const result_ty_id = try self.resolveType(result_ty, .direct);
return try self.constructComposite(result_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
},
.Saturating => {
const sat_val_tmp: Temporary = blk: switch (info.signedness) {
.unsigned => {
const scalar_ty = result_ty.scalarType(self.pt.zcu);
const max_val: u64 = if (info.bits == 64) std.math.maxInt(u64) else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1;
const max_id = try self.constInt(scalar_ty, max_val);
break :blk .{ .ty = scalar_ty, .value = .{ .singleton = max_id } };
},
.signed => {
const zero = Temporary.init(base.ty, try self.constInt(base.ty, 0));
const base_is_neg = try self.buildCmp(.s_lt, base, zero);

const scalar_ty = result_ty.scalarType(self.pt.zcu);
const min_val: i64 = if (info.bits == 0) 0 else -(@as(i64, 1) << @as(u6, @intCast(info.bits - 1)));
const max_val: u64 = if (info.bits == 0) 0 else (@as(u64, 1) << @as(u6, @intCast(info.bits - 1))) - 1;

const min_id = try self.constInt(scalar_ty, min_val);
const max_id = try self.constInt(scalar_ty, max_val);

break :blk try self.buildSelect(
base_is_neg,
Temporary.init(scalar_ty, min_id),
Temporary.init(scalar_ty, max_id),
);
},
};
const final_result = try self.buildSelect(overflowed, sat_val_tmp, result);
return try final_result.materialize(self);
},
}
}

fn airShlOverflow(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;

const base = try self.temporary(extra.lhs);
const shift = try self.temporary(extra.rhs);

const result_ty = self.typeOfIndex(inst);
return self.buildShl(base, shift, result_ty, .WithOverflow);
}

fn airShlSaturating(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;

const base = try self.temporary(bin_op.lhs);
const shift = try self.temporary(bin_op.rhs);

const result_ty = self.typeOfIndex(inst);
return self.buildShl(base, shift, result_ty, .Saturating);
}

fn airMulAdd(self: *NavGen, inst: Air.Inst.Index) !?IdRef {
Expand Down
9 changes: 9 additions & 0 deletions test/behavior/math.zig
Original file line number Diff line number Diff line change
Expand Up @@ -872,6 +872,10 @@ test "@addWithOverflow" {
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO

try testAddWithOverflow(u8, 42, 0, 42, 0);
try testAddWithOverflow(i8, 42, 0, 42, 0);
try testAddWithOverflow(i8, -42, 0, -42, 0);

try testAddWithOverflow(u8, 250, 100, 94, 1);
try testAddWithOverflow(u8, 100, 150, 250, 0);

Expand Down Expand Up @@ -1117,6 +1121,11 @@ test "@subWithOverflow" {
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; // TODO

try testSubWithOverflow(u8, 42, 0, 42, 0);
try testSubWithOverflow(i8, 42, 0, 42, 0);
try testSubWithOverflow(i8, -42, 0, -42, 0);

try testSubWithOverflow(u8, 1, 2, 255, 1);
try testSubWithOverflow(u8, 1, 1, 0, 0);
Expand Down
Loading