Skip to content

Commit ccaf2a9

Browse files
committed
fix(virtio-net): prepare checksum correctly
1 parent 7e69055 commit ccaf2a9

File tree

2 files changed

+149
-30
lines changed

2 files changed

+149
-30
lines changed

src/drivers/net/virtio/mod.rs

Lines changed: 148 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -263,38 +263,12 @@ impl NetworkDriver for VirtioNetDriver {
263263
};
264264

265265
let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
266-
// If a checksum isn't necessary, we have inform the host within the header
267-
// see Virtio specification 5.1.6.2
268-
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
266+
267+
if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) {
269268
header.flags = HdrF::NEEDS_CSUM;
270-
let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> =
271-
EthernetFrame::new_unchecked(&packet);
272-
let packet_header_len: u16;
273-
let protocol;
274-
match ethernet_frame.ethertype() {
275-
smoltcp::wire::EthernetProtocol::Ipv4 => {
276-
let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload());
277-
packet_header_len = packet.header_len().into();
278-
protocol = Some(packet.next_header());
279-
}
280-
smoltcp::wire::EthernetProtocol::Ipv6 => {
281-
let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload());
282-
packet_header_len = packet.header_len().try_into().unwrap();
283-
protocol = Some(packet.next_header());
284-
}
285-
_ => {
286-
packet_header_len = 0;
287-
protocol = None;
288-
}
289-
}
290269
header.csum_start =
291-
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
292-
header.csum_offset = match protocol {
293-
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
294-
Some(smoltcp::wire::IpProtocol::Udp) => 6,
295-
_ => 0,
296-
}
297-
.into();
270+
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into();
271+
header.csum_offset = csum_offset.into();
298272
}
299273

300274
let buff_tkn = AvailBufferToken::new(
@@ -778,6 +752,87 @@ impl VirtioNetDriver {
778752

779753
Ok(())
780754
}
755+
756+
/// Sets the TCP or UDP checksum field to the checksum of the pseudo-header if necessary or returns None otherwise.
757+
fn should_request_checksum<T: AsRef<[u8]> + AsMut<[u8]>>(
758+
&self,
759+
frame: T,
760+
) -> Option<(u16, u16)> {
761+
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
762+
// If a checksum calculation by the host is necessary, we have to inform the host within the header
763+
// see Virtio specification 5.1.6.2
764+
let mut ethernet_frame = EthernetFrame::new_unchecked(frame);
765+
// If the Ethernet protocol is not one of these two, we default to not asking for checksum,
766+
// as otherwise the frame will be corrupted by the device trying to write the checksum.
767+
if let ip @ (smoltcp::wire::EthernetProtocol::Ipv4
768+
| smoltcp::wire::EthernetProtocol::Ipv6) = ethernet_frame.ethertype()
769+
{
770+
let ip_header_len: u16;
771+
let ip_packet_len: usize;
772+
let protocol;
773+
let pseudo_header_checksum;
774+
match ip {
775+
smoltcp::wire::EthernetProtocol::Ipv4 => {
776+
let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut());
777+
ip_header_len = ip_packet.header_len().into();
778+
ip_packet_len = ip_packet.total_len().into();
779+
protocol = ip_packet.next_header();
780+
pseudo_header_checksum =
781+
partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet);
782+
}
783+
smoltcp::wire::EthernetProtocol::Ipv6 => {
784+
let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut());
785+
ip_header_len = ip_packet.header_len().try_into().expect(
786+
"VIRTIO does not support IP headers that are longer than u16::MAX bytes.",
787+
);
788+
ip_packet_len = ip_packet.total_len();
789+
protocol = ip_packet.next_header();
790+
pseudo_header_checksum =
791+
partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet);
792+
}
793+
_ => unreachable!(),
794+
}
795+
// Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
796+
if let smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp = protocol {
797+
let ip_payload =
798+
&mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len];
799+
800+
// We do not care about the offset of the checksum for the protocol if we don't require checksum
801+
// from the host, so we use None to signal that checksum from the host is not needed.
802+
let csum_offset = match protocol {
803+
smoltcp::wire::IpProtocol::Tcp => {
804+
if !self.checksums.tcp.tx() {
805+
let mut tcp_packet =
806+
smoltcp::wire::TcpPacket::new_unchecked(ip_payload);
807+
tcp_packet.set_checksum(pseudo_header_checksum);
808+
Some(16)
809+
} else {
810+
None
811+
}
812+
}
813+
smoltcp::wire::IpProtocol::Udp => {
814+
if !self.checksums.tcp.tx() {
815+
let mut udp_packet =
816+
smoltcp::wire::UdpPacket::new_unchecked(ip_payload);
817+
udp_packet.set_checksum(pseudo_header_checksum);
818+
Some(6)
819+
} else {
820+
None
821+
}
822+
}
823+
_ => None,
824+
};
825+
csum_offset.map(|csum_offset| (ip_header_len, csum_offset))
826+
} else {
827+
None
828+
}
829+
} else {
830+
None
831+
}
832+
} else {
833+
None
834+
}
835+
}
781836
}
782837

