Skip to content

Enhance switch on non-exhaustive enums #24381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
207 changes: 122 additions & 85 deletions lib/std/zig/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -7652,10 +7652,12 @@ fn switchExpr(
var scalar_cases_len: u32 = 0;
var multi_cases_len: u32 = 0;
var inline_cases_len: u32 = 0;
var special_prong: Zir.SpecialProng = .none;
var special_node: Ast.Node.OptionalIndex = .none;
var else_case_node: Ast.Node.OptionalIndex = .none;
var else_src: ?Ast.TokenIndex = null;
var underscore_case_node: Ast.Node.OptionalIndex = .none;
var underscore_node: Ast.Node.OptionalIndex = .none;
var underscore_src: ?Ast.TokenIndex = null;
var underscore_additional_items: Zir.SpecialProngs.AdditionalItems = .none;
for (case_nodes) |case_node| {
const case = tree.fullSwitchCase(case_node).?;
if (case.payload_token) |payload_token| {
Expand All @@ -7676,7 +7678,8 @@ fn switchExpr(
any_non_inline_capture = true;
}
}
// Check for else/`_` prong.

// Check for else prong.
if (case.ast.values.len == 0) {
const case_src = case.ast.arrow_token - 1;
if (else_src) |src| {
Expand All @@ -7692,79 +7695,51 @@ fn switchExpr(
),
},
);
} else if (underscore_src) |some_underscore| {
return astgen.failNodeNotes(
node,
"else and '_' prong in switch expression",
.{},
&[_]u32{
try astgen.errNoteTok(
case_src,
"else prong here",
.{},
),
try astgen.errNoteTok(
some_underscore,
"'_' prong here",
.{},
),
},
);
}
special_node = case_node.toOptional();
special_prong = .@"else";
else_case_node = case_node.toOptional();
else_src = case_src;
continue;
} else if (case.ast.values.len == 1 and
tree.nodeTag(case.ast.values[0]) == .identifier and
mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(case.ast.values[0])), "_"))
{
const case_src = case.ast.arrow_token - 1;
if (underscore_src) |src| {
return astgen.failTokNotes(
case_src,
"multiple '_' prongs in switch expression",
.{},
&[_]u32{
try astgen.errNoteTok(
src,
"previous '_' prong here",
.{},
),
},
);
} else if (else_src) |some_else| {
return astgen.failNodeNotes(
node,
"else and '_' prong in switch expression",
.{},
&[_]u32{
try astgen.errNoteTok(
some_else,
"else prong here",
.{},
),
try astgen.errNoteTok(
case_src,
"'_' prong here",
.{},
),
},
);
}
if (case.inline_token != null) {
return astgen.failTok(case_src, "cannot inline '_' prong", .{});
}
special_node = case_node.toOptional();
special_prong = .under;
underscore_src = case_src;
continue;
}

// Check for '_' prong.
var case_has_underscore = false;
for (case.ast.values) |val| {
if (tree.nodeTag(val) == .string_literal)
return astgen.failNode(val, "cannot switch on strings", .{});
switch (tree.nodeTag(val)) {
.identifier => if (mem.eql(u8, tree.tokenSlice(tree.nodeMainToken(val)), "_")) {
const val_src = tree.nodeMainToken(val);
if (underscore_src) |src| {
return astgen.failTokNotes(
val_src,
"multiple '_' prongs in switch expression",
.{},
&[_]u32{
try astgen.errNoteTok(
src,
"previous '_' prong here",
.{},
),
},
);
}
if (case.inline_token != null) {
return astgen.failTok(val_src, "cannot inline '_' prong", .{});
}
underscore_case_node = case_node.toOptional();
underscore_src = val_src;
underscore_node = val.toOptional();
underscore_additional_items = switch (case.ast.values.len) {
0 => unreachable,
1 => .none,
2 => .one,
else => .many,
};
case_has_underscore = true;
},
.string_literal => return astgen.failNode(val, "cannot switch on strings", .{}),
else => {},
}
}
if (case_has_underscore) continue;

if (case.ast.values.len == 1 and tree.nodeTag(case.ast.values[0]) != .switch_range) {
scalar_cases_len += 1;
Expand All @@ -7776,6 +7751,14 @@ fn switchExpr(
}
}

const special_prongs: Zir.SpecialProngs = .init(
else_src != null,
underscore_src != null,
underscore_additional_items,
);
const has_else = special_prongs.hasElse();
const has_under = special_prongs.hasUnder();

const operand_ri: ResultInfo = .{ .rl = if (any_payload_is_ref) .ref else .none };

