Skip to content

Commit 539f3ef

Browse files
authored
Merge pull request #21933 from kcbanner/comptime_nan_comparison
Fix float vector comparisons with signed zero and NaN, add test coverage
2 parents 9840157 + 144d69b commit 539f3ef

File tree

3 files changed

+138
-2
lines changed

3 files changed

+138
-2
lines changed

src/Sema.zig

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38029,6 +38029,11 @@ fn compareScalar(
3802938029
const pt = sema.pt;
3803038030
const coerced_lhs = try pt.getCoerced(lhs, ty);
3803138031
const coerced_rhs = try pt.getCoerced(rhs, ty);
38032+
38033+
// Equality comparisons of signed zero and NaN need to use floating point semantics
38034+
if (coerced_lhs.isFloat(pt.zcu) or coerced_rhs.isFloat(pt.zcu))
38035+
return Value.compareHeteroSema(coerced_lhs, op, coerced_rhs, pt);
38036+
3803238037
switch (op) {
3803338038
.eq => return sema.valuesEqual(coerced_lhs, coerced_rhs, ty),
3803438039
.neq => return !(try sema.valuesEqual(coerced_lhs, coerced_rhs, ty)),

src/Value.zig

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1132,6 +1132,8 @@ pub fn compareHeteroAdvanced(
11321132
else => {},
11331133
}
11341134
}
1135+
1136+
if (lhs.isNan(zcu) or rhs.isNan(zcu)) return op == .neq;
11351137
return (try orderAdvanced(lhs, rhs, strat, zcu, tid)).compare(op);
11361138
}
11371139

test/behavior/floatop.zig

Lines changed: 131 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -132,13 +132,20 @@ test "cmp f16" {
132132
try comptime testCmp(f16);
133133
}
134134

135-
test "cmp f32/f64" {
135+
test "cmp f32" {
136136
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
137-
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
138137
if (builtin.cpu.arch.isArm() and builtin.target.abi.float() == .soft) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/21234
138+
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
139139

140140
try testCmp(f32);
141141
try comptime testCmp(f32);
142+
}
143+
144+
test "cmp f64" {
145+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
146+
if (builtin.zig_backend == .stage2_x86_64 and builtin.target.ofmt != .elf and builtin.target.ofmt != .macho) return error.SkipZigTest;
147+
if (builtin.cpu.arch.isArm() and builtin.target.abi.float() == .soft) return error.SkipZigTest; // https://github.com/ziglang/zig/issues/21234
148+
142149
try testCmp(f64);
143150
try comptime testCmp(f64);
144151
}
@@ -224,6 +231,98 @@ fn testCmp(comptime T: type) !void {
224231
}
225232
}
226233

