Skip to content

Commit 185f431

Browse files
cagatay-ymkroening
authored andcommitted
fix(virtio-net): prepare checksum correctly
1 parent 2dbedb9 commit 185f431

File tree

1 file changed

+123
-30
lines changed

1 file changed

+123
-30
lines changed

src/drivers/net/virtio/mod.rs

Lines changed: 123 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -271,38 +271,14 @@ impl NetworkDriver for VirtioNetDriver<Init> {
271271
};
272272

273273
let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
274-
// If a checksum isn't necessary, we have inform the host within the header
274+
275+
// If a checksum calculation by the host is necessary, we have to inform the host within the header
275276
// see Virtio specification 5.1.6.2
276-
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
277+
if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) {
277278
header.flags = HdrF::NEEDS_CSUM;
278-
let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> =
279-
EthernetFrame::new_unchecked(&packet);
280-
let packet_header_len: u16;
281-
let protocol;
282-
match ethernet_frame.ethertype() {
283-
smoltcp::wire::EthernetProtocol::Ipv4 => {
284-
let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload());
285-
packet_header_len = packet.header_len().into();
286-
protocol = Some(packet.next_header());
287-
}
288-
smoltcp::wire::EthernetProtocol::Ipv6 => {
289-
let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload());
290-
packet_header_len = packet.header_len().try_into().unwrap();
291-
protocol = Some(packet.next_header());
292-
}
293-
_ => {
294-
packet_header_len = 0;
295-
protocol = None;
296-
}
297-
}
298279
header.csum_start =
299-
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
300-
header.csum_offset = match protocol {
301-
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
302-
Some(smoltcp::wire::IpProtocol::Udp) => 6,
303-
_ => 0,
304-
}
305-
.into();
280+
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into();
281+
header.csum_offset = csum_offset.into();
306282
}
307283

308284
let buff_tkn = AvailBufferToken::new(
@@ -488,6 +464,65 @@ impl VirtioNetDriver<Init> {
488464
// Only for receive? Because send is off anyway?
489465
self.inner.recv_vqs.enable_notifs();
490466
}
467+
468+
/// If necessary, sets the TCP or UDP checksum field to the checksum of the
469+
/// pseudo-header and returns the IP header length and the checksum offset.
470+
/// Otherwise, returns None.
471+
fn should_request_checksum<T: AsRef<[u8]> + AsMut<[u8]>>(
472+
&self,
473+
frame: T,
474+
) -> Option<(u16, u16)> {
475+
if self.checksums.tcp.tx() && self.checksums.udp.tx() {
476+
return None;
477+
}
478+
479+
let ip_header_len: u16;
480+
let ip_packet_len: usize;
481+
let protocol;
482+
let pseudo_header_checksum;
483+
let mut ethernet_frame = EthernetFrame::new_unchecked(frame);
484+
match ethernet_frame.ethertype() {
485+
smoltcp::wire::EthernetProtocol::Ipv4 => {
486+
let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut());
487+
ip_header_len = ip_packet.header_len().into();
488+
ip_packet_len = ip_packet.total_len().into();
489+
protocol = ip_packet.next_header();
490+
pseudo_header_checksum =
491+
partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet);
492+
}
493+
smoltcp::wire::EthernetProtocol::Ipv6 => {
494+
let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut());
495+
ip_header_len = ip_packet.header_len().try_into().expect(
496+
"VIRTIO does not support IP headers that are longer than u16::MAX bytes.",
497+
);
498+
ip_packet_len = ip_packet.total_len();
499+
protocol = ip_packet.next_header();
500+
pseudo_header_checksum =
501+
partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet);
502+
}
503+
// If the Ethernet protocol is not one of these two above, for which we know there may be a checksum field,
504+
// we default to not asking for checksum, as otherwise the frame will be corrupted by the device trying
505+
// to write the checksum.
506+
_ => return None,
507+
};
508+
509+
let csum_offset;
510+
let ip_payload = &mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len];
511+
// Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
512+
if protocol == smoltcp::wire::IpProtocol::Tcp && !self.checksums.tcp.tx() {
513+
let mut tcp_packet = smoltcp::wire::TcpPacket::new_unchecked(ip_payload);
514+
tcp_packet.set_checksum(pseudo_header_checksum);
515+
csum_offset = 16;
516+
} else if protocol == smoltcp::wire::IpProtocol::Udp && !self.checksums.udp.tx() {
517+
let mut udp_packet = smoltcp::wire::UdpPacket::new_unchecked(ip_payload);
518+
udp_packet.set_checksum(pseudo_header_checksum);
519+
csum_offset = 6;
520+
} else {
521+
return None;
522+
};
523+
524+
Some((ip_header_len, csum_offset))
525+
}
491526
}
492527

