Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
153 changes: 123 additions & 30 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -271,38 +271,14 @@ impl NetworkDriver for VirtioNetDriver<Init> {
};

let mut header = Box::new_in(<Hdr as Default>::default(), DeviceAlloc);
// If a checksum isn't necessary, we have inform the host within the header

// If a checksum calculation by the host is necessary, we have to inform the host within the header
// see Virtio specification 5.1.6.2
if !self.checksums.tcp.tx() || !self.checksums.udp.tx() {
if let Some((ip_header_len, csum_offset)) = self.should_request_checksum(&mut packet) {
header.flags = HdrF::NEEDS_CSUM;
let ethernet_frame: smoltcp::wire::EthernetFrame<&[u8]> =
EthernetFrame::new_unchecked(&packet);
let packet_header_len: u16;
let protocol;
match ethernet_frame.ethertype() {
smoltcp::wire::EthernetProtocol::Ipv4 => {
let packet = Ipv4Packet::new_unchecked(ethernet_frame.payload());
packet_header_len = packet.header_len().into();
protocol = Some(packet.next_header());
}
smoltcp::wire::EthernetProtocol::Ipv6 => {
let packet = Ipv6Packet::new_unchecked(ethernet_frame.payload());
packet_header_len = packet.header_len().try_into().unwrap();
protocol = Some(packet.next_header());
}
_ => {
packet_header_len = 0;
protocol = None;
}
}
header.csum_start =
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + packet_header_len).into();
header.csum_offset = match protocol {
Some(smoltcp::wire::IpProtocol::Tcp) => 16,
Some(smoltcp::wire::IpProtocol::Udp) => 6,
_ => 0,
}
.into();
(u16::try_from(ETHERNET_HEADER_LEN).unwrap() + ip_header_len).into();
header.csum_offset = csum_offset.into();
}

let buff_tkn = AvailBufferToken::new(
Expand Down Expand Up @@ -488,6 +464,65 @@ impl VirtioNetDriver<Init> {
// Only for receive? Because send is off anyway?
self.inner.recv_vqs.enable_notifs();
}

/// If necessary, sets the TCP or UDP checksum field to the checksum of the
/// pseudo-header and returns the IP header length and the checksum offset.
/// Otherwise, returns None.
fn should_request_checksum<T: AsRef<[u8]> + AsMut<[u8]>>(
&self,
frame: T,
) -> Option<(u16, u16)> {
if self.checksums.tcp.tx() && self.checksums.udp.tx() {
return None;
}

let ip_header_len: u16;
let ip_packet_len: usize;
let protocol;
let pseudo_header_checksum;
let mut ethernet_frame = EthernetFrame::new_unchecked(frame);
match ethernet_frame.ethertype() {
smoltcp::wire::EthernetProtocol::Ipv4 => {
let ip_packet = Ipv4Packet::new_unchecked(&*ethernet_frame.payload_mut());
ip_header_len = ip_packet.header_len().into();
ip_packet_len = ip_packet.total_len().into();
protocol = ip_packet.next_header();
pseudo_header_checksum =
partial_checksum::ipv4_pseudo_header_partial_checksum(&ip_packet);
}
smoltcp::wire::EthernetProtocol::Ipv6 => {
let ip_packet = Ipv6Packet::new_unchecked(&*ethernet_frame.payload_mut());
ip_header_len = ip_packet.header_len().try_into().expect(
"VIRTIO does not support IP headers that are longer than u16::MAX bytes.",
);
ip_packet_len = ip_packet.total_len();
protocol = ip_packet.next_header();
pseudo_header_checksum =
partial_checksum::ipv6_pseudo_header_partial_checksum(&ip_packet);
}
// If the Ethernet protocol is not one of these two above, for which we know there may be a checksum field,
// we default to not asking for checksum, as otherwise the frame will be corrupted by the device trying
// to write the checksum.
_ => return None,
};

let csum_offset;
let ip_payload = &mut ethernet_frame.payload_mut()[ip_header_len.into()..ip_packet_len];
// Like the Ethernet protocol check, we check for IP protocols for which we know the location of the checksum field.
if protocol == smoltcp::wire::IpProtocol::Tcp && !self.checksums.tcp.tx() {
let mut tcp_packet = smoltcp::wire::TcpPacket::new_unchecked(ip_payload);
tcp_packet.set_checksum(pseudo_header_checksum);
csum_offset = 16;
} else if protocol == smoltcp::wire::IpProtocol::Udp && !self.checksums.udp.tx() {
let mut udp_packet = smoltcp::wire::UdpPacket::new_unchecked(ip_payload);
udp_packet.set_checksum(pseudo_header_checksum);
csum_offset = 6;
} else {
return None;
};

Some((ip_header_len, csum_offset))
}
}

