Skip to content

spirv: fix null file handle crash & improve flush err handling #24296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/link.zig
Original file line number Diff line number Diff line change
Expand Up @@ -568,9 +568,9 @@ pub const File = struct {
const gpa = comp.gpa;
switch (base.tag) {
.lld => assert(base.file == null),
.coff, .elf, .macho, .plan9, .wasm, .goff, .xcoff => {
.coff, .elf, .macho, .plan9, .wasm, .goff, .xcoff, .spirv => {
if (base.file != null) return;
dev.checkAny(&.{ .coff_linker, .elf_linker, .macho_linker, .plan9_linker, .wasm_linker, .goff_linker, .xcoff_linker });
dev.checkAny(&.{ .coff_linker, .elf_linker, .macho_linker, .plan9_linker, .wasm_linker, .goff_linker, .xcoff_linker, .spirv_linker });
const emit = base.emit;
if (base.child_pid) |pid| {
if (builtin.os.tag == .windows) {
Expand Down Expand Up @@ -608,7 +608,7 @@ pub const File = struct {
.mode = determineMode(output_mode, link_mode),
});
},
.c, .spirv => dev.checkAny(&.{ .c_linker, .spirv_linker }),
.c => dev.check(.c_linker),
}
}

Expand Down
51 changes: 42 additions & 9 deletions src/link/SpirV.zig
Original file line number Diff line number Diff line change
Expand Up @@ -180,6 +180,28 @@ pub fn flush(
tid: Zcu.PerThread.Id,
prog_node: std.Progress.Node,
) link.File.FlushError!void {
const tracy = trace(@src());
defer tracy.end();

const comp = self.base.comp;
const diags = &comp.link_diags;

const sub_prog_node = prog_node.start("SPIR-V Flush", 0);
defer sub_prog_node.end();

return flushInner(self, arena, tid, sub_prog_node) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
error.LinkFailure => return error.LinkFailure,
else => |e| return diags.fail("SPIR-V flush failed: {s}", .{@errorName(e)}),
};
Comment on lines +192 to +196
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

just return LinkFailure for every other error:

     return flushInner(self, arena, tid, sub_prog_node) catch |err| switch (err) {
         error.OutOfMemory => return error.OutOfMemory,
-        error.LinkFailure => return error.LinkFailure,
-        else => |e| return diags.fail("SPIR-V flush failed: {s}", .{@errorName(e)}),
+        else => return error.LinkFailure,
     };

}

fn flushInner(
self: *SpirV,
arena: Allocator,
tid: Zcu.PerThread.Id,
prog_node: std.Progress.Node,
) !void {
// The goal is to never use this because it's only needed if we need to
// write to InternPool, but flush is too late to be writing to the
// InternPool.
Expand All @@ -189,12 +211,6 @@ pub fn flush(
@panic("Attempted to compile for architecture that was disabled by build configuration");
}

const tracy = trace(@src());
defer tracy.end();

const sub_prog_node = prog_node.start("Flush Module", 0);
defer sub_prog_node.end();

const comp = self.base.comp;
const spv = &self.object.spv;
const diags = &comp.link_diags;
Expand Down Expand Up @@ -235,13 +251,14 @@ pub fn flush(
const module = try spv.finalize(arena);
errdefer arena.free(module);

const linked_module = self.linkModule(arena, module, sub_prog_node) catch |err| switch (err) {
const linked_module = self.linkModule(arena, module, prog_node) catch |err| switch (err) {
error.OutOfMemory => return error.OutOfMemory,
else => |other| return diags.fail("error while linking: {s}", .{@errorName(other)}),
};

self.base.file.?.writeAll(std.mem.sliceAsBytes(linked_module)) catch |err|
return diags.fail("failed to write: {s}", .{@errorName(err)});
try self.base.makeWritable();
try self.pwriteAll(std.mem.sliceAsBytes(linked_module), 0);
try self.setEndPos(linked_module.len * @sizeOf(Word));
Comment on lines +260 to +261
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Don't abstract them into mini functions. just use try:

- try self.pwriteAll(std.mem.sliceAsBytes(linked_module), 0);
- try self.setEndPos(linked_module.len * @sizeOf(Word));
+ try self.base.file.?.pwriteAll(std.mem.sliceAsBytes(linked_module), 0);
+ try self.base.file.?.setEndPos(linked_module.len * @sizeOf(Word));

}

fn linkModule(self: *SpirV, a: Allocator, module: []Word, progress: std.Progress.Node) ![]Word {
Expand All @@ -261,3 +278,19 @@ fn linkModule(self: *SpirV, a: Allocator, module: []Word, progress: std.Progress

return binary.finalize(a);
}

pub fn pwriteAll(spirv_file: *SpirV, bytes: []const u8, offset: u64) error{LinkFailure}!void {
const comp = spirv_file.base.comp;
const diags = &comp.link_diags;
spirv_file.base.file.?.pwriteAll(bytes, offset) catch |err| {
return diags.fail("failed to write: {s}", .{@errorName(err)});
};
}

pub fn setEndPos(spirv_file: *SpirV, length: u64) error{LinkFailure}!void {
const comp = spirv_file.base.comp;
const diags = &comp.link_diags;
spirv_file.base.file.?.setEndPos(length) catch |err| {
return diags.fail("failed to set file end pos: {s}", .{@errorName(err)});
};
}