493528
impl VirtioNetDriver<Uninit> {
@@ -524,7 +559,9 @@ impl VirtioNetDriver<Uninit> {
524559
// control queue support
525560
| virtio::net::F::CTRL_VQ
526561
// Multiqueue support
527-
| virtio::net::F::MQ;
562+
| virtio::net::F::MQ
563+
// Checksum calculation can partially be offloaded to the device
564+
| virtio::net::F::CSUM;
528565

529566
// Currently the driver does NOT support the features below.
530567
// In order to provide functionality for these, the driver
@@ -853,3 +890,59 @@ pub mod error {
853890
IncompatibleFeatureSets(virtio::net::F, virtio::net::F),
854891
}
855892
}
893+
894+
/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
895+
/// and their results are meant to be combined with the TCP payload to calculate the real checksum.
896+
/// They are only useful for the VIRTIO driver with the checksum offloading feature.
897+
///
898+
/// The calculations here can theoretically be made faster by exploiting the properties described in
899+
/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
900+
mod partial_checksum {
901+
use smoltcp::wire::{Ipv4Packet, Ipv6Packet};
902+
903+
fn addr_sum<const N: usize>(addr: &[u8; N]) -> u16 {
904+
let mut sum = 0;
905+
const CHUNK_SIZE: usize = size_of::<u16>();
906+
for i in 0..(N / CHUNK_SIZE) {
907+
sum = ones_complement_add(
908+
sum,
909+
(u16::from(addr[CHUNK_SIZE * i]) << 8) | u16::from(addr[CHUNK_SIZE * i + 1]),
910+
);
911+
}
912+
sum
913+
}
914+
915+
/// Calculates the checksum for the IPv4 pseudo-header as described in
916+
/// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
917+
pub(super) fn ipv4_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
918+
packet: &Ipv4Packet<T>,
919+
) -> u16 {
920+
let padded_protocol = u16::from(u8::from(packet.next_header()));
921+
let payload_len = packet.total_len() - u16::from(packet.header_len());
922+
923+
let mut sum = addr_sum(&packet.src_addr().octets());
924+
sum = ones_complement_add(sum, addr_sum(&packet.dst_addr().octets()));
925+
sum = ones_complement_add(sum, padded_protocol);
926+
ones_complement_add(sum, payload_len)
927+
}
928+
929+
/// Calculates the checksum for the IPv6 pseudo-header as described in
930+
/// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
931+
pub(super) fn ipv6_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
932+
packet: &Ipv6Packet<T>,
933+
) -> u16 {
934+
warn!("The IPv6 partial checksum implementation is untested!");
935+
let padded_protocol = u16::from(u8::from(packet.next_header()));
936+
937+
let mut sum = addr_sum(&packet.src_addr().octets());
938+
sum = ones_complement_add(sum, addr_sum(&packet.dst_addr().octets()));
939+
sum = ones_complement_add(sum, packet.payload_len());
940+
ones_complement_add(sum, padded_protocol)
941+
}
942+
943+
/// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
944+
fn ones_complement_add(lhs: u16, rhs: u16) -> u16 {
945+
let (sum, overflow) = u16::overflowing_add(lhs, rhs);
946+
sum + u16::from(overflow)
947+
}
948+
}

0 commit comments

Comments
 (0)