diff --git a/build.zig b/build.zig index ef09fa8..a7a378d 100644 --- a/build.zig +++ b/build.zig @@ -146,16 +146,22 @@ pub fn build(b: *std.Build) !void { apStep.dependOn(&b.addInstallArtifact(apExe, .{ .dest_dir = .{ .override = .{ .custom = "../application_processor/c/src" } } }).step); compStep.dependOn(&b.addInstallArtifact(compExe, .{ .dest_dir = .{ .override = .{ .custom = "../component/c/src" } } }).step); - const test_step = b.step("test", "Run unit tests"); const unit_tests = b.addTest(.{ .root_source_file = b.path("shared/main.zig"), .target = b.resolveTargetQuery(.{}), }); const run_unit_tests = b.addRunArtifact(unit_tests); + const test_step = b.step("test", "Run unit tests"); test_step.dependOn(&run_unit_tests.step); + const docs = b.addObject(.{ + .name = "main", + .root_source_file = b.path("shared/main.zig"), + .target = target, + .optimize = .Debug, + }); const install_docs = b.addInstallDirectory(.{ - .source_dir = unit_tests.getEmittedDocs(), + .source_dir = docs.getEmittedDocs(), .install_dir = .{ .custom = ".." }, .install_subdir = "docs", }); diff --git a/shared/layer3.zig b/shared/layer3.zig index 2fd3eed..2102523 100644 --- a/shared/layer3.zig +++ b/shared/layer3.zig @@ -32,7 +32,9 @@ pub fn ChannelInner(comptime T: type) type { try self.inner.send(data, to); } - /// Receive data from an address. Returns `null` if no data is available. + /// Receive data from an address. Returns `null` if no data is available. The returned buffer + /// is not guaranteed to be full remaining contents of the channel, and further data may be + /// available with subsequent calls to `recv`. pub inline fn recv(self: Self, from: Address) ChannelError!?Owned([]const u8) { return try self.inner.recv(from); } @@ -43,48 +45,136 @@ pub inline fn Channel(inner: anytype) ChannelInner(@TypeOf(inner)) { return .{ .inner = inner }; } -/// A mock implementation of the `Channel` interface pub const MockChannel = struct { const Self = @This(); allocator: std.mem.Allocator, - buffers: *std.AutoHashMap(Address, std.ArrayList(u8)), - recv_buffer_size: usize, + sendBuffer: *std.AutoHashMap(Address, std.ArrayList(u8)), + recvBuffer: *std.AutoHashMap(Address, std.ArrayList(u8)), + mutex: *std.Thread.Mutex, + + from: MockChannelSimplex, + to: MockChannelSimplex, + + pub fn init(allocator: std.mem.Allocator, recv_buffer_size: usize, unreliable: bool) !Self { + const sendBuffer = try allocator.create(std.AutoHashMap(Address, std.ArrayList(u8))); + sendBuffer.* = std.AutoHashMap(Address, std.ArrayList(u8)).init(allocator); - pub fn init(allocator: std.mem.Allocator, recv_buffer_size: usize) !Self { - const buffers = try allocator.create(std.AutoHashMap(Address, std.ArrayList(u8))); - buffers.* = std.AutoHashMap(Address, std.ArrayList(u8)).init(allocator); + const recvBuffer = try allocator.create(std.AutoHashMap(Address, std.ArrayList(u8))); + recvBuffer.* = std.AutoHashMap(Address, std.ArrayList(u8)).init(allocator); + + const mutex = try allocator.create(std.Thread.Mutex); + mutex.* = .{}; + + const to = try MockChannelSimplex.init(allocator, recv_buffer_size, unreliable, sendBuffer, recvBuffer, mutex); + const from = try MockChannelSimplex.init(allocator, recv_buffer_size, unreliable, recvBuffer, sendBuffer, mutex); return Self{ - .recv_buffer_size = recv_buffer_size, + .to = to, + .from = from, .allocator = allocator, - .buffers = buffers, + .sendBuffer = sendBuffer, + .recvBuffer = recvBuffer, + .mutex = mutex, }; } pub fn deinit(self: Self) void { - var iter = self.buffers.valueIterator(); + var iter = self.sendBuffer.valueIterator(); while (iter.next()) |buffer| { buffer.deinit(); } - self.buffers.deinit(); - self.allocator.destroy(self.buffers); + self.sendBuffer.deinit(); + self.allocator.destroy(self.sendBuffer); + + iter = self.recvBuffer.valueIterator(); + while (iter.next()) |buffer| { + buffer.deinit(); + } + + self.recvBuffer.deinit(); + self.allocator.destroy(self.recvBuffer); + + self.to.deinit(); + self.from.deinit(); + self.allocator.destroy(self.mutex); + } +}; + +/// A mock implementation of the `Channel` interface +const MockChannelSimplex = struct { + const Self = @This(); + + allocator: std.mem.Allocator, + sendBuffer: *std.AutoHashMap(Address, std.ArrayList(u8)), + recvBuffer: *std.AutoHashMap(Address, std.ArrayList(u8)), + mutex: *std.Thread.Mutex, + recv_buffer_size: usize, + + unreliable: bool, + n: *usize, + + pub fn init( + allocator: std.mem.Allocator, + recv_buffer_size: usize, + unreliable: bool, + sendBuffer: *std.AutoHashMap(Address, std.ArrayList(u8)), + recvBuffer: *std.AutoHashMap(Address, std.ArrayList(u8)), + mutex: *std.Thread.Mutex, + ) !Self { + const n = try allocator.create(usize); + n.* = 0; + + return Self{ + .recv_buffer_size = recv_buffer_size, + .allocator = allocator, + .sendBuffer = sendBuffer, + .recvBuffer = recvBuffer, + .unreliable = unreliable, + .mutex = mutex, + .n = n, + }; + } + + pub fn deinit(self: Self) void { + self.allocator.destroy(self.n); } pub fn send(self: Self, data: []const u8, to: Address) ChannelError!void { - if (self.buffers.getPtr(to)) |buffer| { + self.mutex.lock(); + defer self.mutex.unlock(); + + if (self.sendBuffer.getPtr(to)) |buffer| { try buffer.appendSlice(data); } else { - var list = std.ArrayList(u8).init(self.allocator); - errdefer list.deinit(); - try list.appendSlice(data); - try self.buffers.put(to, list); + var buffer = std.ArrayList(u8).init(self.allocator); + errdefer buffer.deinit(); + try buffer.appendSlice(data); + try self.sendBuffer.put(to, buffer); + } + + self.messWithData(); + } + + fn messWithData(self: Self) void { + if (self.unreliable) { + var iter = self.sendBuffer.valueIterator(); + while (iter.next()) |buffer| { + if (self.n.* % 3 == 0) { + buffer.items[buffer.items.len / 2] /= 2; + } + } + + self.n.* += 1; } } pub fn recv(self: Self, from: Address) ChannelError!?Owned([]const u8) { - const buffer = self.buffers.get(from) orelse return null; + self.mutex.lock(); + defer self.mutex.unlock(); + + const buffer = self.recvBuffer.get(from) orelse return null; const len = @min(self.recv_buffer_size, buffer.items.len); const recv_buffer = try self.allocator.dupe(u8, buffer.items[0..len]); @@ -92,13 +182,13 @@ pub const MockChannel = struct { if (buffer.items.len < self.recv_buffer_size) { buffer.deinit(); - const removed = self.buffers.remove(from); + const removed = self.recvBuffer.remove(from); std.debug.assert(removed); } else { var next_list = std.ArrayList(u8).init(self.allocator); try next_list.appendSlice(buffer.items[self.recv_buffer_size..]); buffer.deinit(); - try self.buffers.put(from, next_list); + try self.recvBuffer.put(from, next_list); } return toOwned(@as([]const u8, recv_buffer), self.allocator); @@ -106,83 +196,119 @@ pub const MockChannel = struct { }; test "basic channel" { - const mock = MockChannel.init(std.testing.allocator, 80) catch unreachable; - defer mock.deinit(); - const channel = Channel(mock); + const mocks = try MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const toChannel = Channel(mocks.to); + const fromChannel = Channel(mocks.from); const addrA = Address.from(11); const data = "hello, world!"; - try channel.send(data, addrA); + try toChannel.send(data, addrA); - const recv = (try channel.recv(addrA)).?; + const recv = (try fromChannel.recv(addrA)).?; defer recv.deinit(); + try std.testing.expectEqualSlices(u8, data, recv.inner); } test "channel large buffer" { const recv_buffer_size = 80; - const mock = try MockChannel.init(std.testing.allocator, recv_buffer_size); - defer mock.deinit(); - const channel = Channel(mock); + const mocks = try MockChannel.init(std.testing.allocator, recv_buffer_size, false); + defer mocks.deinit(); + + const toChannel = Channel(mocks.to); + const fromChannel = Channel(mocks.from); const addrA = Address.from(11); const data = try std.testing.allocator.alloc(u8, recv_buffer_size + 1); defer std.testing.allocator.free(data); - try channel.send(data, addrA); + try toChannel.send(data, addrA); - const recv = (try channel.recv(addrA)).?; + const recv = (try fromChannel.recv(addrA)).?; defer recv.deinit(); try std.testing.expectEqualSlices(u8, data[0..recv_buffer_size], recv.inner); - const recv2 = (try channel.recv(addrA)).?; + const recv2 = (try fromChannel.recv(addrA)).?; defer recv2.deinit(); try std.testing.expectEqualSlices(u8, data[recv_buffer_size .. recv_buffer_size + 1], recv2.inner); } test "channel multiple addresses" { - const mock = MockChannel.init(std.testing.allocator, 80) catch unreachable; - defer mock.deinit(); - const channel = Channel(mock); + const mocks = try MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const toChannel = Channel(mocks.to); + const fromChannel = Channel(mocks.from); const addrA = Address.from(11); const addrB = Address.from(12); const dataA = "data A"; - try channel.send(dataA, addrA); + try toChannel.send(dataA, addrA); const dataB = "data B"; - try channel.send(dataB, addrB); + try toChannel.send(dataB, addrB); - const recvB = (try channel.recv(addrB)).?; + const recvB = (try fromChannel.recv(addrB)).?; defer recvB.deinit(); try std.testing.expectEqualSlices(u8, dataB, recvB.inner); - const recvA = (try channel.recv(addrA)).?; + const recvA = (try fromChannel.recv(addrA)).?; defer recvA.deinit(); try std.testing.expectEqualSlices(u8, dataA, recvA.inner); } test "channel recv nothing" { - const mock = MockChannel.init(std.testing.allocator, 80) catch unreachable; - defer mock.deinit(); - const channel = Channel(mock); + const mocks = try MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const fromChannel = Channel(mocks.from); const addr = Address.from(11); - const recv = try channel.recv(addr); + const recv = try fromChannel.recv(addr); try std.testing.expectEqual(recv, null); } test "channel send nothing" { - const mock = MockChannel.init(std.testing.allocator, 80) catch unreachable; - defer mock.deinit(); - const channel = Channel(mock); + const mocks = try MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const toChannel = Channel(mocks.to); const addr = Address.from(11); const data = ([_]u8{})[0..0]; try std.testing.expectEqual(0, data.len); - try channel.send(data, addr); + try toChannel.send(data, addr); +} + +test "unreliable channel" { + const recv_buffer_size = 80; + const mocks = try MockChannel.init(std.testing.allocator, recv_buffer_size, true); + defer mocks.deinit(); + + const toChannel = Channel(mocks.to); + const fromChannel = Channel(mocks.from); + + const addrA = Address.from(11); + + const data = try std.testing.allocator.alloc(u8, recv_buffer_size * 10); + defer std.testing.allocator.free(data); + try toChannel.send(data, addrA); + + var unequal = false; + for (0..10) |_| { + const recv = (try fromChannel.recv(addrA)).?; + defer recv.deinit(); + unequal = !std.mem.eql(u8, data[0..recv_buffer_size], recv.inner); + if (unequal) { + break; + } + } + + try std.testing.expect(unequal); } diff --git a/shared/layer4.zig b/shared/layer4.zig new file mode 100644 index 0000000..14db251 --- /dev/null +++ b/shared/layer4.zig @@ -0,0 +1,452 @@ +const std = @import("std"); + +const layer3 = @import("layer3.zig"); + +const shared = @import("main.zig"); +const Owned = shared.Owned; +const toOwned = shared.toOwned; + +pub const ConnectionError = error{ + Timeout, + MaxRetriesExceeded, + UnexpectedPacket, + + /// An extremely significant error has occurred that can only happen if + /// someone has interfered with the data. + StopTheCount, +} || layer3.ChannelError; + +const Flags = packed struct { + /// Whether the packet is an acknowledgment + is_ack: bool, + _: u7 = 0, +}; + +/// Simple packet structure with checksum +pub const Packet = extern struct { + const data_size = 110; + + /// Sequence number of the packet + seq_num: u32 align(1) = 0, + + /// Sequence number of the packet being acknowledged + ack_num: u32 align(1) = 0, + + /// Index of the data chunk in the message + index: u32 align(1) = 0, + + /// CRC 32 checksum of the whole packet + checksum: u32 align(1) = 0, + + flags: Flags align(1) = .{ .is_ack = false }, + data_len: u8 align(1) = 0, + data: [data_size]u8 align(1) = [_]u8{0} ** data_size, + + /// Calculate and modify the checksum of the packet + pub fn calculateChecksum(self: *Packet) void { + self.checksum = 0; + const bytes = std.mem.asBytes(self); + var hasher = std.hash.Crc32.init(); + hasher.update(bytes); + self.checksum = hasher.final(); + } + + /// Check that the checksum of the packet is valid + pub fn verifyChecksum(self: *const Packet) bool { + var temp_packet = self.*; + temp_packet.checksum = 0; + const bytes = std.mem.asBytes(&temp_packet); + var hasher = std.hash.Crc32.init(); + hasher.update(bytes); + return hasher.final() == self.checksum; + } +}; + +/// Reliable connection over an unreliable channel. Can send arbitrarily large data over the channel. +/// `ChannelImpl` must implement the `Channel` interface. +pub fn Connection(comptime ChannelImpl: type) type { + return struct { + const Self = @This(); + const num_unacked_packets = 5; + const num_retries = 3; + const global_timeout = 10 * 1000; + + channel: layer3.ChannelInner(ChannelImpl), + address: layer3.Address, + + allocator: std.mem.Allocator, + seq_num: u32 = 0, + + unacked_packets: [num_unacked_packets]UnackedPacket = undefined, + unacked_packets_tail_index: usize = 0, + unacked_packets_head_index: usize = 0, + + buffer: [@sizeOf(Packet)]u8 = undefined, + buffer_size: usize = 0, + + const UnackedPacket = struct { + packet: Packet, + last_sent: i64, + retries: u32 = 0, + }; + + pub fn init(allocator: std.mem.Allocator, channel: ChannelImpl, address: layer3.Address) Self { + return .{ + .channel = layer3.Channel(channel), + .address = address, + .allocator = allocator, + }; + } + + /// Blocks until all data is sent and successfully acknowledged by the receiver. + /// Will return early if there is an error or the connection times out. + pub fn send(self: *Self, data: anytype) ConnectionError!void { + const T = @TypeOf(data); + const bytes = std.mem.asBytes(&data); + + var remaining: usize = @sizeOf(T) / Packet.data_size; + if (@mod(@sizeOf(T), Packet.data_size) != 0) remaining += 1; + + var offset: usize = 0; + var index: u32 = 0; + + var timeout: i64 = 0; + while (true) { + if (remaining == 0 and timeout == 0) { + std.debug.print("SENDER: sent all packets, starting timeout\n", .{}); + // if we ware done sending, start a timeout + timeout = std.time.milliTimestamp(); + } + if (remaining > 0 and self.unacked_packets_head_index - self.unacked_packets_tail_index < num_unacked_packets) { + remaining -= 1; + // only send a chunk if we have room in our unacked packets + const end = @min(offset + Packet.data_size, bytes.len); + const chunk = bytes[offset..end]; + offset += chunk.len; + + var packet = Packet{ + .seq_num = self.seq_num, + .index = index, + .data_len = @intCast(chunk.len), + }; + @memcpy(packet.data[0..chunk.len], chunk); + + self.unacked_packets[@mod(self.unacked_packets_head_index, num_unacked_packets)] = .{ + .packet = packet, + .last_sent = std.time.milliTimestamp(), + }; + self.unacked_packets_head_index += 1; + + self.seq_num += 1; + index += 1; + std.debug.print("SEND len={}, seq_num={}, index={}\n", .{ packet.data_len, packet.seq_num, packet.index }); + try self.sendPacket(&packet); + } else if (timeout == 0) { + std.debug.print("SENDER: not receiving acks, starting timeout\n", .{}); + // receiver might have disconnected + timeout = std.time.milliTimestamp(); + } + + if (try self.recvPacket()) |packet| { + if (!packet.flags.is_ack) return error.UnexpectedPacket; + + // find the packet to ack + var pos: usize = self.unacked_packets_tail_index; + while (pos < self.unacked_packets_head_index) : (pos += 1) { + const i = @mod(pos, num_unacked_packets); + if (self.unacked_packets[i].packet.seq_num == packet.ack_num) { + // swap the acked packet with the last unacked packet, which will be removed + const last_tail = @mod(self.unacked_packets_tail_index, num_unacked_packets); + if (last_tail != i) self.unacked_packets[i] = self.unacked_packets[last_tail]; + + // move the tail up, removing the acked packet from the unacked list. If the tail gets + // too high, move the tail and head down (to avoid overflow when it gets to @sizeOf(usize)) + self.unacked_packets_tail_index += 1; + if (self.unacked_packets_tail_index >= num_unacked_packets) { + self.unacked_packets_tail_index -= num_unacked_packets; + self.unacked_packets_head_index -= num_unacked_packets; + } + + timeout = std.time.milliTimestamp(); + std.debug.print("ACKN ack_num={}\n", .{packet.ack_num}); + break; + } + } + } + + const current_time = std.time.milliTimestamp(); + + if (timeout != 0 and current_time - timeout > global_timeout) { + return; + } + + var pos: usize = self.unacked_packets_tail_index; + while (pos < self.unacked_packets_head_index) : (pos += 1) { + const i = @mod(pos, num_unacked_packets); + const unacked = &self.unacked_packets[i]; + if (current_time - unacked.last_sent > 1000) { + if (unacked.retries >= num_retries) { + return error.MaxRetriesExceeded; + } + std.debug.print("RETX seq_num={}\n", .{unacked.packet.seq_num}); + try self.sendPacket(&unacked.packet); + unacked.last_sent = current_time; + unacked.retries += 1; + } + } + + if (remaining == 0 and self.unacked_packets_head_index == self.unacked_packets_tail_index) { + return; + } + } + } + + /// Blocks until all data is received. Will return early if there is an error or the connection times out. + pub fn recv(self: *Self, comptime T: type) ConnectionError!Owned(*T) { + const result = try self.allocator.create(T); + errdefer self.allocator.destroy(result); + const result_bytes = std.mem.asBytes(result); + var remaining: usize = @sizeOf(T) / Packet.data_size; + if (@mod(@sizeOf(T), Packet.data_size) != 0) remaining += 1; + var num_received_bytes: usize = 0; + + std.debug.print("RECEIVER: remaining={}\n", .{remaining}); + + var timeout = std.time.milliTimestamp(); + + while (remaining > 0) { + const currentTime = std.time.milliTimestamp(); + if (try self.recvPacket()) |packet| { + if (packet.flags.is_ack) return error.UnexpectedPacket; + + std.debug.print("RECV len={}, seq_num={}, index={}\n", .{ packet.data_len, packet.seq_num, packet.index }); + + var ack_packet = Packet{ .ack_num = packet.seq_num, .flags = .{ .is_ack = true } }; + try self.sendPacket(&ack_packet); + + const offset = packet.index * Packet.data_size; + const len = @min(Packet.data_size, packet.data_len); + if (offset + len > @sizeOf(T)) { + return error.StopTheCount; + } + + @memcpy(result_bytes[offset .. offset + len], packet.data[0..len]); + remaining -= 1; + num_received_bytes += len; + + timeout = currentTime; + } + + if (currentTime - timeout > global_timeout) { + return error.Timeout; + } + } + + if (num_received_bytes != @sizeOf(T)) { + return error.StopTheCount; + } + + return toOwned(result, self.allocator); + } + + /// Calculate the checksum of the packet and send it over the underlying channel + fn sendPacket(self: *Self, packet: *Packet) !void { + packet.calculateChecksum(); + const bytes = std.mem.asBytes(packet); + try self.channel.send(bytes, self.address); + } + + /// Poll receiving a `Packet` from the underlying channel + fn recvPacket(self: *Self) !?Packet { + if (try self.channel.recv(self.address)) |data| { + std.debug.assert(data.inner.len <= @sizeOf(Packet)); + defer data.deinit(); + + const remaining_bytes = @sizeOf(Packet) - self.buffer_size; + if (data.inner.len < remaining_bytes) { + @memcpy(self.buffer[self.buffer_size .. self.buffer_size + data.inner.len], data.inner); + self.buffer_size += data.inner.len; + return null; + } + + @memcpy(self.buffer[self.buffer_size..@sizeOf(Packet)], data.inner[0..remaining_bytes]); + const packet_ptr: *Packet = @ptrCast(&self.buffer[0]); + const packet = packet_ptr.*; + + self.buffer_size = data.inner.len - remaining_bytes; + @memcpy(self.buffer[0..self.buffer_size], data.inner[remaining_bytes..]); + + if (!packet.verifyChecksum()) { + std.debug.print("DROP packet failed checksum seq_num={}\n", .{packet.seq_num}); + return null; + } + + return packet; + } + + return null; + } + }; +} + +fn initBytes(num_bytes: comptime_int) [num_bytes]u8 { + var result: [num_bytes]u8 = undefined; + for (&result, 0..) |*b, i| { + b.* = @truncate(i); + } + return result; +} + +test "connection" { + std.debug.print("\nTEST: connection\n", .{}); + const mocks = try layer3.MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const addr = layer3.Address.from(11); + + const channel1 = layer3.Channel(mocks.to); + const obj1 = initBytes(1); + var conn1 = Connection(@TypeOf(channel1)).init(std.testing.allocator, channel1, addr); + + const thread = try std.Thread.spawn(.{}, struct { + fn run(conn: *@TypeOf(conn1), obj: *const @TypeOf(obj1)) void { + conn.send(obj.*) catch @panic("failed to send"); + } + }.run, .{ &conn1, &obj1 }); + defer thread.join(); + + const channel2 = layer3.Channel(mocks.from); + var conn2 = Connection(@TypeOf(channel2)).init(std.testing.allocator, channel2, addr); + const obj2 = try conn2.recv(@TypeOf(obj1)); + defer obj2.deinit(); + try std.testing.expectEqualDeep(obj1, obj2.inner.*); +} + +test "connection over big T" { + std.debug.print("\nTEST: connection over big T\n", .{}); + const mocks = try layer3.MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const addr = layer3.Address.from(11); + const obj1 = initBytes(Packet.data_size * 111 + 1); + + const channel1 = layer3.Channel(mocks.to); + var conn1 = Connection(@TypeOf(channel1)).init(std.testing.allocator, channel1, addr); + + const thread = try std.Thread.spawn(.{}, struct { + fn run(conn: *@TypeOf(conn1), obj: *const @TypeOf(obj1)) void { + conn.send(obj.*) catch @panic("failed to send"); + } + }.run, .{ &conn1, &obj1 }); + defer thread.join(); + + const channel2 = layer3.Channel(mocks.from); + var conn2 = Connection(@TypeOf(channel2)).init(std.testing.allocator, channel2, addr); + const obj2 = try conn2.recv(@TypeOf(obj1)); + defer obj2.deinit(); + try std.testing.expectEqualDeep(obj1, obj2.inner.*); +} + +test "connection over unreliable channel" { + std.debug.print("\nTEST: connection over unreliable channel\n", .{}); + const mocks = try layer3.MockChannel.init(std.testing.allocator, 80, true); + defer mocks.deinit(); + + const addr = layer3.Address.from(11); + const obj1 = initBytes(2048); + + const channel1 = layer3.Channel(mocks.to); + var conn1 = Connection(@TypeOf(channel1)).init(std.testing.allocator, channel1, addr); + + const thread = try std.Thread.spawn(.{}, struct { + fn run(conn: *@TypeOf(conn1), obj: *const @TypeOf(obj1)) void { + conn.send(obj.*) catch @panic("failed to send"); + } + }.run, .{ &conn1, &obj1 }); + defer thread.join(); + + const channel2 = layer3.Channel(mocks.from); + var conn2 = Connection(@TypeOf(channel2)).init(std.testing.allocator, channel2, addr); + const obj2 = try conn2.recv(@TypeOf(obj1)); + defer obj2.deinit(); + try std.testing.expectEqualDeep(obj1, obj2.inner.*); +} + +test "connection transmitting multiple Ts" { + std.debug.print("\nTEST: connection transmitting multiple Ts\n", .{}); + const mocks = try layer3.MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const addr = layer3.Address.from(11); + + const channel1 = layer3.Channel(mocks.to); + var conn1 = Connection(@TypeOf(channel1)).init(std.testing.allocator, channel1, addr); + + const obj1 = initBytes(12); + const obj2 = initBytes(123); + + const thread = try std.Thread.spawn(.{}, struct { + fn run(conn: *@TypeOf(conn1), a: *const @TypeOf(obj1), b: *const @TypeOf(obj2)) void { + conn.send(a.*) catch @panic("failed to send (a)"); + conn.send(b.*) catch @panic("failed to send (b)"); + } + }.run, .{ &conn1, &obj1, &obj2 }); + defer thread.join(); + + const channel2 = layer3.Channel(mocks.from); + var conn2 = Connection(@TypeOf(channel2)).init(std.testing.allocator, channel2, addr); + + const obj1r = try conn2.recv(@TypeOf(obj1)); + defer obj1r.deinit(); + try std.testing.expectEqualDeep(obj1, obj1r.inner.*); + + const obj12r = try conn2.recv(@TypeOf(obj2)); + defer obj12r.deinit(); + try std.testing.expectEqualDeep(obj2, obj12r.inner.*); +} + +test "connection send with no receiver" { + std.debug.print("\nTEST: connection send with no receiver\n", .{}); + const mocks = try layer3.MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const addr = layer3.Address.from(11); + + const channel = layer3.Channel(mocks.to); + var conn = Connection(@TypeOf(channel)).init(std.testing.allocator, channel, addr); + + const obj = initBytes(12); + const result = conn.send(obj); + try std.testing.expectError(error.MaxRetriesExceeded, result); +} + +test "connection receive with no sender" { + std.debug.print("\nTEST: connection receive with no sender\n", .{}); + const mocks = try layer3.MockChannel.init(std.testing.allocator, 80, false); + defer mocks.deinit(); + + const addr = layer3.Address.from(11); + + const channel = layer3.Channel(mocks.from); + var conn = Connection(@TypeOf(channel)).init(std.testing.allocator, channel, addr); + + const result = conn.recv(u32); + try std.testing.expectError(error.Timeout, result); +} + +test "packet checksum" { + var packet = Packet{}; + packet.calculateChecksum(); + try std.testing.expect(packet.verifyChecksum()); +} + +test "corrupted packet checksum" { + var packet = Packet{}; + packet.calculateChecksum(); + packet.data[13] += 2; + try std.testing.expect(!packet.verifyChecksum()); +} + +test "sizes" { + try std.testing.expectEqual(128, @sizeOf(Packet)); +} diff --git a/shared/main.zig b/shared/main.zig index 4061e18..4fb84f5 100644 --- a/shared/main.zig +++ b/shared/main.zig @@ -4,9 +4,11 @@ const std = @import("std"); pub const msdk = @import("msdk"); pub const layer3 = @import("layer3.zig"); +pub const layer4 = @import("layer4.zig"); comptime { _ = layer3; + _ = layer4; } /// Used to override Zig's default log function to work on the embedded, `freestanding` platform. Normally, Zig has a hard dependency on posix. @@ -49,12 +51,21 @@ pub fn Owned(comptime T: type) type { allocator: std.mem.Allocator, pub fn deinit(self: Self) void { - self.allocator.free(self.inner); + switch (@typeInfo(T)) { + .Array => self.allocator.free(self.inner), + .Vector => self.allocator.free(self.inner), + .Pointer => |info| switch (info.size) { + .One => self.allocator.destroy(self.inner), + .Many, .C, .Slice => self.allocator.free(self.inner), + }, + .Struct => self.inner.deinit(), + else => unreachable, + } } }; } -/// `inner` must have been allocated with `allocator.create` +/// `inner` must have been allocated with `allocator` pub fn toOwned(inner: anytype, allocator: std.mem.Allocator) Owned(@TypeOf(inner)) { return .{ .inner = inner,