Skip to content

Commit c764640

Browse files
committed
Sema: fix generics with struct literal coerced to tagged union
The `Value.eql` function has to test for value equality *as-if* the lhs value parameter is coerced into the type of the rhs. For tagged unions, there was a problematic case when the lhs was an anonymous struct, because in such case the value is empty_struct_value and the type contains all the value information. But the only type available in the function was the rhs type. So the fix involved making `Value.eqlAdvanced` also accept the lhs type, and then enhancing the logic to handle the case of the `.anon_struct` tag. closes #12418 Tests run locally: * test-behavior * test-cases
1 parent a12abc6 commit c764640

File tree

3 files changed

+110
-42
lines changed

3 files changed

+110
-42
lines changed

src/Sema.zig

Lines changed: 37 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1495,7 +1495,8 @@ pub fn resolveInst(sema: *Sema, zir_ref: Zir.Inst.Ref) !Air.Inst.Ref {
14951495

14961496
// Finally, the last section of indexes refers to the map of ZIR=>AIR.
14971497
const inst = sema.inst_map.get(@intCast(u32, i)).?;
1498-
if (sema.typeOf(inst).tag() == .generic_poison) return error.GenericPoison;
1498+
const ty = sema.typeOf(inst);
1499+
if (ty.tag() == .generic_poison) return error.GenericPoison;
14991500
return inst;
15001501
}
15011502