astgen.advanceSourceCursorToNode(operand_node);
Expand All @@ -7796,7 +7779,9 @@ fn switchExpr(
const payloads = &astgen.scratch;
const scratch_top = astgen.scratch.items.len;
const case_table_start = scratch_top;
const scalar_case_table = case_table_start + @intFromBool(special_prong != .none);
const else_case_index = if (has_else) case_table_start else undefined;
const under_case_index = if (has_under) case_table_start + @intFromBool(has_else) else undefined;
const scalar_case_table = case_table_start + @intFromBool(has_else) + @intFromBool(has_under);
const multi_case_table = scalar_case_table + scalar_cases_len;
const case_table_end = multi_case_table + multi_cases_len;
try astgen.scratch.resize(gpa, case_table_end);
Expand Down Expand Up @@ -7928,14 +7913,33 @@ fn switchExpr(

const header_index: u32 = @intCast(payloads.items.len);
const body_len_index = if (is_multi_case) blk: {
payloads.items[multi_case_table + multi_case_index] = header_index;
multi_case_index += 1;
if (case_node.toOptional() == underscore_case_node) {
payloads.items[under_case_index] = header_index;
if (special_prongs.hasOneAdditionalItem()) {
try payloads.resize(gpa, header_index + 2); // item, body_len
const maybe_item_node = case.ast.values[0];
const item_node = if (maybe_item_node.toOptional() == underscore_node)
case.ast.values[1]
else
maybe_item_node;
const item_inst = try comptimeExpr(parent_gz, scope, item_ri, item_node, .switch_item);
payloads.items[header_index] = @intFromEnum(item_inst);
break :blk header_index + 1;
}
} else {
payloads.items[multi_case_table + multi_case_index] = header_index;
multi_case_index += 1;
}
try payloads.resize(gpa, header_index + 3); // items_len, ranges_len, body_len

// items
var items_len: u32 = 0;
for (case.ast.values) |item_node| {
if (tree.nodeTag(item_node) == .switch_range) continue;
if (item_node.toOptional() == underscore_node or
tree.nodeTag(item_node) == .switch_range)
{
continue;
}
items_len += 1;

const item_inst = try comptimeExpr(parent_gz, scope, item_ri, item_node, .switch_item);
Expand All @@ -7945,7 +7949,9 @@ fn switchExpr(
// ranges
var ranges_len: u32 = 0;
for (case.ast.values) |range| {
if (tree.nodeTag(range) != .switch_range) continue;
if (tree.nodeTag(range) != .switch_range) {
continue;
}
ranges_len += 1;

const first_node, const last_node = tree.nodeData(range).node_and_node;
Expand All @@ -7959,8 +7965,13 @@ fn switchExpr(
payloads.items[header_index] = items_len;
payloads.items[header_index + 1] = ranges_len;
break :blk header_index + 2;
} else if (case_node.toOptional() == special_node) blk: {
payloads.items[case_table_start] = header_index;
} else if (case_node.toOptional() == else_case_node) blk: {
payloads.items[else_case_index] = header_index;
try payloads.resize(gpa, header_index + 1); // body_len
break :blk header_index;
} else if (case_node.toOptional() == underscore_case_node) blk: {
assert(!special_prongs.hasAdditionalItems());
payloads.items[under_case_index] = header_index;
try payloads.resize(gpa, header_index + 1); // body_len
break :blk header_index;
} else blk: {
Expand Down Expand Up @@ -8015,15 +8026,13 @@ fn switchExpr(
try astgen.extra.ensureUnusedCapacity(gpa, @typeInfo(Zir.Inst.SwitchBlock).@"struct".fields.len +
@intFromBool(multi_cases_len != 0) +
@intFromBool(any_has_tag_capture) +
payloads.items.len - case_table_end +
(case_table_end - case_table_start) * @typeInfo(Zir.Inst.As).@"struct".fields.len);
payloads.items.len - scratch_top);

const payload_index = astgen.addExtraAssumeCapacity(Zir.Inst.SwitchBlock{
.operand = raw_operand,
.bits = Zir.Inst.SwitchBlock.Bits{
.has_multi_cases = multi_cases_len != 0,
.has_else = special_prong == .@"else",
.has_under = special_prong == .under,
.special_prongs = special_prongs,
.any_has_tag_capture = any_has_tag_capture,
.any_non_inline_capture = any_non_inline_capture,
.has_continue = switch_full.label_token != null and block_scope.label.?.used_for_continue,
Expand All @@ -8042,13 +8051,41 @@ fn switchExpr(
const zir_datas = astgen.instructions.items(.data);
zir_datas[@intFromEnum(switch_block)].pl_node.payload_index = payload_index;

for (payloads.items[case_table_start..case_table_end], 0..) |start_index, i| {
if (has_else) {
const start_index = payloads.items[else_case_index];
var end_index = start_index + 1;
const prong_info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(payloads.items[start_index]);
end_index += prong_info.body_len;
astgen.extra.appendSliceAssumeCapacity(payloads.items[start_index..end_index]);
}
if (has_under) {
const start_index = payloads.items[under_case_index];
var body_len_index = start_index;
var end_index = start_index;
const table_index = case_table_start + i;
if (table_index < scalar_case_table) {
end_index += 1;
} else if (table_index < multi_case_table) {
switch (underscore_additional_items) {
.none => {
end_index += 1;
},
.one => {
body_len_index += 1;
end_index += 2;
},
.many => {
body_len_index += 2;
const items_len = payloads.items[start_index];
const ranges_len = payloads.items[start_index + 1];
end_index += 3 + items_len + 2 * ranges_len;
},
}
const prong_info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(payloads.items[body_len_index]);
end_index += prong_info.body_len;
astgen.extra.appendSliceAssumeCapacity(payloads.items[start_index..end_index]);
}
for (payloads.items[scalar_case_table..case_table_end], 0..) |start_index, i| {
var body_len_index = start_index;
var end_index = start_index;
const table_index = scalar_case_table + i;
if (table_index < multi_case_table) {
body_len_index += 1;
end_index += 2;
} else {
Expand Down
Loading