Skip to content

Commit 0e4c1ad

Browse files
committed
fix(virtio-net): prepare checksum correctly
1 parent 1ca3169 commit 0e4c1ad

File tree

1 file changed

+140
-30
lines changed

1 file changed

+140
-30
lines changed

src/drivers/net/virtio/mod.rs

Lines changed: 140 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -257,38 +257,12 @@ impl NetworkDriver for VirtioNetDriver {
257257
};
258258

259259
let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
260-
// If a checksum isn't necessary, we have inform the host within the header
261-
// see Virtio specification 5.1.6.2
262-
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
260+
261+
if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) {
263262
header.flags = HdrF::NEEDS_CSUM;
264-
let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> =
265-
EthernetFrame::new_unchecked(&packet);
266-
let packet_header_len: u16;
267-
let protocol;
268-
match ethernet_frame.ethertype() {
269-
smoltcp::wire::EthernetProtocol::Ipv4 => {
270-
let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload());
271-
packet_header_len = packet.header_len().into();
272-
protocol = Some(packet.next_header());
273-
}
274-
smoltcp::wire::EthernetProtocol::Ipv6 => {
275-
let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload());
276-
packet_header_len = packet.header_len().try_into().unwrap();
277-
protocol = Some(packet.next_header());
278-
}
279-
_ => {
280-
packet_header_len = 0;
281-
protocol = None;
282-
}
283-
}
284263
header.csum_start =
285-
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
286-
header.csum_offset = match protocol {
287-
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
288-
Some(smoltcp::wire::IpProtocol::Udp) => 6,
289-
_ => 0,
290-
}
291-
.into();
264+
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into();
265+
header.csum_offset = csum_offset.into();
292266
}
293267

294268
let buff_tkn = AvailBufferToken::new(
@@ -791,6 +765,90 @@ impl VirtioNetDriver {
791765

792766
Ok(())
793767
}
768+
769+
/// If necessary, sets the TCP or UDP checksum field to the checksum of the
770+
/// pseudo-header and returns the IP header length and the checksum offset.
771+
/// Otherwise, returns None.
772+
fn should_request_checksum<T: AsRef<[u8]> + AsMut<[u8]>>(
773+
&self,
774+
frame: T,
775+
) -> Option<(u16, u16)> {
776+
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
777+
// If a checksum calculation by the host is necessary, we have to inform the host within the header
778+
// see Virtio specification 5.1.6.2
779+
let mut ethernet_frame = EthernetFrame::new_unchecked(frame);
780+
// If the Ethernet protocol is not one of these two, for which we know there may be a checksum field,
781+
// we default to not asking for checksum, as otherwise the frame will be corrupted by the device trying
782+
// to write the checksum.
783+
if let ip @ (smoltcp::wire::EthernetProtocol::Ipv4
784+
| smoltcp::wire::EthernetProtocol::Ipv6) = ethernet_frame.ethertype()
785+
{
786+
let ip_header_len: u16;
787+
let ip_packet_len: usize;
788+
let protocol;
789+
let pseudo_header_checksum;
790+
match ip {
791+
smoltcp::wire::EthernetProtocol::Ipv4 => {
792+
let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut());
793+
ip_header_len = ip_packet.header_len().into();
794+
ip_packet_len = ip_packet.total_len().into();
795+
protocol = ip_packet.next_header();
796+
pseudo_header_checksum =
797+
partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet);
798+
}
799+
smoltcp::wire::EthernetProtocol::Ipv6 => {
800+
let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut());
801+
ip_header_len = ip_packet.header_len().try_into().expect(
802+
"VIRTIO does not support IP headers that are longer than u16::MAX bytes.",
803+
);
804+
ip_packet_len = ip_packet.total_len();
805+
protocol = ip_packet.next_header();
806+
pseudo_header_checksum =
807+
partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet);
808+
}
809+
_ => unreachable!(),
810+
}
811+
// Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
812+
if let smoltcp::wire::IpProtocol::Tcp | smoltcp::wire::IpProtocol::Udp = protocol {
813+
let ip_payload =
814+
&mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len];
815+
816+
// We do not care about the offset of the checksum for the protocol if we don't require checksum
817+
// from the host, so we use None to signal that checksum from the host is not needed.
818+
let csum_offset = match protocol {
819+
smoltcp::wire::IpProtocol::Tcp => {
820+
if self.checksums.tcp.tx() {
821+
None
822+
} else {
823+
let mut tcp_packet =
824+
smoltcp::wire::TcpPacket::new_unchecked(ip_payload);
825+
tcp_packet.set_checksum(pseudo_header_checksum);
826+
Some(16)
827+
}
828+
}
829+
smoltcp::wire::IpProtocol::Udp => {
830+
if self.checksums.tcp.tx() {
831+
None
832+
} else {
833+
let mut udp_packet =
834+
smoltcp::wire::UdpPacket::new_unchecked(ip_payload);
835+
udp_packet.set_checksum(pseudo_header_checksum);
836+
Some(6)
837+
}
838+
}
839+
_ => None,
840+
};
841+
csum_offset.map(|csum_offset| (ip_header_len, csum_offset))
842+
} else {
843+
None
844+
}
845+
} else {
846+
None
847+
}
848+
} else {
849+
None
850+
}
851+
}
794852
}
795853

