@@ -257,38 +257,12 @@ impl NetworkDriver for VirtioNetDriver {
257
257
} ;
258
258
259
259
let mut header = Box :: new_in ( <Hdr as Default >:: default ( ) , DeviceAlloc ) ;
260
- // If a checksum isn't necessary, we have inform the host within the header
261
- // see Virtio specification 5.1.6.2
262
- if !self . checksums . tcp . tx ( ) || !self . checksums . udp . tx ( ) {
260
+
261
+ if let Some ( ( ip_header_len, csum_offset) ) = self . should_request_checksum ( & mut packet) {
263
262
header. flags = HdrF :: NEEDS_CSUM ;
264
- let ethernet_frame: smoltcp:: wire:: EthernetFrame < & [ u8 ] > =
265
- EthernetFrame :: new_unchecked ( & packet) ;
266
- let packet_header_len: u16 ;
267
- let protocol;
268
- match ethernet_frame. ethertype ( ) {
269
- smoltcp:: wire:: EthernetProtocol :: Ipv4 => {
270
- let packet = Ipv4Packet :: new_unchecked ( ethernet_frame. payload ( ) ) ;
271
- packet_header_len = packet. header_len ( ) . into ( ) ;
272
- protocol = Some ( packet. next_header ( ) ) ;
273
- }
274
- smoltcp:: wire:: EthernetProtocol :: Ipv6 => {
275
- let packet = Ipv6Packet :: new_unchecked ( ethernet_frame. payload ( ) ) ;
276
- packet_header_len = packet. header_len ( ) . try_into ( ) . unwrap ( ) ;
277
- protocol = Some ( packet. next_header ( ) ) ;
278
- }
279
- _ => {
280
- packet_header_len = 0 ;
281
- protocol = None ;
282
- }
283
- }
284
263
header. csum_start =
285
- ( u16:: try_from ( ETHERNET_HEADER_LEN ) . unwrap ( ) + packet_header_len) . into ( ) ;
286
- header. csum_offset = match protocol {
287
- Some ( smoltcp:: wire:: IpProtocol :: Tcp ) => 16 ,
288
- Some ( smoltcp:: wire:: IpProtocol :: Udp ) => 6 ,
289
- _ => 0 ,
290
- }
291
- . into ( ) ;
264
+ ( u16:: try_from ( ETHERNET_HEADER_LEN ) . unwrap ( ) + ip_header_len) . into ( ) ;
265
+ header. csum_offset = csum_offset. into ( ) ;
292
266
}
293
267
294
268
let buff_tkn = AvailBufferToken :: new (
@@ -791,6 +765,90 @@ impl VirtioNetDriver {
791
765
792
766
Ok ( ( ) )
793
767
}
768
+
769
+ /// If necessary, sets the TCP or UDP checksum field to the checksum of the
770
+ /// pseudo-header and returns the IP header length and the checksum offset.
771
+ /// Otherwise, returns None.
772
+ fn should_request_checksum < T : AsRef < [ u8 ] > + AsMut < [ u8 ] > > (
773
+ & self ,
774
+ frame : T ,
775
+ ) -> Option < ( u16 , u16 ) > {
776
+ if !self . checksums . tcp . tx ( ) || !self . checksums . udp . tx ( ) {
777
+ // If a checksum calculation by the host is necessary, we have to inform the host within the header
778
+ // see Virtio specification 5.1.6.2
779
+ let mut ethernet_frame = EthernetFrame :: new_unchecked ( frame) ;
780
+ // If the Ethernet protocol is not one of these two, for which we know there may be a checksum field,
781
+ // we default to not asking for checksum, as otherwise the frame will be corrupted by the device trying
782
+ // to write the checksum.
783
+ if let ip @ ( smoltcp:: wire:: EthernetProtocol :: Ipv4
784
+ | smoltcp:: wire:: EthernetProtocol :: Ipv6 ) = ethernet_frame. ethertype ( )
785
+ {
786
+ let ip_header_len: u16 ;
787
+ let ip_packet_len: usize ;
788
+ let protocol;
789
+ let pseudo_header_checksum;
790
+ match ip {
791
+ smoltcp:: wire:: EthernetProtocol :: Ipv4 => {
792
+ let ip_packet = Ipv4Packet :: new_unchecked ( & * ethernet_frame. payload_mut ( ) ) ;
793
+ ip_header_len = ip_packet. header_len ( ) . into ( ) ;
794
+ ip_packet_len = ip_packet. total_len ( ) . into ( ) ;
795
+ protocol = ip_packet. next_header ( ) ;
796
+ pseudo_header_checksum =
797
+ partial_checksum:: ipv4_pseudo_header_partial_checksum ( & ip_packet) ;
798
+ }
799
+ smoltcp:: wire:: EthernetProtocol :: Ipv6 => {
800
+ let ip_packet = Ipv6Packet :: new_unchecked ( & * ethernet_frame. payload_mut ( ) ) ;
801
+ ip_header_len = ip_packet. header_len ( ) . try_into ( ) . expect (
802
+ "VIRTIO does not support IP headers that are longer than u16::MAX bytes." ,
803
+ ) ;
804
+ ip_packet_len = ip_packet. total_len ( ) ;
805
+ protocol = ip_packet. next_header ( ) ;
806
+ pseudo_header_checksum =
807
+ partial_checksum:: ipv6_pseudo_header_partial_checksum ( & ip_packet) ;
808
+ }
809
+ _ => unreachable ! ( ) ,
810
+ }
811
+ // Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
812
+ if let smoltcp:: wire:: IpProtocol :: Tcp | smoltcp:: wire:: IpProtocol :: Udp = protocol {
813
+ let ip_payload =
814
+ & mut ethernet_frame. payload_mut ( ) [ ip_header_len. into ( ) ..ip_packet_len] ;
815
+
816
+ // We do not care about the offset of the checksum for the protocol if we don't require checksum
817
+ // from the host, so we use None to signal that checksum from the host is not needed.
818
+ let csum_offset = match protocol {
819
+ smoltcp:: wire:: IpProtocol :: Tcp => {
820
+ if self . checksums . tcp . tx ( ) {
821
+ None
822
+ } else {
823
+ let mut tcp_packet =
824
+ smoltcp:: wire:: TcpPacket :: new_unchecked ( ip_payload) ;
825
+ tcp_packet. set_checksum ( pseudo_header_checksum) ;
826
+ Some ( 16 )
827
+ }
828
+ }
829
+ smoltcp:: wire:: IpProtocol :: Udp => {
830
+ if self . checksums . tcp . tx ( ) {
831
+ None
832
+ } else {
833
+ let mut udp_packet =
834
+ smoltcp:: wire:: UdpPacket :: new_unchecked ( ip_payload) ;
835
+ udp_packet. set_checksum ( pseudo_header_checksum) ;
836
+ Some ( 6 )
837
+ }
838
+ }
839
+ _ => None ,
840
+ } ;
841
+ csum_offset. map ( |csum_offset| ( ip_header_len, csum_offset) )
842
+ } else {
843
+ None
844
+ }
845
+ } else {
846
+ None
847
+ }
848
+ } else {
849
+ None
850
+ }
851
+ }
794
852
}
795
853
796
854
pub mod constants {
@@ -815,3 +873,55 @@ pub mod error {
815
873
IncompatibleFeatureSets ( virtio:: net:: F , virtio:: net:: F ) ,
816
874
}
817
875
}
876
+
877
+ /// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
878
+ /// and their results are meant to be combined with the TCP payload to calculate the real checksum.
879
+ /// They are only useful for the VIRTIO driver with the checksum offloading feature.
880
+ ///
881
+ /// The calculations here can theoretically be made faster by exploiting the properties described in
882
+ /// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
883
+ mod partial_checksum {
884
+ use smoltcp:: wire:: { Ipv4Packet , Ipv6Packet } ;
885
+
886
+ fn addr_sum < const N : usize > ( addr : & [ u8 ; N ] ) -> u16 {
887
+ let mut sum = 0 ;
888
+ for i in 0 ..( N / size_of :: < u16 > ( ) ) {
889
+ sum = ones_complement_add ( sum, ( u16:: from ( addr[ i] ) << 8 ) | u16:: from ( addr[ i + 1 ] ) ) ;
890
+ }
891
+ sum
892
+ }
893
+
894
+ /// Calculates the checksum for the IPv4 pseudo-header as described in
895
+ /// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
896
+ pub ( super ) fn ipv4_pseudo_header_partial_checksum < T : AsRef < [ u8 ] > > (
897
+ packet : & Ipv4Packet < T > ,
898
+ ) -> u16 {
899
+ let padded_protocol = u16:: from ( u8:: from ( packet. next_header ( ) ) ) ;
900
+ let payload_len = packet. total_len ( ) - u16:: from ( packet. header_len ( ) ) ;
901
+
902
+ let mut sum = addr_sum ( & packet. src_addr ( ) . octets ( ) ) ;
903
+ sum = ones_complement_add ( sum, addr_sum ( & packet. dst_addr ( ) . octets ( ) ) ) ;
904
+ sum = ones_complement_add ( sum, padded_protocol) ;
905
+ ones_complement_add ( sum, payload_len)
906
+ }
907
+
908
+ /// Calculates the checksum for the IPv6 pseudo-header as described in
909
+ /// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
910
+ pub ( super ) fn ipv6_pseudo_header_partial_checksum < T : AsRef < [ u8 ] > > (
911
+ packet : & Ipv6Packet < T > ,
912
+ ) -> u16 {
913
+ warn ! ( "The IPv6 partial checksum implementation is untested!" ) ;
914
+ let padded_protocol = u16:: from ( u8:: from ( packet. next_header ( ) ) ) ;
915
+
916
+ let mut sum = addr_sum ( & packet. src_addr ( ) . octets ( ) ) ;
917
+ sum = ones_complement_add ( sum, addr_sum ( & packet. dst_addr ( ) . octets ( ) ) ) ;
918
+ sum = ones_complement_add ( sum, packet. payload_len ( ) ) ;
919
+ ones_complement_add ( sum, padded_protocol)
920
+ }
921
+
922
+ /// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
923
+ fn ones_complement_add ( lhs : u16 , rhs : u16 ) -> u16 {
924
+ let ( sum, overflow) = u16:: overflowing_add ( lhs, rhs) ;
925
+ sum + u16:: from ( overflow)
926
+ }
927
+ }
0 commit comments