234+
test "vector cmp f16" {
235+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
236+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
237+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest;
238+
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
239+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
240+
if (builtin.cpu.arch.isArm()) return error.SkipZigTest;
241+
if (builtin.cpu.arch.isPowerPC64()) return error.SkipZigTest;
242+
243+
try testCmpVector(f16);
244+
try comptime testCmpVector(f16);
245+
}
246+
247+
test "vector cmp f32" {
248+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
249+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
250+
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
251+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
252+
if (builtin.cpu.arch.isArm()) return error.SkipZigTest;
253+
if (builtin.cpu.arch.isPowerPC64()) return error.SkipZigTest;
254+
255+
try testCmpVector(f32);
256+
try comptime testCmpVector(f32);
257+
}
258+
259+
test "vector cmp f64" {
260+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
261+
if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO
262+
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
263+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
264+
if (builtin.cpu.arch.isArm()) return error.SkipZigTest;
265+
if (builtin.cpu.arch.isPowerPC64()) return error.SkipZigTest;
266+
267+
try testCmpVector(f64);
268+
try comptime testCmpVector(f64);
269+
}
270+
271+
test "vector cmp f128" {
272+
if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO
273+
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
274+
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
275+
if (builtin.zig_backend == .stage2_c and builtin.cpu.arch.isArm()) return error.SkipZigTest;
276+
if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest;
277+
if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest;
278+
if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest;
279+
if (builtin.cpu.arch.isArm()) return error.SkipZigTest;
280+
if (builtin.cpu.arch.isPowerPC64()) return error.SkipZigTest;
281+
282+
try testCmpVector(f128);
283+
try comptime testCmpVector(f128);
284+
}
285+
286+
test "vector cmp f80/c_longdouble" {
287+
if (true) return error.SkipZigTest;
288+
289+
try testCmpVector(f80);
290+
try comptime testCmpVector(f80);
291+
try testCmpVector(c_longdouble);
292+
try comptime testCmpVector(c_longdouble);
293+
}
294+
fn testCmpVector(comptime T: type) !void {
295+
var edges = [_]T{
296+
-math.inf(T),
297+
-math.floatMax(T),
298+
-math.floatMin(T),
299+
-math.floatTrueMin(T),
300+
-0.0,
301+
math.nan(T),
302+
0.0,
303+
math.floatTrueMin(T),
304+
math.floatMin(T),
305+
math.floatMax(T),
306+
math.inf(T),
307+
};
308+
_ = &edges;
309+
for (edges, 0..) |rhs, rhs_i| {
310+
const rhs_v: @Vector(4, T) = .{ rhs, rhs, rhs, rhs };
311+
for (edges, 0..) |lhs, lhs_i| {
312+
const no_nan = lhs_i != 5 and rhs_i != 5;
313+
const lhs_order = if (lhs_i < 5) lhs_i else lhs_i - 2;
314+
const rhs_order = if (rhs_i < 5) rhs_i else rhs_i - 2;
315+
const lhs_v: @Vector(4, T) = .{ lhs, lhs, lhs, lhs };
316+
try expect(@reduce(.And, (lhs_v == rhs_v)) == (no_nan and lhs_order == rhs_order));
317+
try expect(@reduce(.And, (lhs_v != rhs_v)) == !(no_nan and lhs_order == rhs_order));
318+
try expect(@reduce(.And, (lhs_v < rhs_v)) == (no_nan and lhs_order < rhs_order));
319+
try expect(@reduce(.And, (lhs_v > rhs_v)) == (no_nan and lhs_order > rhs_order));
320+
try expect(@reduce(.And, (lhs_v <= rhs_v)) == (no_nan and lhs_order <= rhs_order));
321+
try expect(@reduce(.And, (lhs_v >= rhs_v)) == (no_nan and lhs_order >= rhs_order));
322+
}
323+
}
324+
}
325+
227326
test "different sized float comparisons" {
228327
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
229328
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
@@ -1703,3 +1802,33 @@ test "optimized float mode" {
17031802
try expect(S.optimized(small) == small);
17041803
try expect(S.strict(small) == tiny);
17051804
}
1805+
1806+
fn MakeType(comptime x: anytype) type {
1807+
return struct {
1808+
fn get() @TypeOf(x) {
1809+
return x;
1810+
}
1811+
};
1812+
}
1813+
1814+
const nan_a: f32 = @bitCast(@as(u32, 0xffc00000));
1815+
const nan_b: f32 = @bitCast(@as(u32, 0xffe00000));
1816+
1817+
fn testMemoization() !void {
1818+
try expect(MakeType(nan_a) == MakeType(nan_a));
1819+
try expect(MakeType(nan_b) == MakeType(nan_b));
1820+
try expect(MakeType(nan_a) != MakeType(nan_b));
1821+
}
1822+
1823+
fn testVectorMemoization(comptime T: type) !void {
1824+
const nan_a_v: T = @splat(nan_a);
1825+
const nan_b_v: T = @splat(nan_b);
1826+
try expect(MakeType(nan_a_v) == MakeType(nan_a_v));
1827+
try expect(MakeType(nan_b_v) == MakeType(nan_b_v));
1828+
try expect(MakeType(nan_a_v) != MakeType(nan_b_v));
1829+
}
1830+
1831+
test "comptime calls are only memoized when float arguments are bit-for-bit equal" {
1832+
try comptime testMemoization();
1833+
try comptime testVectorMemoization(@Vector(4, f32));
1834+
}

0 commit comments

Comments
 (0)