@@ -5570,11 +5571,15 @@ const GenericCallAdapter = struct {
55705571
generic_fn: *Module.Fn,
55715572
precomputed_hash: u64,
55725573
func_ty_info: Type.Payload.Function.Data,
5573-
/// Unlike comptime_args, the Type here is not always present.
5574-
/// .generic_poison is used to communicate non-anytype parameters.
5575-
comptime_tvs: []const TypedValue,
5574+
args: []const Arg,
55765575
module: *Module,
55775576

5577+
const Arg = struct {
5578+
ty: Type,
5579+
val: Value,
5580+
is_anytype: bool,
5581+
};
5582+
55785583
pub fn eql(ctx: @This(), adapted_key: void, other_key: *Module.Fn) bool {
55795584
_ = adapted_key;
55805585
// The generic function Decl is guaranteed to be the first dependency
@@ -5585,10 +5590,10 @@ const GenericCallAdapter = struct {
55855590

55865591
const other_comptime_args = other_key.comptime_args.?;
55875592
for (other_comptime_args[0..ctx.func_ty_info.param_types.len]) |other_arg, i| {
5588-
const this_arg = ctx.comptime_tvs[i];
5593+
const this_arg = ctx.args[i];
55895594
const this_is_comptime = this_arg.val.tag() != .generic_poison;
55905595
const other_is_comptime = other_arg.val.tag() != .generic_poison;
5591-
const this_is_anytype = this_arg.ty.tag() != .generic_poison;
5596+
const this_is_anytype = this_arg.is_anytype;
55925597
const other_is_anytype = other_key.isAnytypeParam(ctx.module, @intCast(u32, i));
55935598

55945599
if (other_is_anytype != this_is_anytype) return false;
@@ -5607,7 +5612,17 @@ const GenericCallAdapter = struct {
56075612
}
56085613
} else if (this_is_comptime) {
56095614
// Both are comptime parameters but not anytype parameters.
5610-
if (!this_arg.val.eql(other_arg.val, other_arg.ty, ctx.module)) {
5615+
// We assert no error is possible here because any lazy values must be resolved
5616+
// before inserting into the generic function hash map.
5617+
const is_eql = Value.eqlAdvanced(
5618+
this_arg.val,
5619+
this_arg.ty,
5620+
other_arg.val,
5621+
other_arg.ty,
5622+
ctx.module,
5623+
null,
5624+
) catch unreachable;
5625+
if (!is_eql) {
56115626
return false;
56125627
}
56135628
}
@@ -6258,8 +6273,7 @@ fn instantiateGenericCall(
62586273
var hasher = std.hash.Wyhash.init(0);
62596274
std.hash.autoHash(&hasher, @ptrToInt(module_fn));
62606275

6261-
const comptime_tvs = try sema.arena.alloc(TypedValue, func_ty_info.param_types.len);
6262-
6276+
const generic_args = try sema.arena.alloc(GenericCallAdapter.Arg, func_ty_info.param_types.len);
62636277
{
62646278
var i: usize = 0;
62656279
for (fn_info.param_body) |inst| {
@@ -6283,8 +6297,9 @@ fn instantiateGenericCall(
62836297
else => continue,
62846298
}
62856299

6300+
const arg_ty = sema.typeOf(uncasted_args[i]);
6301+
62866302
if (is_comptime) {
6287-
const arg_ty = sema.typeOf(uncasted_args[i]);
62886303
const arg_val = sema.analyzeGenericCallArgVal(block, .unneeded, uncasted_args[i]) catch |err| switch (err) {
62896304
error.NeededSourceLocation => {
62906305
const decl = sema.mod.declPtr(block.src_decl);
@@ -6297,27 +6312,30 @@ fn instantiateGenericCall(
62976312
arg_val.hash(arg_ty, &hasher, mod);
62986313
if (is_anytype) {
62996314
arg_ty.hashWithHasher(&hasher, mod);
6300-
comptime_tvs[i] = .{
6315+
generic_args[i] = .{
63016316
.ty = arg_ty,
63026317
.val = arg_val,
6318+
.is_anytype = true,
63036319
};
63046320
} else {
6305-
comptime_tvs[i] = .{
6306-
.ty = Type.initTag(.generic_poison),
6321+
generic_args[i] = .{
6322+
.ty = arg_ty,
63076323
.val = arg_val,
6324+
.is_anytype = false,
63086325
};
63096326
}
63106327
} else if (is_anytype) {
6311-
const arg_ty = sema.typeOf(uncasted_args[i]);
63126328
arg_ty.hashWithHasher(&hasher, mod);
6313-
comptime_tvs[i] = .{
6329+
generic_args[i] = .{
63146330
.ty = arg_ty,
63156331
.val = Value.initTag(.generic_poison),
6332+
.is_anytype = true,
63166333
};
63176334
} else {
6318-
comptime_tvs[i] = .{
6319-
.ty = Type.initTag(.generic_poison),
6335+
generic_args[i] = .{
6336+
.ty = arg_ty,
63206337
.val = Value.initTag(.generic_poison),
6338+
.is_anytype = false,
63216339
};
63226340
}
63236341

@@ -6331,7 +6349,7 @@ fn instantiateGenericCall(
63316349
.generic_fn = module_fn,
63326350
.precomputed_hash = precomputed_hash,
63336351
.func_ty_info = func_ty_info,
6334-
.comptime_tvs = comptime_tvs,
6352+
.args = generic_args,
63356353
.module = mod,
63366354
};
63376355
const gop = try mod.monomorphed_funcs.getOrPutAdapted(gpa, {}, adapter);
@@ -30124,7 +30142,7 @@ fn valuesEqual(
3012430142
rhs: Value,
3012530143
ty: Type,
3012630144
) CompileError!bool {
30127-
return Value.eqlAdvanced(lhs, rhs, ty, sema.mod, sema.kit(block, src));
30145+
return Value.eqlAdvanced(lhs, ty, rhs, ty, sema.mod, sema.kit(block, src));
3012830146
}
3012930147

3013030148
/// Asserts the values are comparable vectors of type `ty`.

src/value.zig

Lines changed: 54 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2004,6 +2004,10 @@ pub const Value = extern union {
20042004
return (try orderAgainstZeroAdvanced(lhs, sema_kit)).compare(op);
20052005
}
20062006

2007+
pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool {
2008+
return eqlAdvanced(a, ty, b, ty, mod, null) catch unreachable;
2009+
}
2010+
20072011
/// This function is used by hash maps and so treats floating-point NaNs as equal
20082012
/// to each other, and not equal to other floating-point values.
20092013
/// Similarly, it treats `undef` as a distinct value from all other values.
@@ -2012,13 +2016,10 @@ pub const Value = extern union {
20122016
/// for `a`. This function must act *as if* `a` has been coerced to `ty`. This complication
20132017
/// is required in order to make generic function instantiation efficient - specifically
20142018
/// the insertion into the monomorphized function table.
2015-
pub fn eql(a: Value, b: Value, ty: Type, mod: *Module) bool {
2016-
return eqlAdvanced(a, b, ty, mod, null) catch unreachable;
2017-
}
2018-
20192019
/// If `null` is provided for `sema_kit` then it is guaranteed no error will be returned.
20202020
pub fn eqlAdvanced(
20212021
a: Value,
2022+
a_ty: Type,
20222023
b: Value,
20232024
ty: Type,
20242025
mod: *Module,
@@ -2044,33 +2045,34 @@ pub const Value = extern union {
20442045
const a_payload = a.castTag(.opt_payload).?.data;
20452046
const b_payload = b.castTag(.opt_payload).?.data;
20462047
var buffer: Type.Payload.ElemType = undefined;
2047-
return eqlAdvanced(a_payload, b_payload, ty.optionalChild(&buffer), mod, sema_kit);
2048+
const payload_ty = ty.optionalChild(&buffer);
2049+
return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit);
20482050
},
20492051
.slice => {
20502052
const a_payload = a.castTag(.slice).?.data;
20512053
const b_payload = b.castTag(.slice).?.data;
2052-
if (!(try eqlAdvanced(a_payload.len, b_payload.len, Type.usize, mod, sema_kit))) {
2054+
if (!(try eqlAdvanced(a_payload.len, Type.usize, b_payload.len, Type.usize, mod, sema_kit))) {
20532055
return false;
20542056
}
20552057

20562058
var ptr_buf: Type.SlicePtrFieldTypeBuffer = undefined;
20572059
const ptr_ty = ty.slicePtrFieldType(&ptr_buf);
20582060

2059-
return eqlAdvanced(a_payload.ptr, b_payload.ptr, ptr_ty, mod, sema_kit);
2061+
return eqlAdvanced(a_payload.ptr, ptr_ty, b_payload.ptr, ptr_ty, mod, sema_kit);
20602062
},
20612063
.elem_ptr => {
20622064
const a_payload = a.castTag(.elem_ptr).?.data;
20632065
const b_payload = b.castTag(.elem_ptr).?.data;
20642066
if (a_payload.index != b_payload.index) return false;
20652067

2066-
return eqlAdvanced(a_payload.array_ptr, b_payload.array_ptr, ty, mod, sema_kit);
2068+
return eqlAdvanced(a_payload.array_ptr, ty, b_payload.array_ptr, ty, mod, sema_kit);
20672069
},
20682070
.field_ptr => {
20692071
const a_payload = a.castTag(.field_ptr).?.data;
20702072
const b_payload = b.castTag(.field_ptr).?.data;
20712073
if (a_payload.field_index != b_payload.field_index) return false;
20722074

2073-
return eqlAdvanced(a_payload.container_ptr, b_payload.container_ptr, ty, mod, sema_kit);
2075+
return eqlAdvanced(a_payload.container_ptr, ty, b_payload.container_ptr, ty, mod, sema_kit);
20742076
},
20752077
.@"error" => {
20762078
const a_name = a.castTag(.@"error").?.data.name;
@@ -2080,7 +2082,8 @@ pub const Value = extern union {
20802082
.eu_payload => {
20812083
const a_payload = a.castTag(.eu_payload).?.data;
20822084
const b_payload = b.castTag(.eu_payload).?.data;
2083-
return eqlAdvanced(a_payload, b_payload, ty.errorUnionPayload(), mod, sema_kit);
2085+
const payload_ty = ty.errorUnionPayload();
2086+
return eqlAdvanced(a_payload, payload_ty, b_payload, payload_ty, mod, sema_kit);
20842087
},
20852088
.eu_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
20862089
.opt_payload_ptr => @panic("TODO: Implement more pointer eql cases"),
@@ -2098,7 +2101,7 @@ pub const Value = extern union {
20982101
const types = ty.tupleFields().types;
20992102
assert(types.len == a_field_vals.len);
21002103
for (types) |field_ty, i| {
2101-
if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field_ty, mod, sema_kit))) {
2104+
if (!(try eqlAdvanced(a_field_vals[i], field_ty, b_field_vals[i], field_ty, mod, sema_kit))) {
21022105
return false;
21032106
}
21042107
}
@@ -2109,7 +2112,7 @@ pub const Value = extern union {
21092112
const fields = ty.structFields().values();
21102113
assert(fields.len == a_field_vals.len);
21112114
for (fields) |field, i| {
2112-
if (!(try eqlAdvanced(a_field_vals[i], b_field_vals[i], field.ty, mod, sema_kit))) {
2115+
if (!(try eqlAdvanced(a_field_vals[i], field.ty, b_field_vals[i], field.ty, mod, sema_kit))) {
21132116
return false;
21142117
}
21152118
}
@@ -2120,7 +2123,7 @@ pub const Value = extern union {
21202123
for (a_field_vals) |a_elem, i| {
21212124
const b_elem = b_field_vals[i];
21222125

2123-
if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) {
2126+
if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) {
21242127
return false;
21252128
}
21262129
}
@@ -2132,21 +2135,21 @@ pub const Value = extern union {
21322135
switch (ty.containerLayout()) {
21332136
.Packed, .Extern => {
21342137
const tag_ty = ty.unionTagTypeHypothetical();
2135-
if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) {
2138+
if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) {
21362139
// In this case, we must disregard mismatching tags and compare
21372140
// based on the in-memory bytes of the payloads.
21382141
@panic("TODO comptime comparison of extern union values with mismatching tags");
21392142
}
21402143
},
21412144
.Auto => {
21422145
const tag_ty = ty.unionTagTypeHypothetical();
2143-
if (!(try a_union.tag.eqlAdvanced(b_union.tag, tag_ty, mod, sema_kit))) {
2146+
if (!(try eqlAdvanced(a_union.tag, tag_ty, b_union.tag, tag_ty, mod, sema_kit))) {
21442147
return false;
21452148
}
21462149
},
21472150
}
21482151
const active_field_ty = ty.unionFieldType(a_union.tag, mod);
2149-
return a_union.val.eqlAdvanced(b_union.val, active_field_ty, mod, sema_kit);
2152+
return eqlAdvanced(a_union.val, active_field_ty, b_union.val, active_field_ty, mod, sema_kit);
21502153
},
21512154
else => {},
21522155
} else if (a_tag == .null_value or b_tag == .null_value) {
@@ -2180,7 +2183,7 @@ pub const Value = extern union {
21802183
const b_val = b.enumToInt(ty, &buf_b);
21812184
var buf_ty: Type.Payload.Bits = undefined;
21822185
const int_ty = ty.intTagType(&buf_ty);
2183-
return eqlAdvanced(a_val, b_val, int_ty, mod, sema_kit);
2186+
return eqlAdvanced(a_val, int_ty, b_val, int_ty, mod, sema_kit);
21842187
},
21852188
.Array, .Vector => {
21862189
const len = ty.arrayLen();
@@ -2191,17 +2194,44 @@ pub const Value = extern union {
21912194
while (i < len) : (i += 1) {
21922195
const a_elem = elemValueBuffer(a, mod, i, &a_buf);
21932196
const b_elem = elemValueBuffer(b, mod, i, &b_buf);
2194-
if (!(try eqlAdvanced(a_elem, b_elem, elem_ty, mod, sema_kit))) {
2197+
if (!(try eqlAdvanced(a_elem, elem_ty, b_elem, elem_ty, mod, sema_kit))) {
21952198
return false;
21962199
}
21972200
}
21982201
return true;
21992202
},
22002203
.Struct => {
2201-
// A tuple can be represented with .empty_struct_value,
2202-
// the_one_possible_value, .aggregate in which case we could
2203-
// end up here and the values are equal if the type has zero fields.
2204-
return ty.isTupleOrAnonStruct() and ty.structFieldCount() != 0;
2204+
// A struct can be represented with one of:
2205+
// .empty_struct_value,
2206+
// .the_one_possible_value,
2207+
// .aggregate,
2208+
// Note that we already checked above for matching tags, e.g. both .aggregate.
2209+
return ty.onePossibleValue() != null;
2210+
},
2211+
.Union => {
2212+
// Here we have to check for value equality, as-if `a` has been coerced to `ty`.
2213+
if (ty.onePossibleValue() != null) {
2214+
return true;
2215+
}
2216+
if (a_ty.castTag(.anon_struct)) |payload| {
2217+
const tuple = payload.data;
2218+
if (tuple.values.len != 1) {
2219+
return false;
2220+
}
2221+
const field_name = tuple.names[0];
2222+
const union_obj = ty.cast(Type.Payload.Union).?.data;
2223+
const field_index = union_obj.fields.getIndex(field_name) orelse return false;
2224+
const tag_and_val = b.castTag(.@"union").?.data;
2225+
var field_tag_buf: Value.Payload.U32 = .{
2226+
.base = .{ .tag = .enum_field_index },
2227+
.data = @intCast(u32, field_index),
2228+
};
2229+
const field_tag = Value.initPayload(&field_tag_buf.base);
2230+
const tag_matches = tag_and_val.tag.eql(field_tag, union_obj.tag_ty, mod);
2231+
if (!tag_matches) return false;
2232+
return eqlAdvanced(tag_and_val.val, union_obj.tag_ty, tuple.values[0], tuple.types[0], mod, sema_kit);
2233+
}
2234+
return false;
22052235
},
22062236
.Float => {
22072237
switch (ty.floatBits(target)) {
@@ -2230,7 +2260,8 @@ pub const Value = extern union {
22302260
.base = .{ .tag = .opt_payload },
22312261
.data = a,
22322262
};
2233-
return eqlAdvanced(Value.initPayload(&buffer.base), b, ty, mod, sema_kit);
2263+
const opt_val = Value.initPayload(&buffer.base);
2264+
return eqlAdvanced(opt_val, ty, b, ty, mod, sema_kit);
22342265
}
22352266
},
22362267
else => {},

test/behavior/generics.zig

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,3 +323,22 @@ test "generic function instantiation non-duplicates" {
323323
S.copy(u8, &buffer, "hello");
324324
S.copy(u8, &buffer, "hello2");
325325
}
326+
327+
test "generic instantiation of tagged union with only one field" {
328+
if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO
329+
if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO
330+
if (builtin.os.tag == .wasi) return error.SkipZigTest;
331+
332+
const S = struct {
333+
const U = union(enum) {
334+
s: []const u8,
335+
};
336+
337+
fn foo(comptime u: U) usize {
338+
return u.s.len;
339+
}
340+
};
341+
342+
try expect(S.foo(.{ .s = "a" }) == 1);
343+
try expect(S.foo(.{ .s = "ab" }) == 2);
344+
}

0 commit comments

Comments
 (0)