Skip to content

Commit cef8c1f

Browse files
spirv: saturating arithmetic implementation (only add/sub)
1 parent b8ac740 commit cef8c1f

File tree

1 file changed

+84
-44
lines changed

1 file changed

+84
-44
lines changed

src/codegen/spirv.zig

Lines changed: 84 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -3249,10 +3249,12 @@ const NavGen = struct {
32493249
.rem, .rem_optimized => try self.airArithOp(inst, .f_rem, .s_rem, .u_mod),
32503250
.mod, .mod_optimized => try self.airArithOp(inst, .f_mod, .s_mod, .u_mod),
32513251

3252-
.add_with_overflow => try self.airAddSubOverflow(inst, .i_add, .u_lt, .s_lt),
3253-
.sub_with_overflow => try self.airAddSubOverflow(inst, .i_sub, .u_gt, .s_gt),
3252+
.add_with_overflow => try self.airAddSubWithOverflow(inst, .i_add),
3253+
.sub_with_overflow => try self.airAddSubWithOverflow(inst, .i_sub),
32543254
.mul_with_overflow => try self.airMulOverflow(inst),
32553255
.shl_with_overflow => try self.airShlOverflow(inst),
3256+
.add_sat => try self.airAddSubSaturating(inst, .i_add),
3257+
.sub_sat => try self.airAddSubSaturating(inst, .i_sub),
32563258

32573259
.mul_add => try self.airMulAdd(inst),
32583260

@@ -3654,68 +3656,106 @@ const NavGen = struct {
36543656
}
36553657
}
36563658

