Skip to content

fix(virtio-net): Only initialize virtqueues after feature negotiation #1729

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 17 additions & 21 deletions src/drivers/net/virtio/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,21 +2,19 @@
//!
//! The module contains ...

use alloc::vec::Vec;
use core::str::FromStr;

use smoltcp::phy::ChecksumCapabilities;
use virtio::mmio::{DeviceRegisters, DeviceRegistersVolatileFieldAccess};
use volatile::VolatileRef;

use crate::drivers::InterruptLine;
use crate::drivers::net::virtio::{CtrlQueue, NetDevCfg, RxQueues, TxQueues, VirtioNetDriver};
use crate::drivers::net::virtio::{Init, NetDevCfg, Uninit, VirtioNetDriver};
use crate::drivers::virtio::error::{VirtioError, VirtioNetError};
use crate::drivers::virtio::transport::mmio::{ComCfg, IsrStatus, NotifCfg};
use crate::drivers::virtio::virtqueue::VirtQueue;

// Backend-dependent interface for Virtio network driver
impl VirtioNetDriver {
impl VirtioNetDriver<Uninit> {
pub fn new(
dev_id: u16,
mut registers: VolatileRef<'static, DeviceRegisters>,
Expand Down Expand Up @@ -47,30 +45,19 @@ impl VirtioNetDriver {
1514
};

let send_vqs = TxQueues::new(Vec::<VirtQueue>::new(), &dev_cfg);
let recv_vqs = RxQueues::new(Vec::<VirtQueue>::new(), &dev_cfg);
Ok(VirtioNetDriver {
dev_cfg,
com_cfg: ComCfg::new(registers, 1),
isr_stat,
notif_cfg,
ctrl_vq: CtrlQueue::new(None),
recv_vqs,
send_vqs,
inner: Uninit,
num_vqs: 0,
mtu,
irq,
checksums: ChecksumCapabilities::default(),
})
}

pub fn print_information(&mut self) {
self.com_cfg.print_information();
if self.dev_status() == virtio::net::S::LINK_UP {
info!("The link of the network device is up!");
}
}

/// Initializes virtio network device by mapping configuration layout to
/// respective structs (configuration structs are:
///
Expand All @@ -80,13 +67,13 @@ impl VirtioNetDriver {
dev_id: u16,
registers: VolatileRef<'static, DeviceRegisters>,
irq: InterruptLine,
) -> Result<VirtioNetDriver, VirtioError> {
if let Ok(mut drv) = VirtioNetDriver::new(dev_id, registers, irq) {
) -> Result<VirtioNetDriver<Init>, VirtioError> {
if let Ok(drv) = VirtioNetDriver::new(dev_id, registers, irq) {
match drv.init_dev() {
Err(error_code) => Err(VirtioError::NetDriver(error_code)),
_ => {
drv.print_information();
Ok(drv)
Ok(mut initialized_drv) => {
initialized_drv.print_information();
Ok(initialized_drv)
}
}
} else {
Expand All @@ -95,3 +82,12 @@ impl VirtioNetDriver {
}
}
}

impl VirtioNetDriver<Init> {
pub fn print_information(&mut self) {
self.com_cfg.print_information();
if self.dev_status() == virtio::net::S::LINK_UP {
info!("The link of the network device is up!");
}
}
}
103 changes: 57 additions & 46 deletions src/drivers/net/virtio/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,6 @@ pub(crate) struct NetDevCfg {
pub features: virtio::net::F,
}

pub struct CtrlQueue(Option<VirtQueue>);

impl CtrlQueue {
pub fn new(vq: Option<VirtQueue>) -> Self {
CtrlQueue(vq)
}
}

pub struct RxQueues {
vqs: Vec<VirtQueue>,
packet_size: u32,
Expand Down Expand Up @@ -191,27 +183,32 @@ impl TxQueues {
}
}

pub(crate) struct Uninit;
pub(crate) struct Init {
pub(super) ctrl_vq: Option<VirtQueue>,
pub(super) recv_vqs: RxQueues,
pub(super) send_vqs: TxQueues,
}

/// Virtio network driver struct.
///
/// Struct allows to control devices virtqueues as also
/// the device itself.
pub(crate) struct VirtioNetDriver {
pub(crate) struct VirtioNetDriver<T = Init> {
pub(super) dev_cfg: NetDevCfg,
pub(super) com_cfg: ComCfg,
pub(super) isr_stat: IsrStatus,
pub(super) notif_cfg: NotifCfg,

pub(super) ctrl_vq: CtrlQueue,
pub(super) recv_vqs: RxQueues,
pub(super) send_vqs: TxQueues,
pub(super) inner: T,

pub(super) num_vqs: u16,
pub(super) mtu: u16,
pub(super) irq: InterruptLine,
pub(super) checksums: ChecksumCapabilities,
}

impl NetworkDriver for VirtioNetDriver {
impl NetworkDriver for VirtioNetDriver<Init> {
/// Returns the mac address of the device.
/// If VIRTIO_NET_F_MAC is not set, the function panics currently!
fn get_mac_address(&self) -> [u8; 6] {
Expand All @@ -235,7 +232,7 @@ impl NetworkDriver for VirtioNetDriver {

#[allow(dead_code)]
fn has_packet(&self) -> bool {
self.recv_vqs.has_packet()
self.inner.recv_vqs.has_packet()
}

/// Provides smoltcp a slice to copy the IP packet and transfer the packet
Expand All @@ -246,9 +243,9 @@ impl NetworkDriver for VirtioNetDriver {
{
// We need to poll to get the queue to remove elements from the table and make space for
// what we are about to add
self.send_vqs.poll();
self.inner.send_vqs.poll();

assert!(len < usize::try_from(self.send_vqs.packet_length).unwrap());
assert!(len < usize::try_from(self.inner.send_vqs.packet_length).unwrap());
let mut packet = Vec::with_capacity_in(len, DeviceAlloc);
let result = unsafe {
let result = f(packet.spare_capacity_mut().assume_init_mut());
Expand Down Expand Up @@ -297,7 +294,7 @@ impl NetworkDriver for VirtioNetDriver {
)
.unwrap();

self.send_vqs.vqs[0]
self.inner.send_vqs.vqs[0]
.dispatch(buff_tkn, false, BufferType::Direct)
.unwrap();

Expand All @@ -306,7 +303,7 @@ impl NetworkDriver for VirtioNetDriver {

fn receive_packet(&mut self) -> Option<(RxToken, TxToken)> {
let mut receive_single_packet = || {
let mut buffer_tkn = self.recv_vqs.get_next()?;
let mut buffer_tkn = self.inner.recv_vqs.get_next()?;
RxQueues::post_processing(&mut buffer_tkn)
.inspect_err(|vnet_err| warn!("Post processing failed. Err: {vnet_err:?}"))
.ok()?;
Expand All @@ -333,9 +330,9 @@ impl NetworkDriver for VirtioNetDriver {
}

fill_queue(
&mut self.recv_vqs.vqs[0],
&mut self.inner.recv_vqs.vqs[0],
num_buffers,
self.recv_vqs.packet_size,
self.inner.recv_vqs.packet_size,
);

Some((RxToken::new(combined_packets), TxToken::new()))
Expand Down Expand Up @@ -368,7 +365,7 @@ impl NetworkDriver for VirtioNetDriver {
}
}

impl Driver for VirtioNetDriver {
impl Driver for VirtioNetDriver<Init> {
fn get_interrupt_number(&self) -> InterruptLine {
self.irq
}
Expand All @@ -379,17 +376,12 @@ impl Driver for VirtioNetDriver {
}

// Backend-independent interface for Virtio network driver
impl VirtioNetDriver {
impl VirtioNetDriver<Init> {
#[cfg(feature = "pci")]
pub fn get_dev_id(&self) -> u16 {
self.dev_cfg.dev_id
}

#[cfg(feature = "pci")]
pub fn set_failed(&mut self) {
self.com_cfg.set_failed();
}

/// Returns the current status of the device, if VIRTIO_NET_F_STATUS
/// has been negotiated. Otherwise assumes an active device.
#[cfg(not(feature = "pci"))]
Expand Down Expand Up @@ -453,21 +445,23 @@ impl VirtioNetDriver {
pub fn disable_interrupts(&mut self) {
// For send and receive queues?
// Only for receive? Because send is off anyway?
self.recv_vqs.disable_notifs();
self.inner.recv_vqs.disable_notifs();
}

pub fn enable_interrupts(&mut self) {
// For send and receive queues?
// Only for receive? Because send is off anyway?
self.recv_vqs.enable_notifs();
self.inner.recv_vqs.enable_notifs();
}
}

impl VirtioNetDriver<Uninit> {
/// Initializes the device in adherence to specification. Returns Some(VirtioNetError)
/// upon failure and None in case everything worked as expected.
///
/// See Virtio specification v1.1. - 3.1.1.
/// and v1.1. - 5.1.5
pub fn init_dev(&mut self) -> Result<(), VirtioNetError> {
pub fn init_dev(mut self) -> Result<VirtioNetDriver<Init>, VirtioNetError> {
// Reset
self.com_cfg.reset_dev();

Expand Down Expand Up @@ -593,7 +587,13 @@ impl VirtioNetDriver {
return Err(VirtioNetError::FailFeatureNeg(self.dev_cfg.dev_id));
}

self.dev_spec_init()?;
let mut inner = Init {
ctrl_vq: None,
recv_vqs: RxQueues::new(Vec::new(), &self.dev_cfg),
send_vqs: TxQueues::new(Vec::new(), &self.dev_cfg),
};

self.dev_spec_init(&mut inner)?;
info!(
"Device specific initialization for Virtio network device {:x} finished",
self.dev_cfg.dev_id
Expand All @@ -620,7 +620,17 @@ impl VirtioNetDriver {
self.mtu = self.dev_cfg.raw.as_ptr().mtu().read().to_ne();
}

Ok(())
Ok(VirtioNetDriver {
dev_cfg: self.dev_cfg,
com_cfg: self.com_cfg,
isr_stat: self.isr_stat,
notif_cfg: self.notif_cfg,
inner,
num_vqs: self.num_vqs,
mtu: self.mtu,
irq: self.irq,
checksums: self.checksums,
})
}

/// Negotiates a subset of features, understood and wanted by both the OS
Expand Down Expand Up @@ -650,14 +660,14 @@ impl VirtioNetDriver {
}

/// Device Specific initialization according to Virtio specifictation v1.1. - 5.1.5
fn dev_spec_init(&mut self) -> Result<(), VirtioNetError> {
self.virtqueue_init()?;
fn dev_spec_init(&mut self, inner: &mut Init) -> Result<(), VirtioNetError> {
self.virtqueue_init(inner)?;
info!("Network driver successfully initialized virtqueues.");

// Add a control if feature is negotiated
if self.dev_cfg.features.contains(virtio::net::F::CTRL_VQ) {
if self.dev_cfg.features.contains(virtio::net::F::RING_PACKED) {
self.ctrl_vq = CtrlQueue(Some(VirtQueue::Packed(
let mut ctrl_vq = if self.dev_cfg.features.contains(virtio::net::F::RING_PACKED) {
VirtQueue::Packed(
PackedVq::new(
&mut self.com_cfg,
&self.notif_cfg,
Expand All @@ -666,9 +676,9 @@ impl VirtioNetDriver {
self.dev_cfg.features.into(),
)
.unwrap(),
)));
)
} else {
self.ctrl_vq = CtrlQueue(Some(VirtQueue::Split(
VirtQueue::Split(
SplitVq::new(
&mut self.com_cfg,
&self.notif_cfg,
Expand All @@ -677,17 +687,18 @@ impl VirtioNetDriver {
self.dev_cfg.features.into(),
)
.unwrap(),
)));
}
)
};

self.ctrl_vq.0.as_mut().unwrap().enable_notifs();
ctrl_vq.enable_notifs();
inner.ctrl_vq = Some(ctrl_vq);
}

Ok(())
}

/// Initialize virtqueues via the queue interface and populates receiving queues
fn virtqueue_init(&mut self) -> Result<(), VirtioNetError> {
fn virtqueue_init(&mut self, inner: &mut Init) -> Result<(), VirtioNetError> {
// We are assuming here, that the device single source of truth is the
// device specific configuration. Hence we do NOT check if
//
Expand Down Expand Up @@ -744,7 +755,7 @@ impl VirtioNetDriver {
// Interrupt for receiving packets is wanted
vq.enable_notifs();

self.recv_vqs.add(VirtQueue::Packed(vq));
inner.recv_vqs.add(VirtQueue::Packed(vq));

let mut vq = PackedVq::new(
&mut self.com_cfg,
Expand All @@ -757,7 +768,7 @@ impl VirtioNetDriver {
// Interrupt for communicating that a sended packet left, is not needed
vq.disable_notifs();

self.send_vqs.add(VirtQueue::Packed(vq));
inner.send_vqs.add(VirtQueue::Packed(vq));
} else {
let mut vq = SplitVq::new(
&mut self.com_cfg,
Expand All @@ -770,7 +781,7 @@ impl VirtioNetDriver {
// Interrupt for receiving packets is wanted
vq.enable_notifs();

self.recv_vqs.add(VirtQueue::Split(vq));
inner.recv_vqs.add(VirtQueue::Split(vq));

let mut vq = SplitVq::new(
&mut self.com_cfg,
Expand All @@ -783,7 +794,7 @@ impl VirtioNetDriver {
// Interrupt for communicating that a sended packet left, is not needed
vq.disable_notifs();

self.send_vqs.add(VirtQueue::Split(vq));
inner.send_vqs.add(VirtQueue::Split(vq));
}
}

Expand Down
Loading