impl VirtioNetDriver<Uninit> {
Expand Down Expand Up @@ -524,7 +559,9 @@ impl VirtioNetDriver<Uninit> {
// control queue support
| virtio::net::F::CTRL_VQ
// Multiqueue support
| virtio::net::F::MQ;
| virtio::net::F::MQ
// Checksum calculation can partially be offloaded to the device
| virtio::net::F::CSUM;

// Currently the driver does NOT support the features below.
// In order to provide functionality for these, the driver
Expand Down Expand Up @@ -853,3 +890,59 @@ pub mod error {
IncompatibleFeatureSets(virtio::net::F, virtio::net::F),
}
}

/// The checksum functions in this module only calculate the one's complement sum for the pseudo-header
/// and their results are meant to be combined with the TCP payload to calculate the real checksum.
/// They are only useful for the VIRTIO driver with the checksum offloading feature.
///
/// The calculations here can theoretically be made faster by exploiting the properties described in
/// [RFC 1071 section 2](https://www.rfc-editor.org/rfc/rfc1071).
mod partial_checksum {
use smoltcp::wire::{Ipv4Packet, Ipv6Packet};

fn addr_sum<const N: usize>(addr: &[u8; N]) -> u16 {
let mut sum = 0;
const CHUNK_SIZE: usize = size_of::<u16>();
for i in 0..(N / CHUNK_SIZE) {
sum = ones_complement_add(
sum,
(u16::from(addr[CHUNK_SIZE * i]) << 8) | u16::from(addr[CHUNK_SIZE * i + 1]),
);
}
sum
}

/// Calculates the checksum for the IPv4 pseudo-header as described in
/// [RFC 9293 subsection 3.1](https://www.rfc-editor.org/rfc/rfc9293.html#section-3.1-6.18.1) WITHOUT the final inversion.
pub(super) fn ipv4_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
packet: &Ipv4Packet<T>,
) -> u16 {
let padded_protocol = u16::from(u8::from(packet.next_header()));
let payload_len = packet.total_len() - u16::from(packet.header_len());

let mut sum = addr_sum(&packet.src_addr().octets());
sum = ones_complement_add(sum, addr_sum(&packet.dst_addr().octets()));
sum = ones_complement_add(sum, padded_protocol);
ones_complement_add(sum, payload_len)
}

/// Calculates the checksum for the IPv6 pseudo-header as described in
/// [RFC 8200 subsection 8.1](https://www.rfc-editor.org/rfc/rfc8200.html#section-8.1) WITHOUT the final inversion.
pub(super) fn ipv6_pseudo_header_partial_checksum<T: AsRef<[u8]>>(
packet: &Ipv6Packet<T>,
) -> u16 {
warn!("The IPv6 partial checksum implementation is untested!");
let padded_protocol = u16::from(u8::from(packet.next_header()));

let mut sum = addr_sum(&packet.src_addr().octets());
sum = ones_complement_add(sum, addr_sum(&packet.dst_addr().octets()));
sum = ones_complement_add(sum, packet.payload_len());
ones_complement_add(sum, padded_protocol)
}

/// Implements one's complement checksum as described in [RFC 1071 section 1](https://www.rfc-editor.org/rfc/rfc1071#section-1).
fn ones_complement_add(lhs: u16, rhs: u16) -> u16 {
let (sum, overflow) = u16::overflowing_add(lhs, rhs);
sum + u16::from(overflow)
}
}