783838
pub mod constants {
@@ -802,3 +857,66 @@ pub mod error {
802857
IncompatibleFeatureSets(virtio::net::F, virtio::net::F),
803858
}
804859
}
860+
861+
/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
862+
/// and their results are meant to be combined with the TCP payload to calculate the real checksum.
863+
/// They are only useful for the VIRTIO driver with the checksum offloading feature.
864+
///
865+
/// The calculations here can theoretically be made faster by exploiting the properties described in
866+
/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
867+
mod partial_checksum {
868+
use core::iter;
869+
870+
use smoltcp::wire::{Ipv4Packet, Ipv6Packet};
871+
872+
/// Calculates the checksum for the IPv4 pseudo-header as described in
873+
/// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
874+
pub(super) fn ipv4_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
875+
packet: &Ipv4Packet<T>,
876+
) -> u16 {
877+
let src_addr = packet.src_addr();
878+
let dst_addr = packet.dst_addr();
879+
let address_words = src_addr
880+
.as_bytes()
881+
.iter()
882+
.chain(dst_addr.as_bytes())
883+
.copied()
884+
.array_chunks::<{ size_of::<u16>() }>()
885+
.map(u16::from_be_bytes);
886+
let padded_protocol = u16::from(u8::from(packet.next_header()));
887+
let payload_len = packet.total_len() - u16::from(packet.header_len());
888+
address_words
889+
.chain(iter::once(padded_protocol))
890+
.chain(iter::once(payload_len))
891+
.fold(0u16, ones_complement_add)
892+
}
893+
894+
/// Calculates the checksum for the IPv6 pseudo-header as described in
895+
/// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
896+
pub(super) fn ipv6_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
897+
packet: &Ipv6Packet<T>,
898+
) -> u16 {
899+
warn!("The IPv6 partial checksum implementation is untested!");
900+
let src_addr = packet.src_addr();
901+
let dst_addr = packet.dst_addr();
902+
let payload_len = packet.payload_len();
903+
let padded_protocol = u16::from(u8::from(packet.next_header()));
904+
905+
src_addr
906+
.as_bytes()
907+
.iter()
908+
.chain(dst_addr.as_bytes())
909+
.copied()
910+
.array_chunks::<{ size_of::<u16>() }>()
911+
.map(u16::from_be_bytes)
912+
.chain(iter::once(payload_len))
913+
.chain(iter::once(padded_protocol))
914+
.fold(0u16, ones_complement_add)
915+
}
916+
917+
/// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
918+
fn ones_complement_add(lhs: u16, rhs: u16) -> u16 {
919+
let (sum, overflow) = u16::overflowing_add(lhs, rhs);
920+
sum + u16::from(overflow)
921+
}
922+
}

src/lib.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
)]
1414
#![cfg_attr(target_arch = "x86_64", feature(abi_x86_interrupt))]
1515
#![feature(allocator_api)]
16+
#![feature(iter_array_chunks)]
1617
#![feature(linked_list_cursors)]
1718
#![feature(map_try_insert)]
1819
#![feature(maybe_uninit_as_bytes)]

0 commit comments

Comments
 (0)