Skip to content

Commit f4f55bc

Browse files
committed
fix(virtio-net): Only initialize virtqueues after feature negotiation
This refactors the initialization sequence using type state. Signed-off-by: Jens Reidel <[email protected]>
1 parent 6786d48 commit f4f55bc

File tree

3 files changed

+97
-94
lines changed

3 files changed

+97
-94
lines changed

src/drivers/net/virtio/mmio.rs

Lines changed: 17 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,21 +2,19 @@
22
//!
33
//! The module contains ...
44
5-
use alloc::vec::Vec;
65
use core::str::FromStr;
76

87
use smoltcp::phy::ChecksumCapabilities;
98
use virtio::mmio::{DeviceRegisters, DeviceRegistersVolatileFieldAccess};
109
use volatile::VolatileRef;
1110

1211
use crate::drivers::InterruptLine;
13-
use crate::drivers::net::virtio::{CtrlQueue, NetDevCfg, RxQueues, TxQueues, VirtioNetDriver};
12+
use crate::drivers::net::virtio::{Init, NetDevCfg, Uninit, VirtioNetDriver};
1413
use crate::drivers::virtio::error::{VirtioError, VirtioNetError};
1514
use crate::drivers::virtio::transport::mmio::{ComCfg, IsrStatus, NotifCfg};
16-
use crate::drivers::virtio::virtqueue::VirtQueue;
1715

