@@ -271,38 +271,14 @@ impl NetworkDriver for VirtioNetDriver<Init> {
271
271
} ;
272
272
273
273
let mut header = Box :: new_in ( <Hdr as Default >:: default ( ) , DeviceAlloc ) ;
274
- // If a checksum isn't necessary, we have inform the host within the header
274
+
275
+ // If a checksum calculation by the host is necessary, we have to inform the host within the header
275
276
// see Virtio specification 5.1.6.2
276
- if ! self . checksums . tcp . tx ( ) || ! self . checksums . udp . tx ( ) {
277
+ if let Some ( ( ip_header_len , csum_offset ) ) = self . should_request_checksum ( & mut packet ) {
277
278
header. flags = HdrF :: NEEDS_CSUM ;
278
- let ethernet_frame: smoltcp:: wire:: EthernetFrame < & [ u8 ] > =
279
- EthernetFrame :: new_unchecked ( & packet) ;
280
- let packet_header_len: u16 ;
281
- let protocol;
282
- match ethernet_frame. ethertype ( ) {
283
- smoltcp:: wire:: EthernetProtocol :: Ipv4 => {
284
- let packet = Ipv4Packet :: new_unchecked ( ethernet_frame. payload ( ) ) ;
285
- packet_header_len = packet. header_len ( ) . into ( ) ;
286
- protocol = Some ( packet. next_header ( ) ) ;
287
- }
288
- smoltcp:: wire:: EthernetProtocol :: Ipv6 => {
289
- let packet = Ipv6Packet :: new_unchecked ( ethernet_frame. payload ( ) ) ;
290
- packet_header_len = packet. header_len ( ) . try_into ( ) . unwrap ( ) ;
291
- protocol = Some ( packet. next_header ( ) ) ;
292
- }
293
- _ => {
294
- packet_header_len = 0 ;
295
- protocol = None ;
296
- }
297
- }
298
279
header. csum_start =
299
- ( u16:: try_from ( ETHERNET_HEADER_LEN ) . unwrap ( ) + packet_header_len) . into ( ) ;
300
- header. csum_offset = match protocol {
301
- Some ( smoltcp:: wire:: IpProtocol :: Tcp ) => 16 ,
302
- Some ( smoltcp:: wire:: IpProtocol :: Udp ) => 6 ,
303
- _ => 0 ,
304
- }
305
- . into ( ) ;
280
+ ( u16:: try_from ( ETHERNET_HEADER_LEN ) . unwrap ( ) + ip_header_len) . into ( ) ;
281
+ header. csum_offset = csum_offset. into ( ) ;
306
282
}
307
283
308
284
let buff_tkn = AvailBufferToken :: new (
@@ -488,6 +464,65 @@ impl VirtioNetDriver<Init> {
488
464
// Only for receive? Because send is off anyway?
489
465
self . inner . recv_vqs . enable_notifs ( ) ;
490
466
}
467
+
468
+ /// If necessary, sets the TCP or UDP checksum field to the checksum of the
469
+ /// pseudo-header and returns the IP header length and the checksum offset.
470
+ /// Otherwise, returns None.
471
+ fn should_request_checksum < T : AsRef < [ u8 ] > + AsMut < [ u8 ] > > (
472
+ & self ,
473
+ frame : T ,
474
+ ) -> Option < ( u16 , u16 ) > {
475
+ if self . checksums . tcp . tx ( ) && self . checksums . udp . tx ( ) {
476
+ return None ;
477
+ }
478
+
479
+ let ip_header_len: u16 ;
480
+ let ip_packet_len: usize ;
481
+ let protocol;
482
+ let pseudo_header_checksum;
483
+ let mut ethernet_frame = EthernetFrame :: new_unchecked ( frame) ;
484
+ match ethernet_frame. ethertype ( ) {
485
+ smoltcp:: wire:: EthernetProtocol :: Ipv4 => {
486
+ let ip_packet = Ipv4Packet :: new_unchecked ( & * ethernet_frame. payload_mut ( ) ) ;
487
+ ip_header_len = ip_packet. header_len ( ) . into ( ) ;
488
+ ip_packet_len = ip_packet. total_len ( ) . into ( ) ;
489
+ protocol = ip_packet. next_header ( ) ;
490
+ pseudo_header_checksum =
491
+ partial_checksum:: ipv4_pseudo_header_partial_checksum ( & ip_packet) ;
492
+ }
493
+ smoltcp:: wire:: EthernetProtocol :: Ipv6 => {
494
+ let ip_packet = Ipv6Packet :: new_unchecked ( & * ethernet_frame. payload_mut ( ) ) ;
495
+ ip_header_len = ip_packet. header_len ( ) . try_into ( ) . expect (
496
+ "VIRTIO does not support IP headers that are longer than u16::MAX bytes." ,
497
+ ) ;
498
+ ip_packet_len = ip_packet. total_len ( ) ;
499
+ protocol = ip_packet. next_header ( ) ;
500
+ pseudo_header_checksum =
501
+ partial_checksum:: ipv6_pseudo_header_partial_checksum ( & ip_packet) ;
502
+ }
503
+ // If the Ethernet protocol is not one of these two above, for which we know there may be a checksum field,
504
+ // we default to not asking for checksum, as otherwise the frame will be corrupted by the device trying
505
+ // to write the checksum.
506
+ _ => return None ,
507
+ } ;
508
+
509
+ let csum_offset;
510
+ let ip_payload = & mut ethernet_frame. payload_mut ( ) [ ip_header_len. into ( ) ..ip_packet_len] ;
511
+ // Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
512
+ if protocol == smoltcp:: wire:: IpProtocol :: Tcp && !self . checksums . tcp . tx ( ) {
513
+ let mut tcp_packet = smoltcp:: wire:: TcpPacket :: new_unchecked ( ip_payload) ;
514
+ tcp_packet. set_checksum ( pseudo_header_checksum) ;
515
+ csum_offset = 16 ;
516
+ } else if protocol == smoltcp:: wire:: IpProtocol :: Udp && !self . checksums . udp . tx ( ) {
517
+ let mut udp_packet = smoltcp:: wire:: UdpPacket :: new_unchecked ( ip_payload) ;
518
+ udp_packet. set_checksum ( pseudo_header_checksum) ;
519
+ csum_offset = 6 ;
520
+ } else {
521
+ return None ;
522
+ } ;
523
+
524
+ Some ( ( ip_header_len, csum_offset) )
525
+ }
491
526
}
492
527
493
528
impl VirtioNetDriver < Uninit > {
@@ -524,7 +559,9 @@ impl VirtioNetDriver<Uninit> {
524
559
// control queue support
525
560
| virtio:: net:: F :: CTRL_VQ
526
561
// Multiqueue support
527
- | virtio:: net:: F :: MQ ;
562
+ | virtio:: net:: F :: MQ
563
+ // Checksum calculation can partially be offloaded to the device
564
+ | virtio:: net:: F :: CSUM ;
528
565
529
566
// Currently the driver does NOT support the features below.
530
567
// In order to provide functionality for these, the driver
@@ -853,3 +890,59 @@ pub mod error {
853
890
IncompatibleFeatureSets ( virtio:: net:: F , virtio:: net:: F ) ,
854
891
}
855
892
}
893
+
894
+ /// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
895
+ /// and their results are meant to be combined with the TCP payload to calculate the real checksum.
896
+ /// They are only useful for the VIRTIO driver with the checksum offloading feature.
897
+ ///
898
+ /// The calculations here can theoretically be made faster by exploiting the properties described in
899
+ /// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
900
+ mod partial_checksum {
901
+ use smoltcp:: wire:: { Ipv4Packet , Ipv6Packet } ;
902
+
903
+ fn addr_sum < const N : usize > ( addr : & [ u8 ; N ] ) -> u16 {
904
+ let mut sum = 0 ;
905
+ const CHUNK_SIZE : usize = size_of :: < u16 > ( ) ;
906
+ for i in 0 ..( N / CHUNK_SIZE ) {
907
+ sum = ones_complement_add (
908
+ sum,
909
+ ( u16:: from ( addr[ CHUNK_SIZE * i] ) << 8 ) | u16:: from ( addr[ CHUNK_SIZE * i + 1 ] ) ,
910
+ ) ;
911
+ }
912
+ sum
913
+ }
914
+
915
+ /// Calculates the checksum for the IPv4 pseudo-header as described in
916
+ /// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
917
+ pub ( super ) fn ipv4_pseudo_header_partial_checksum < T : AsRef < [ u8 ] > > (
918
+ packet : & Ipv4Packet < T > ,
919
+ ) -> u16 {
920
+ let padded_protocol = u16:: from ( u8:: from ( packet. next_header ( ) ) ) ;
921
+ let payload_len = packet. total_len ( ) - u16:: from ( packet. header_len ( ) ) ;
922
+
923
+ let mut sum = addr_sum ( & packet. src_addr ( ) . octets ( ) ) ;
924
+ sum = ones_complement_add ( sum, addr_sum ( & packet. dst_addr ( ) . octets ( ) ) ) ;
925
+ sum = ones_complement_add ( sum, padded_protocol) ;
926
+ ones_complement_add ( sum, payload_len)
927
+ }
928
+
929
+ /// Calculates the checksum for the IPv6 pseudo-header as described in
930
+ /// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
931
+ pub ( super ) fn ipv6_pseudo_header_partial_checksum < T : AsRef < [ u8 ] > > (
932
+ packet : & Ipv6Packet < T > ,
933
+ ) -> u16 {
934
+ warn ! ( "The IPv6 partial checksum implementation is untested!" ) ;
935
+ let padded_protocol = u16:: from ( u8:: from ( packet. next_header ( ) ) ) ;
936
+
937
+ let mut sum = addr_sum ( & packet. src_addr ( ) . octets ( ) ) ;
938
+ sum = ones_complement_add ( sum, addr_sum ( & packet. dst_addr ( ) . octets ( ) ) ) ;
939
+ sum = ones_complement_add ( sum, packet. payload_len ( ) ) ;
940
+ ones_complement_add ( sum, padded_protocol)
941
+ }
942
+
943
+ /// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
944
+ fn ones_complement_add ( lhs : u16 , rhs : u16 ) -> u16 {
945
+ let ( sum, overflow) = u16:: overflowing_add ( lhs, rhs) ;
946
+ sum + u16:: from ( overflow)
947
+ }
948
+ }
0 commit comments