3657-
fn airAddSubOverflow(
3659+
fn buildAddSub(
36583660
self: *NavGen,
3659-
inst: Air.Inst.Index,
3660-
comptime add: BinaryOp,
3661-
comptime ucmp: CmpPredicate,
3662-
comptime scmp: CmpPredicate,
3661+
lhs: Temporary,
3662+
rhs: Temporary,
3663+
result_ty: Type,
3664+
comptime op: BinaryOp,
3665+
comptime mode: enum { WithOverflow, Saturating },
36633666
) !?IdRef {
3664-
// Note: OpIAddCarry and OpISubBorrow are not really useful here: For unsigned numbers,
3665-
// there is in both cases only one extra operation required. For signed operations,
3666-
// the overflow bit is set then going from 0x80.. to 0x00.., but this doesn't actually
3667-
// normally set a carry bit. So the SPIR-V overflow operations are not particularly
3668-
// useful here.
3669-
3670-
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
3671-
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
3672-
3673-
const lhs = try self.temporary(extra.lhs);
3674-
const rhs = try self.temporary(extra.rhs);
3675-
3676-
const result_ty = self.typeOfIndex(inst);
3677-
36783667
const info = self.arithmeticTypeInfo(lhs.ty);
36793668
switch (info.class) {
36803669
.composite_integer => unreachable, // TODO
36813670
.strange_integer, .integer => {},
36823671
.float, .bool => unreachable,
36833672
}
36843673

3685-
const sum = try self.buildBinary(add, lhs, rhs);
3674+
const sum = try self.buildBinary(op, lhs, rhs);
36863675
const result = try self.normalize(sum, info);
36873676

36883677
const overflowed = switch (info.signedness) {
36893678
// Overflow happened if the result is smaller than either of the operands. It doesn't matter which.
36903679
// For subtraction the conditions need to be swapped.
3691-
.unsigned => try self.buildCmp(ucmp, result, lhs),
3692-
// For addition, overflow happened if:
3693-
// - rhs is negative and value > lhs
3694-
// - rhs is positive and value < lhs
3695-
// This can be shortened to:
3696-
// (rhs < 0 and value > lhs) or (rhs >= 0 and value <= lhs)
3697-
// = (rhs < 0) == (value > lhs)
3698-
// = (rhs < 0) == (lhs < value)
3699-
// Note that signed overflow is also wrapping in spir-v.
3700-
// For subtraction, overflow happened if:
3701-
// - rhs is negative and value < lhs
3702-
// - rhs is positive and value > lhs
3703-
// This can be shortened to:
3704-
// (rhs < 0 and value < lhs) or (rhs >= 0 and value >= lhs)
3705-
// = (rhs < 0) == (value < lhs)
3706-
// = (rhs < 0) == (lhs > value)
3680+
.unsigned => switch (op) {
3681+
.i_add => try self.buildCmp(.u_lt, result, lhs),
3682+
.i_sub => try self.buildCmp(.u_gt, result, lhs),
3683+
else => unreachable,
3684+
},
3685+
// For signed operations, we check the signs of the operands and the result.
37073686
.signed => blk: {
3687+
// Signed overflow detection using the sign bits of the operands and the result.
3688+
// For addition (a + b), overflow occurs if the operands have the same sign
3689+
// and the result's sign is different from the operands' sign.
3690+
// (sign(a) == sign(b)) && (sign(a) != sign(result))
3691+
// For subtraction (a - b), overflow occurs if the operands have different signs
3692+
// and the result's sign is different from the minuend's (a's) sign.
3693+
// (sign(a) != sign(b)) && (sign(a) != sign(result))
37083694
const zero = Temporary.init(rhs.ty, try self.constInt(rhs.ty, 0));
3709-
const rhs_lt_zero = try self.buildCmp(.s_lt, rhs, zero);
3710-
const result_gt_lhs = try self.buildCmp(scmp, lhs, result);
3711-
break :blk try self.buildCmp(.l_eq, rhs_lt_zero, result_gt_lhs);
3695+
3696+
const lhs_is_neg = try self.buildCmp(.s_lt, lhs, zero);
3697+
const rhs_is_neg = try self.buildCmp(.s_lt, rhs, zero);
3698+
const result_is_neg = try self.buildCmp(.s_lt, result, zero);
3699+
3700+
const signs_match = try self.buildCmp(.l_eq, lhs_is_neg, rhs_is_neg);
3701+
const result_sign_differs = try self.buildCmp(.l_ne, lhs_is_neg, result_is_neg);
3702+
3703+
const overflow_condition = if (op == .i_add)
3704+
signs_match
3705+
else // .i_sub
3706+
try self.buildUnary(.l_not, signs_match);
3707+
3708+
break :blk try self.buildBinary(.l_and, overflow_condition, result_sign_differs);
37123709
},
37133710
};
37143711

3715-
const ov = try self.intFromBool(overflowed);
3712+
switch (mode) {
3713+
.WithOverflow => {
3714+
const ov = try self.intFromBool(overflowed);
3715+
const struct_ty_id = try self.resolveType(result_ty, .direct);
3716+
return try self.constructComposite(struct_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
3717+
},
3718+
.Saturating => {
3719+
const sat_val_tmp = blk: {
3720+
const scalar_ty = result_ty.scalarType(self.pt.zcu);
3721+
if (info.signedness == .signed and op == .i_sub) {
3722+
const min_val: i64 = if (info.bits == 0) 0 else -(@as(i64, 1) << @as(u6, @intCast(info.bits - 1)));
3723+
const min_id = try self.constInt(scalar_ty, min_val);
3724+
break :blk Temporary.init(scalar_ty, min_id);
3725+
} else {
3726+
const max_val: u64 = if (info.bits == 0) 0 else switch (info.signedness) {
3727+
.unsigned => if (info.bits == 64) std.math.maxInt(u64) else (@as(u64, 1) << @as(u6, @intCast(info.bits))) - 1,
3728+
.signed => (@as(u64, 1) << @as(u6, @intCast(info.bits - 1))) - 1,
3729+
};
3730+
const max_id = try self.constInt(scalar_ty, max_val);
3731+
break :blk Temporary.init(scalar_ty, max_id);
3732+
}
3733+
};
3734+
const final_result = try self.buildSelect(overflowed, sat_val_tmp, result);
3735+
return try final_result.materialize(self);
3736+
},
3737+
}
3738+
}
37163739

3717-
const result_ty_id = try self.resolveType(result_ty, .direct);
3718-
return try self.constructComposite(result_ty_id, &.{ try result.materialize(self), try ov.materialize(self) });
3740+
fn airAddSubWithOverflow(self: *NavGen, inst: Air.Inst.Index, comptime op: BinaryOp) !?IdRef {
3741+
const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl;
3742+
const extra = self.air.extraData(Air.Bin, ty_pl.payload).data;
3743+
3744+
const lhs = try self.temporary(extra.lhs);
3745+
const rhs = try self.temporary(extra.rhs);
3746+
3747+
const result_ty = self.typeOfIndex(inst);
3748+
return self.buildAddSub(lhs, rhs, result_ty, op, .WithOverflow);
3749+
}
3750+
3751+
fn airAddSubSaturating(self: *NavGen, inst: Air.Inst.Index, comptime op: BinaryOp) !?IdRef {
3752+
const bin_op = self.air.instructions.items(.data)[@intFromEnum(inst)].bin_op;
3753+
3754+
const lhs = try self.temporary(bin_op.lhs);
3755+
const rhs = try self.temporary(bin_op.rhs);
3756+
3757+
const result_ty = self.typeOfIndex(inst);
3758+
return self.buildAddSub(lhs, rhs, result_ty, op, .Saturating);
37193759
}
37203760

37213761
fn airMulOverflow(self: *NavGen, inst: Air.Inst.Index) !?IdRef {

0 commit comments

Comments
 (0)