1816
// Backend-dependent interface for Virtio network driver
19-
impl VirtioNetDriver {
17+
impl VirtioNetDriver<Uninit> {
2018
pub fn new(
2119
dev_id: u16,
2220
mut registers: VolatileRef<'static, DeviceRegisters>,
@@ -47,30 +45,19 @@ impl VirtioNetDriver {
4745
1514
4846
};
4947

50-
let send_vqs = TxQueues::new(Vec::<VirtQueue>::new(), &dev_cfg);
51-
let recv_vqs = RxQueues::new(Vec::<VirtQueue>::new(), &dev_cfg);
5248
Ok(VirtioNetDriver {
5349
dev_cfg,
5450
com_cfg: ComCfg::new(registers, 1),
5551
isr_stat,
5652
notif_cfg,
57-
ctrl_vq: CtrlQueue::new(None),
58-
recv_vqs,
59-
send_vqs,
53+
inner: Uninit,
6054
num_vqs: 0,
6155
mtu,
6256
irq,
6357
checksums: ChecksumCapabilities::default(),
6458
})
6559
}
6660

67-
pub fn print_information(&mut self) {
68-
self.com_cfg.print_information();
69-
if self.dev_status() == virtio::net::S::LINK_UP {
70-
info!("The link of the network device is up!");
71-
}
72-
}
73-
7461
/// Initializes virtio network device by mapping configuration layout to
7562
/// respective structs (configuration structs are:
7663
///
@@ -80,13 +67,13 @@ impl VirtioNetDriver {
8067
dev_id: u16,
8168
registers: VolatileRef<'static, DeviceRegisters>,
8269
irq: InterruptLine,
83-
) -> Result<VirtioNetDriver, VirtioError> {
84-
if let Ok(mut drv) = VirtioNetDriver::new(dev_id, registers, irq) {
70+
) -> Result<VirtioNetDriver<Init>, VirtioError> {
71+
if let Ok(drv) = VirtioNetDriver::new(dev_id, registers, irq) {
8572
match drv.init_dev() {
8673
Err(error_code) => Err(VirtioError::NetDriver(error_code)),
87-
_ => {
88-
drv.print_information();
89-
Ok(drv)
74+
Ok(mut initialized_drv) => {
75+
initialized_drv.print_information();
76+
Ok(initialized_drv)
9077
}
9178
}
9279
} else {
@@ -95,3 +82,12 @@ impl VirtioNetDriver {
9582
}
9683
}
9784
}
85+
86+
impl VirtioNetDriver<Init> {
87+
pub fn print_information(&mut self) {
88+
self.com_cfg.print_information();
89+
if self.dev_status() == virtio::net::S::LINK_UP {
90+
info!("The link of the network device is up!");
91+
}
92+
}
93+
}

src/drivers/net/virtio/mod.rs

Lines changed: 63 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -46,32 +46,22 @@ pub(crate) struct NetDevCfg {
4646
pub features: virtio::net::F,
4747
}
4848

49-
pub struct CtrlQueue(Option<VirtQueue>);
50-
51-
impl CtrlQueue {
52-
pub fn new(vq: Option<VirtQueue>) -> Self {
53-
CtrlQueue(vq)
54-
}
55-
}
56-
5749
pub struct RxQueues {
5850
vqs: Vec<VirtQueue>,
5951
packet_size: u32,
6052
}
6153

6254
impl RxQueues {
6355
pub fn new(vqs: Vec<VirtQueue>, dev_cfg: &NetDevCfg) -> Self {
64-
// See Virtio specification v1.1 - 5.1.6.3.1
56+
// See Virtio specification v1.1 - 5.1.6.3.1 and 5.1.4.2
6557
//
66-
let packet_size = if dev_cfg.features.contains(virtio::net::F::MRG_RXBUF) {
58+
#[allow(clippy::decimal_literal_representation)]
59+
let packet_size = if dev_cfg.features.contains(virtio::net::F::MTU) {
60+
65550
61+
} else if dev_cfg.features.contains(virtio::net::F::MRG_RXBUF) {
6762
1514
68-
} else if dev_cfg.features.contains(virtio::net::F::GUEST_TSO4)
69-
|| dev_cfg.features.contains(virtio::net::F::GUEST_TSO6)
70-
|| dev_cfg.features.contains(virtio::net::F::GUEST_UFO)
71-
{
72-
dev_cfg.raw.as_ptr().mtu().read().to_ne().into()
7363
} else {
74-
1514
64+
dev_cfg.raw.as_ptr().mtu().read().to_ne().into()
7565
};
7666

7767
Self { vqs, packet_size }
@@ -196,27 +186,32 @@ impl TxQueues {
196186
}
197187
}
198188

189+
pub(crate) struct Uninit;
190+
pub(crate) struct Init {
191+
pub(super) ctrl_vq: Option<VirtQueue>,
192+
pub(super) recv_vqs: RxQueues,
193+
pub(super) send_vqs: TxQueues,
194+
}
195+
199196
/// Virtio network driver struct.
200197
///
201198
/// Struct allows to control devices virtqueues as also
202199
/// the device itself.
203-
pub(crate) struct VirtioNetDriver {
200+
pub(crate) struct VirtioNetDriver<T = Init> {
204201
pub(super) dev_cfg: NetDevCfg,
205202
pub(super) com_cfg: ComCfg,
206203
pub(super) isr_stat: IsrStatus,
207204
pub(super) notif_cfg: NotifCfg,
208205

209-
pub(super) ctrl_vq: CtrlQueue,
210-
pub(super) recv_vqs: RxQueues,
211-
pub(super) send_vqs: TxQueues,
206+
pub(super) inner: T,
212207

213208
pub(super) num_vqs: u16,
214209
pub(super) mtu: u16,
215210
pub(super) irq: InterruptLine,
216211
pub(super) checksums: ChecksumCapabilities,
217212
}
218213

219-
impl NetworkDriver for VirtioNetDriver {
214+
impl NetworkDriver for VirtioNetDriver<Init> {
220215
/// Returns the mac address of the device.
221216
/// If VIRTIO_NET_F_MAC is not set, the function panics currently!
222217
fn get_mac_address(&self) -> [u8; 6] {
@@ -240,7 +235,7 @@ impl NetworkDriver for VirtioNetDriver {
240235

241236
#[allow(dead_code)]
242237
fn has_packet(&self) -> bool {
243-
self.recv_vqs.has_packet()
238+
self.inner.recv_vqs.has_packet()
244239
}
245240

246241
/// Provides smoltcp a slice to copy the IP packet and transfer the packet
@@ -251,9 +246,9 @@ impl NetworkDriver for VirtioNetDriver {
251246
{
252247
// We need to poll to get the queue to remove elements from the table and make space for
253248
// what we are about to add
254-
self.send_vqs.poll();
249+
self.inner.send_vqs.poll();
255250

256-
assert!(len < usize::try_from(self.send_vqs.packet_length).unwrap());
251+
assert!(len < usize::try_from(self.inner.send_vqs.packet_length).unwrap());
257252
let mut packet = Vec::with_capacity_in(len, DeviceAlloc);
258253
let result = unsafe {
259254
let result = f(packet.spare_capacity_mut().assume_init_mut());
@@ -302,7 +297,7 @@ impl NetworkDriver for VirtioNetDriver {
302297
)
303298
.unwrap();
304299

305-
self.send_vqs.vqs[0]
300+
self.inner.send_vqs.vqs[0]
306301
.dispatch(buff_tkn, false, BufferType::Direct)
307302
.unwrap();
308303

@@ -311,7 +306,7 @@ impl NetworkDriver for VirtioNetDriver {
311306

312307
fn receive_packet(&mut self) -> Option<(RxToken, TxToken)> {
313308
let mut receive_single_packet = || {
314-
let mut buffer_tkn = self.recv_vqs.get_next()?;
309+
let mut buffer_tkn = self.inner.recv_vqs.get_next()?;
315310
RxQueues::post_processing(&mut buffer_tkn)
316311
.inspect_err(|vnet_err| warn!("Post processing failed. Err: {vnet_err:?}"))
317312
.ok()?;
@@ -338,9 +333,9 @@ impl NetworkDriver for VirtioNetDriver {
338333
}
339334

340335
fill_queue(
341-
&mut self.recv_vqs.vqs[0],
336+
&mut self.inner.recv_vqs.vqs[0],
342337
num_buffers,
343-
self.recv_vqs.packet_size,
338+
self.inner.recv_vqs.packet_size,
344339
);
345340

346341
Some((RxToken::new(combined_packets), TxToken::new()))
@@ -373,7 +368,7 @@ impl NetworkDriver for VirtioNetDriver {
373368
}
374369
}
375370

376-
impl Driver for VirtioNetDriver {
371+
impl Driver for VirtioNetDriver<Init> {
377372
fn get_interrupt_number(&self) -> InterruptLine {
378373
self.irq
379374
}
@@ -384,17 +379,12 @@ impl Driver for VirtioNetDriver {
384379
}
385380

386381
// Backend-independent interface for Virtio network driver
387-
impl VirtioNetDriver {
382+
impl VirtioNetDriver<Init> {
388383
#[cfg(feature = "pci")]
389384
pub fn get_dev_id(&self) -> u16 {
390385
self.dev_cfg.dev_id
391386
}
392387

393-
#[cfg(feature = "pci")]
394-
pub fn set_failed(&mut self) {
395-
self.com_cfg.set_failed();
396-
}
397-
398388
/// Returns the current status of the device, if VIRTIO_NET_F_STATUS
399389
/// has been negotiated. Otherwise assumes an active device.
400390
#[cfg(not(feature = "pci"))]
@@ -458,21 +448,23 @@ impl VirtioNetDriver {
458448
pub fn disable_interrupts(&mut self) {
459449
// For send and receive queues?
460450
// Only for receive? Because send is off anyway?
461-
self.recv_vqs.disable_notifs();
451+
self.inner.recv_vqs.disable_notifs();
462452
}
463453

464454
pub fn enable_interrupts(&mut self) {
465455
// For send and receive queues?
466456
// Only for receive? Because send is off anyway?
467-
self.recv_vqs.enable_notifs();
457+
self.inner.recv_vqs.enable_notifs();
468458
}
459+
}
469460

461+
impl VirtioNetDriver<Uninit> {
470462
/// Initializes the device in adherence to specification. Returns Some(VirtioNetError)
471463
/// upon failure and None in case everything worked as expected.
472464
///
473465
/// See Virtio specification v1.1. - 3.1.1.
474466
/// and v1.1. - 5.1.5
475-
pub fn init_dev(&mut self) -> Result<(), VirtioNetError> {
467+
pub fn init_dev(mut self) -> Result<VirtioNetDriver<Init>, VirtioNetError> {
476468
// Reset
477469
self.com_cfg.reset_dev();
478470

@@ -598,7 +590,13 @@ impl VirtioNetDriver {
598590
return Err(VirtioNetError::FailFeatureNeg(self.dev_cfg.dev_id));
599591
}
600592

601-
self.dev_spec_init()?;
593+
let mut inner = Init {
594+
ctrl_vq: None,
595+
recv_vqs: RxQueues::new(Vec::new(), &self.dev_cfg),
596+
send_vqs: TxQueues::new(Vec::new(), &self.dev_cfg),
597+
};
598+
599+
self.dev_spec_init(&mut inner)?;
602600
info!(
603601
"Device specific initialization for Virtio network device {:x} finished",
604602
self.dev_cfg.dev_id
@@ -625,7 +623,17 @@ impl VirtioNetDriver {
625623
self.mtu = self.dev_cfg.raw.as_ptr().mtu().read().to_ne();
626624
}
627625

628-
Ok(())
626+
Ok(VirtioNetDriver {
627+
dev_cfg: self.dev_cfg,
628+
com_cfg: self.com_cfg,
629+
isr_stat: self.isr_stat,
630+
notif_cfg: self.notif_cfg,
631+
inner,
632+
num_vqs: self.num_vqs,
633+
mtu: self.mtu,
634+
irq: self.irq,
635+
checksums: self.checksums,
636+
})
629637
}
630638

631639
/// Negotiates a subset of features, understood and wanted by both the OS
@@ -655,14 +663,14 @@ impl VirtioNetDriver {
655663
}
656664

657665
/// Device Specific initialization according to Virtio specifictation v1.1. - 5.1.5
658-
fn dev_spec_init(&mut self) -> Result<(), VirtioNetError> {
659-
self.virtqueue_init()?;
666+
fn dev_spec_init(&mut self, inner: &mut Init) -> Result<(), VirtioNetError> {
667+
self.virtqueue_init(inner)?;
660668
info!("Network driver successfully initialized virtqueues.");
661669

662670
// Add a control if feature is negotiated
663671
if self.dev_cfg.features.contains(virtio::net::F::CTRL_VQ) {
664-
if self.dev_cfg.features.contains(virtio::net::F::RING_PACKED) {
665-
self.ctrl_vq = CtrlQueue(Some(VirtQueue::Packed(
672+
let mut ctrl_vq = if self.dev_cfg.features.contains(virtio::net::F::RING_PACKED) {
673+
VirtQueue::Packed(
666674
PackedVq::new(
667675
&mut self.com_cfg,
668676
&self.notif_cfg,
@@ -671,9 +679,9 @@ impl VirtioNetDriver {
671679
self.dev_cfg.features.into(),
672680
)
673681
.unwrap(),
674-
)));
682+
)
675683
} else {
676-
self.ctrl_vq = CtrlQueue(Some(VirtQueue::Split(
684+
VirtQueue::Split(
677685
SplitVq::new(
678686
&mut self.com_cfg,
679687
&self.notif_cfg,
@@ -682,17 +690,18 @@ impl VirtioNetDriver {
682690
self.dev_cfg.features.into(),
683691
)
684692
.unwrap(),
685-
)));
686-
}
693+
)
694+
};
687695

688-
self.ctrl_vq.0.as_mut().unwrap().enable_notifs();
696+
ctrl_vq.enable_notifs();
697+
inner.ctrl_vq = Some(ctrl_vq);
689698
}
690699

691700
Ok(())
692701
}
693702

694703
/// Initialize virtqueues via the queue interface and populates receiving queues
695-
fn virtqueue_init(&mut self) -> Result<(), VirtioNetError> {
704+
fn virtqueue_init(&mut self, inner: &mut Init) -> Result<(), VirtioNetError> {
696705
// We are assuming here, that the device single source of truth is the
697706
// device specific configuration. Hence we do NOT check if
698707
//
@@ -749,7 +758,7 @@ impl VirtioNetDriver {
749758
// Interrupt for receiving packets is wanted
750759
vq.enable_notifs();
751760

752-
self.recv_vqs.add(VirtQueue::Packed(vq));
761+
inner.recv_vqs.add(VirtQueue::Packed(vq));
753762

754763
let mut vq = PackedVq::new(
755764
&mut self.com_cfg,
@@ -762,7 +771,7 @@ impl VirtioNetDriver {
762771
// Interrupt for communicating that a sended packet left, is not needed
763772
vq.disable_notifs();
764773

765-
self.send_vqs.add(VirtQueue::Packed(vq));
774+
inner.send_vqs.add(VirtQueue::Packed(vq));
766775
} else {
767776
let mut vq = SplitVq::new(
768777
&mut self.com_cfg,
@@ -775,7 +784,7 @@ impl VirtioNetDriver {
775784
// Interrupt for receiving packets is wanted
776785
vq.enable_notifs();
777786

778-
self.recv_vqs.add(VirtQueue::Split(vq));
787+
inner.recv_vqs.add(VirtQueue::Split(vq));
779788

780789
let mut vq = SplitVq::new(
781790
&mut self.com_cfg,
@@ -788,7 +797,7 @@ impl VirtioNetDriver {
788797
// Interrupt for communicating that a sended packet left, is not needed
789798
vq.disable_notifs();
790799

791-
self.send_vqs.add(VirtQueue::Split(vq));
800+
inner.send_vqs.add(VirtQueue::Split(vq));
792801
}
793802
}
794803

0 commit comments

Comments
 (0)