@@ -263,38 +263,12 @@ impl NetworkDriver for VirtioNetDriver {
263
263
} ;
264
264
265
265
let mut header = Box :: new_in ( <Hdr as Default >:: default ( ) , DeviceAlloc ) ;
266
- // If a checksum isn't necessary, we have inform the host within the header
267
- // see Virtio specification 5.1.6.2
268
- if !self . checksums . tcp . tx ( ) || !self . checksums . udp . tx ( ) {
266
+
267
+ if let Some ( ( ip_header_len, csum_offset) ) = self . should_request_checksum ( & mut packet) {
269
268
header. flags = HdrF :: NEEDS_CSUM ;
270
- let ethernet_frame: smoltcp:: wire:: EthernetFrame < & [ u8 ] > =
271
- EthernetFrame :: new_unchecked ( & packet) ;
272
- let packet_header_len: u16 ;
273
- let protocol;
274
- match ethernet_frame. ethertype ( ) {
275
- smoltcp:: wire:: EthernetProtocol :: Ipv4 => {
276
- let packet = Ipv4Packet :: new_unchecked ( ethernet_frame. payload ( ) ) ;
277
- packet_header_len = packet. header_len ( ) . into ( ) ;
278
- protocol = Some ( packet. next_header ( ) ) ;
279
- }
280
- smoltcp:: wire:: EthernetProtocol :: Ipv6 => {
281
- let packet = Ipv6Packet :: new_unchecked ( ethernet_frame. payload ( ) ) ;
282
- packet_header_len = packet. header_len ( ) . try_into ( ) . unwrap ( ) ;
283
- protocol = Some ( packet. next_header ( ) ) ;
284
- }
285
- _ => {
286
- packet_header_len = 0 ;
287
- protocol = None ;
288
- }
289
- }
290
269
header. csum_start =
291
- ( u16:: try_from ( ETHERNET_HEADER_LEN ) . unwrap ( ) + packet_header_len) . into ( ) ;
292
- header. csum_offset = match protocol {
293
- Some ( smoltcp:: wire:: IpProtocol :: Tcp ) => 16 ,
294
- Some ( smoltcp:: wire:: IpProtocol :: Udp ) => 6 ,
295
- _ => 0 ,
296
- }
297
- . into ( ) ;
270
+ ( u16:: try_from ( ETHERNET_HEADER_LEN ) . unwrap ( ) + ip_header_len) . into ( ) ;
271
+ header. csum_offset = csum_offset. into ( ) ;
298
272
}
299
273
300
274
let buff_tkn = AvailBufferToken :: new (
@@ -778,6 +752,87 @@ impl VirtioNetDriver {
778
752
779
753
Ok ( ( ) )
780
754
}
755
+
756
+ /// Sets the TCP or UDP checksum field to the checksum of the pseudo-header if necessary or returns None otherwise.
757
+ fn should_request_checksum < T : AsRef < [ u8 ] > + AsMut < [ u8 ] > > (
758
+ & self ,
759
+ frame : T ,
760
+ ) -> Option < ( u16 , u16 ) > {
761
+ if !self . checksums . tcp . tx ( ) || !self . checksums . udp . tx ( ) {
762
+ // If a checksum calculation by the host is necessary, we have to inform the host within the header
763
+ // see Virtio specification 5.1.6.2
764
+ let mut ethernet_frame = EthernetFrame :: new_unchecked ( frame) ;
765
+ // If the Ethernet protocol is not one of these two, we default to not asking for checksum,
766
+ // as otherwise the frame will be corrupted by the device trying to write the checksum.
767
+ if let ip @ ( smoltcp:: wire:: EthernetProtocol :: Ipv4
768
+ | smoltcp:: wire:: EthernetProtocol :: Ipv6 ) = ethernet_frame. ethertype ( )
769
+ {
770
+ let ip_header_len: u16 ;
771
+ let ip_packet_len: usize ;
772
+ let protocol;
773
+ let pseudo_header_checksum;
774
+ match ip {
775
+ smoltcp:: wire:: EthernetProtocol :: Ipv4 => {
776
+ let ip_packet = Ipv4Packet :: new_unchecked ( & * ethernet_frame. payload_mut ( ) ) ;
777
+ ip_header_len = ip_packet. header_len ( ) . into ( ) ;
778
+ ip_packet_len = ip_packet. total_len ( ) . into ( ) ;
779
+ protocol = ip_packet. next_header ( ) ;
780
+ pseudo_header_checksum =
781
+ partial_checksum:: ipv4_pseudo_header_partial_checksum ( & ip_packet) ;
782
+ }
783
+ smoltcp:: wire:: EthernetProtocol :: Ipv6 => {
784
+ let ip_packet = Ipv6Packet :: new_unchecked ( & * ethernet_frame. payload_mut ( ) ) ;
785
+ ip_header_len = ip_packet. header_len ( ) . try_into ( ) . expect (
786
+ "VIRTIO does not support IP headers that are longer than u16::MAX bytes." ,
787
+ ) ;
788
+ ip_packet_len = ip_packet. total_len ( ) ;
789
+ protocol = ip_packet. next_header ( ) ;
790
+ pseudo_header_checksum =
791
+ partial_checksum:: ipv6_pseudo_header_partial_checksum ( & ip_packet) ;
792
+ }
793
+ _ => unreachable ! ( ) ,
794
+ }
795
+ // Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
796
+ if let smoltcp:: wire:: IpProtocol :: Tcp | smoltcp:: wire:: IpProtocol :: Udp = protocol {
797
+ let ip_payload =
798
+ & mut ethernet_frame. payload_mut ( ) [ ip_header_len. into ( ) ..ip_packet_len] ;
799
+
800
+ // We do not care about the offset of the checksum for the protocol if we don't require checksum
801
+ // from the host, so we use None to signal that checksum from the host is not needed.
802
+ let csum_offset = match protocol {
803
+ smoltcp:: wire:: IpProtocol :: Tcp => {
804
+ if !self . checksums . tcp . tx ( ) {
805
+ let mut tcp_packet =
806
+ smoltcp:: wire:: TcpPacket :: new_unchecked ( ip_payload) ;
807
+ tcp_packet. set_checksum ( pseudo_header_checksum) ;
808
+ Some ( 16 )
809
+ } else {
810
+ None
811
+ }
812
+ }
813
+ smoltcp:: wire:: IpProtocol :: Udp => {
814
+ if !self . checksums . tcp . tx ( ) {
815
+ let mut udp_packet =
816
+ smoltcp:: wire:: UdpPacket :: new_unchecked ( ip_payload) ;
817
+ udp_packet. set_checksum ( pseudo_header_checksum) ;
818
+ Some ( 6 )
819
+ } else {
820
+ None
821
+ }
822
+ }
823
+ _ => None ,
824
+ } ;
825
+ csum_offset. map ( |csum_offset| ( ip_header_len, csum_offset) )
826
+ } else {
827
+ None
828
+ }
829
+ } else {
830
+ None
831
+ }
832
+ } else {
833
+ None
834
+ }
835
+ }
781
836
}
782
837
783
838
pub mod constants {
@@ -802,3 +857,66 @@ pub mod error {
802
857
IncompatibleFeatureSets ( virtio:: net:: F , virtio:: net:: F ) ,
803
858
}
804
859
}
860
+
861
+ /// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
862
+ /// and their results are meant to be combined with the TCP payload to calculate the real checksum.
863
+ /// They are only useful for the VIRTIO driver with the checksum offloading feature.
864
+ ///
865
+ /// The calculations here can theoretically be made faster by exploiting the properties described in
866
+ /// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
867
+ mod partial_checksum {
868
+ use core:: iter;
869
+
870
+ use smoltcp:: wire:: { Ipv4Packet , Ipv6Packet } ;
871
+
872
+ /// Calculates the checksum for the IPv4 pseudo-header as described in
873
+ /// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
874
+ pub ( super ) fn ipv4_pseudo_header_partial_checksum < T : AsRef < [ u8 ] > > (
875
+ packet : & Ipv4Packet < T > ,
876
+ ) -> u16 {
877
+ let src_addr = packet. src_addr ( ) ;
878
+ let dst_addr = packet. dst_addr ( ) ;
879
+ let address_words = src_addr
880
+ . as_bytes ( )
881
+ . iter ( )
882
+ . chain ( dst_addr. as_bytes ( ) )
883
+ . copied ( )
884
+ . array_chunks :: < { size_of :: < u16 > ( ) } > ( )
885
+ . map ( u16:: from_be_bytes) ;
886
+ let padded_protocol = u16:: from ( u8:: from ( packet. next_header ( ) ) ) ;
887
+ let payload_len = packet. total_len ( ) - u16:: from ( packet. header_len ( ) ) ;
888
+ address_words
889
+ . chain ( iter:: once ( padded_protocol) )
890
+ . chain ( iter:: once ( payload_len) )
891
+ . fold ( 0u16 , ones_complement_add)
892
+ }
893
+
894
+ /// Calculates the checksum for the IPv6 pseudo-header as described in
895
+ /// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
896
+ pub ( super ) fn ipv6_pseudo_header_partial_checksum < T : AsRef < [ u8 ] > > (
897
+ packet : & Ipv6Packet < T > ,
898
+ ) -> u16 {
899
+ warn ! ( "The IPv6 partial checksum implementation is untested!" ) ;
900
+ let src_addr = packet. src_addr ( ) ;
901
+ let dst_addr = packet. dst_addr ( ) ;
902
+ let payload_len = packet. payload_len ( ) ;
903
+ let padded_protocol = u16:: from ( u8:: from ( packet. next_header ( ) ) ) ;
904
+
905
+ src_addr
906
+ . as_bytes ( )
907
+ . iter ( )
908
+ . chain ( dst_addr. as_bytes ( ) )
909
+ . copied ( )
910
+ . array_chunks :: < { size_of :: < u16 > ( ) } > ( )
911
+ . map ( u16:: from_be_bytes)
912
+ . chain ( iter:: once ( payload_len) )
913
+ . chain ( iter:: once ( padded_protocol) )
914
+ . fold ( 0u16 , ones_complement_add)
915
+ }
916
+
917
+ /// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
918
+ fn ones_complement_add ( lhs : u16 , rhs : u16 ) -> u16 {
919
+ let ( sum, overflow) = u16:: overflowing_add ( lhs, rhs) ;
920
+ sum + u16:: from ( overflow)
921
+ }
922
+ }
0 commit comments