796854
pub mod constants {
@@ -815,3 +873,55 @@ pub mod error {
815873
IncompatibleFeatureSets(virtio::net::F, virtio::net::F),
816874
}
817875
}
876+
877+
/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
878+
/// and their results are meant to be combined with the TCP payload to calculate the real checksum.
879+
/// They are only useful for the VIRTIO driver with the checksum offloading feature.
880+
///
881+
/// The calculations here can theoretically be made faster by exploiting the properties described in
882+
/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
883+
mod partial_checksum {
884+
use smoltcp::wire::{Ipv4Packet, Ipv6Packet};
885+
886+
fn addr_sum<const N: usize>(addr: &[u8; N]) -> u16 {
887+
let mut sum = 0;
888+
for i in 0..(N / size_of::<u16>()) {
889+
sum = ones_complement_add(sum, (u16::from(addr[i]) << 8) | u16::from(addr[i + 1]));
890+
}
891+
sum
892+
}
893+
894+
/// Calculates the checksum for the IPv4 pseudo-header as described in
895+
/// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
896+
pub(super) fn ipv4_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
897+
packet: &Ipv4Packet<T>,
898+
) -> u16 {
899+
let padded_protocol = u16::from(u8::from(packet.next_header()));
900+
let payload_len = packet.total_len() - u16::from(packet.header_len());
901+
902+
let mut sum = addr_sum(&packet.src_addr().octets());
903+
sum = ones_complement_add(sum, addr_sum(&packet.dst_addr().octets()));
904+
sum = ones_complement_add(sum, padded_protocol);
905+
ones_complement_add(sum, payload_len)
906+
}
907+
908+
/// Calculates the checksum for the IPv6 pseudo-header as described in
909+
/// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
910+
pub(super) fn ipv6_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
911+
packet: &Ipv6Packet<T>,
912+
) -> u16 {
913+
warn!("The IPv6 partial checksum implementation is untested!");
914+
let padded_protocol = u16::from(u8::from(packet.next_header()));
915+
916+
let mut sum = addr_sum(&packet.src_addr().octets());
917+
sum = ones_complement_add(sum, addr_sum(&packet.dst_addr().octets()));
918+
sum = ones_complement_add(sum, packet.payload_len());
919+
ones_complement_add(sum, padded_protocol)
920+
}
921+
922+
/// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
923+
fn ones_complement_add(lhs: u16, rhs: u16) -> u16 {
924+
let (sum, overflow) = u16::overflowing_add(lhs, rhs);
925+
sum + u16::from(overflow)
926+
}
927+
}

0 commit comments

Comments
 (0)