From 9c758676f5d67a222e1445af79b31789ce2e91b3 Mon Sep 17 00:00:00 2001 From: Jens Reidel Date: Mon, 7 Apr 2025 01:01:53 +0200 Subject: [PATCH] refactor: Remove dyn Any usage in BufferElem Signed-off-by: Jens Reidel --- Cargo.toml | 2 +- src/drivers/fs/virtio_fs.rs | 25 +++--- src/drivers/net/virtio/mod.rs | 19 ++--- src/drivers/virtio/virtqueue/mod.rs | 128 ++++++++++++++-------------- src/drivers/vsock/mod.rs | 9 +- src/fs/fuse.rs | 29 +++---- src/lib.rs | 1 + 7 files changed, 102 insertions(+), 111 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 4e9ade06df..e69961ffa7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -132,7 +132,7 @@ take-static = "0.1" talc = { version = "4" } time = { version = "0.3", default-features = false } volatile = "0.6" -zerocopy = { version = "0.8", default-features = false } +zerocopy = { version = "0.8", features = ["derive"], default-features = false } uhyve-interface = "0.1.3" [dependencies.smoltcp] diff --git a/src/drivers/fs/virtio_fs.rs b/src/drivers/fs/virtio_fs.rs index aea830adb0..cc4f88e659 100644 --- a/src/drivers/fs/virtio_fs.rs +++ b/src/drivers/fs/virtio_fs.rs @@ -1,4 +1,3 @@ -use alloc::boxed::Box; use alloc::string::{String, ToString}; use alloc::vec::Vec; use core::str; @@ -8,6 +7,7 @@ use virtio::FeatureBits; use virtio::fs::ConfigVolatileFieldAccess; use volatile::VolatileRef; use volatile::access::ReadOnly; +use zerocopy::{FromBytes, Immutable, IntoBytes}; use crate::config::VIRTIO_MAX_QUEUE_SIZE; use crate::drivers::Driver; @@ -158,30 +158,26 @@ impl FuseInterface for VirtioFsDriver { rsp_payload_len: u32, ) -> Result, VirtqError> where - ::InStruct: Send, - ::OutStruct: Send, + ::InStruct: Send + IntoBytes + Immutable, + ::OutStruct: Send + FromBytes, { let fuse::Cmd { headers: cmd_headers, payload: cmd_payload_opt, } = cmd; let send = if let Some(cmd_payload) = cmd_payload_opt { - vec![ - BufferElem::Sized(cmd_headers), - BufferElem::Vector(cmd_payload), - ] + vec![BufferElem::from(cmd_headers), BufferElem(cmd_payload)] } else { - vec![BufferElem::Sized(cmd_headers)] + vec![BufferElem::from(cmd_headers)] }; - let rsp_headers = Box::, _>::new_uninit_in(DeviceAlloc); let recv = if rsp_payload_len == 0 { - vec![BufferElem::Sized(rsp_headers)] + vec![BufferElem::new_uninit::>()] } else { let rsp_payload = Vec::with_capacity_in(rsp_payload_len as usize, DeviceAlloc); vec![ - BufferElem::Sized(rsp_headers), - BufferElem::Vector(rsp_payload), + BufferElem::new_uninit::>(), + BufferElem(rsp_payload), ] }; @@ -189,7 +185,10 @@ impl FuseInterface for VirtioFsDriver { let mut transfer_result = self.vqueues[1].dispatch_blocking(buffer_tkn, BufferType::Direct)?; - let headers = transfer_result.used_recv_buff.pop_front_downcast().unwrap(); + let headers = transfer_result + .used_recv_buff + .pop_front_deserialize() + .unwrap(); let payload = transfer_result.used_recv_buff.pop_front_vec(); Ok(Rsp { headers, payload }) } diff --git a/src/drivers/net/virtio/mod.rs b/src/drivers/net/virtio/mod.rs index 5002ba323b..1232fbaff5 100644 --- a/src/drivers/net/virtio/mod.rs +++ b/src/drivers/net/virtio/mod.rs @@ -10,7 +10,6 @@ cfg_if::cfg_if! { } } -use alloc::boxed::Box; use alloc::vec::Vec; use smoltcp::phy::{Checksum, ChecksumCapabilities}; @@ -116,8 +115,8 @@ fn fill_queue(vq: &mut dyn Virtq, num_packets: u16, packet_size: u32) { let buff_tkn = match AvailBufferToken::new( vec![], vec![ - BufferElem::Sized(Box::::new_uninit_in(DeviceAlloc)), - BufferElem::Vector(Vec::with_capacity_in( + BufferElem::new_uninit::(), + BufferElem(Vec::with_capacity_in( packet_size.try_into().unwrap(), DeviceAlloc, )), @@ -256,7 +255,7 @@ impl NetworkDriver for VirtioNetDriver { result }; - let mut header = Box::new_in(::default(), DeviceAlloc); + let mut header = Hdr::default(); // If a checksum isn't necessary, we have inform the host within the header // see Virtio specification 5.1.6.2 if !self.checksums.tcp.tx() || !self.checksums.udp.tx() { @@ -291,11 +290,9 @@ impl NetworkDriver for VirtioNetDriver { .into(); } - let buff_tkn = AvailBufferToken::new( - vec![BufferElem::Sized(header), BufferElem::Vector(packet)], - vec![], - ) - .unwrap(); + let buff_tkn = + AvailBufferToken::new(vec![BufferElem::from(header), BufferElem(packet)], vec![]) + .unwrap(); self.send_vqs.vqs[0] .dispatch(buff_tkn, false, BufferType::Direct) @@ -309,7 +306,7 @@ impl NetworkDriver for VirtioNetDriver { RxQueues::post_processing(&mut buffer_tkn) .inspect_err(|vnet_err| warn!("Post processing failed. Err: {vnet_err:?}")) .ok()?; - let first_header = buffer_tkn.used_recv_buff.pop_front_downcast::()?; + let first_header = buffer_tkn.used_recv_buff.pop_front_deserialize::()?; let first_packet = buffer_tkn.used_recv_buff.pop_front_vec()?; trace!("Header: {first_header:?}"); @@ -329,7 +326,7 @@ impl NetworkDriver for VirtioNetDriver { RxQueues::post_processing(&mut buffer_tkn) .inspect_err(|vnet_err| warn!("Post processing failed. Err: {vnet_err:?}")) .ok()?; - let _header = buffer_tkn.used_recv_buff.pop_front_downcast::()?; + let _header = buffer_tkn.used_recv_buff.pop_front_deserialize::()?; let packet = buffer_tkn.used_recv_buff.pop_front_vec()?; packets.push(packet); } diff --git a/src/drivers/virtio/virtqueue/mod.rs b/src/drivers/virtio/virtqueue/mod.rs index cfc86a4688..d13bfc4927 100644 --- a/src/drivers/virtio/virtqueue/mod.rs +++ b/src/drivers/virtio/virtqueue/mod.rs @@ -15,12 +15,11 @@ pub mod split; use alloc::boxed::Box; use alloc::collections::vec_deque::VecDeque; use alloc::vec::Vec; -use core::any::Any; -use core::mem::MaybeUninit; -use core::{mem, ptr}; +use core::mem; use memory_addresses::VirtAddr; use virtio::{le32, le64, pvirtq, virtq}; +use zerocopy::{Immutable, IntoBytes}; use self::error::VirtqError; #[cfg(not(feature = "pci"))] @@ -265,7 +264,7 @@ trait VirtqPrivate { .chain(recv_desc_iter) .map(|(mem_descr, len, incomplete_flags)| { Self::Descriptor::incomplete_desc( - paging::virt_to_phys(VirtAddr::from_ptr(mem_descr.addr())) + paging::virt_to_phys(VirtAddr::from_ptr(mem_descr.as_ptr())) .as_u64() .into(), len.into(), @@ -344,41 +343,49 @@ impl TransferToken { } #[derive(Debug)] -pub enum BufferElem { - Sized(Box), - Vector(Vec), -} +pub struct BufferElem(pub Vec); impl BufferElem { - // Returns the initialized length of the element. Assumes [Self::Sized] to - // be initialized, since the type of the object is erased and we cannot - // detect if the content is actually a [MaybeUninit]. However, this function - // should be only relevant for read buffer elements, which should not be uninit. - // If the element belongs to a write buffer, it is likely that [Self::capacity] - // is more appropriate. + /// Returns the initialized length of the element. pub fn len(&self) -> u32 { - match self { - BufferElem::Sized(sized) => mem::size_of_val(sized.as_ref()), - BufferElem::Vector(vec) => vec.len(), - } - .try_into() - .unwrap() + self.0.len().try_into().unwrap() } + /// Returns the allocated capacity of the element. pub fn capacity(&self) -> u32 { - match self { - BufferElem::Sized(sized) => mem::size_of_val(sized.as_ref()), - BufferElem::Vector(vec) => vec.capacity(), - } - .try_into() - .unwrap() + self.0.capacity().try_into().unwrap() } - pub fn addr(&self) -> *const u8 { - match self { - BufferElem::Sized(sized) => ptr::from_ref(sized.as_ref()).cast::(), - BufferElem::Vector(vec) => vec.as_ptr(), - } + /// Returns a pointer to the buffer. + pub fn as_ptr(&self) -> *const u8 { + self.0.as_ptr() + } + + /// Helper method to create a [`BufferElem`] that pre-allocates capacity + /// for a given element of type `T`. This ensures the buffer is aligned + /// to the same boundaries as `T`. + pub fn new_uninit() -> Self { + let uninit_mem = Box::::new_uninit_in(DeviceAlloc); + // SAFETY: Length is 0 because it's uninit, capacity matches the memory amount that the Box allocated. + // The pointer was allocated with the same allocator: DeviceAlloc. + let uninit_vec = unsafe { + Vec::from_raw_parts_in( + Box::into_raw(uninit_mem).cast(), + 0, + size_of::(), + DeviceAlloc, + ) + }; + Self(uninit_vec) + } +} + +impl From for BufferElem +where + T: IntoBytes + Immutable, +{ + fn from(value: T) -> Self { + Self(value.as_bytes().to_vec_in(DeviceAlloc)) } } @@ -419,46 +426,39 @@ pub(crate) struct UsedDeviceWritableBuffer { } impl UsedDeviceWritableBuffer { - pub fn pop_front_downcast(&mut self) -> Option> - where - T: Any, - { + pub fn pop_front_deserialize(&mut self) -> Option> { if self.remaining_written_len < u32::try_from(size_of::()).unwrap() { return None; } - let elem = self.elems.pop_front()?; - if let BufferElem::Sized(sized) = elem { - match sized.downcast::>() { - Ok(cast) => { - self.remaining_written_len -= u32::try_from(size_of::()).unwrap(); - Some(unsafe { cast.assume_init() }) - } - Err(sized) => { - self.elems.push_front(BufferElem::Sized(sized)); - None - } - } - } else { - self.elems.push_front(elem); - None - } + let BufferElem(buf) = self.elems.pop_front()?; + self.remaining_written_len -= u32::try_from(size_of::()).unwrap(); + + // Ensure the buffer is aligned to T. This is the case if it was created via + // [`BufferElem::new_uninit::`] and should always be the case, but since + // it is technically possible to construct an unaligned buffer and use that, + // we should check it. + assert!( + buf.as_ptr().addr() % align_of::() == 0, + "Attempted to deserialize buffer as type with different alignment" + ); + + // SAFETY: Management of the memory is transferred from the Vec to the Box + // Both heap allocations were made with the same alloc: DeviceAlloc + // The alignment was checked manually before. + Some(unsafe { Box::from_raw_in(buf.into_raw_parts().0.cast(), DeviceAlloc) }) } pub fn pop_front_vec(&mut self) -> Option> { - let elem = self.elems.pop_front()?; - if let BufferElem::Vector(mut vector) = elem { - let new_len = u32::min( - vector.capacity().try_into().unwrap(), - self.remaining_written_len, - ); - self.remaining_written_len -= new_len; - unsafe { vector.set_len(new_len.try_into().unwrap()) }; - Some(vector) - } else { - self.elems.push_front(elem); - None - } + let BufferElem(mut vector) = self.elems.pop_front()?; + let new_len = u32::min( + vector.capacity().try_into().unwrap(), + self.remaining_written_len, + ); + self.remaining_written_len -= new_len; + unsafe { vector.set_len(new_len.try_into().unwrap()) }; + + Some(vector) } } diff --git a/src/drivers/vsock/mod.rs b/src/drivers/vsock/mod.rs index 0879b4b574..2ebd4235fc 100644 --- a/src/drivers/vsock/mod.rs +++ b/src/drivers/vsock/mod.rs @@ -3,7 +3,6 @@ #[cfg(feature = "pci")] pub mod pci; -use alloc::boxed::Box; use alloc::vec::Vec; use core::mem; @@ -29,8 +28,8 @@ fn fill_queue(vq: &mut dyn Virtq, num_packets: u16, packet_size: u32) { let buff_tkn = match AvailBufferToken::new( vec![], vec![ - BufferElem::Sized(Box::::new_uninit_in(DeviceAlloc)), - BufferElem::Vector(Vec::with_capacity_in( + BufferElem::new_uninit::(), + BufferElem(Vec::with_capacity_in( packet_size.try_into().unwrap(), DeviceAlloc, )), @@ -99,7 +98,7 @@ impl RxQueue { while let Some(mut buffer_tkn) = self.get_next() { let header = buffer_tkn .used_recv_buff - .pop_front_downcast::() + .pop_front_deserialize::() .unwrap(); let packet = buffer_tkn.used_recv_buff.pop_front_vec().unwrap(); @@ -170,7 +169,7 @@ impl TxQueue { result }; - let buff_tkn = AvailBufferToken::new(vec![BufferElem::Vector(packet)], vec![]).unwrap(); + let buff_tkn = AvailBufferToken::new(vec![BufferElem(packet)], vec![]).unwrap(); vq.dispatch(buff_tkn, false, BufferType::Direct).unwrap(); diff --git a/src/fs/fuse.rs b/src/fs/fuse.rs index cc6bf349e7..19db685444 100644 --- a/src/fs/fuse.rs +++ b/src/fs/fuse.rs @@ -13,7 +13,7 @@ use async_lock::Mutex; use async_trait::async_trait; use fuse_abi::linux::*; use num_traits::FromPrimitive; -use zerocopy::FromBytes; +use zerocopy::{FromBytes, Immutable, IntoBytes}; use crate::alloc::string::ToString; #[cfg(not(feature = "pci"))] @@ -50,8 +50,8 @@ pub(crate) trait FuseInterface { rsp_payload_len: u32, ) -> Result, VirtqError> where - ::InStruct: Send, - ::OutStruct: Send; + ::InStruct: Send + IntoBytes + Immutable, + ::OutStruct: Send + FromBytes; fn get_mount_point(&self) -> String; } @@ -62,13 +62,14 @@ pub(crate) mod ops { use alloc::ffi::CString; use fuse_abi::linux::*; + use zerocopy::FromBytes; use super::Cmd; use crate::fd::PollEvent; use crate::fs::SeekWhence; #[repr(C)] - #[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq)] + #[derive(Debug, Default, Copy, Clone, Hash, PartialEq, Eq, FromBytes)] pub(crate) struct CreateOut { pub entry: fuse_entry_out, pub open: fuse_open_out, @@ -447,8 +448,8 @@ impl From for FileAttr { } } -#[repr(C)] -#[derive(Debug)] +#[repr(C, packed)] +#[derive(Debug, IntoBytes, Immutable)] pub(crate) struct CmdHeader { pub in_header: fuse_in_header, op_header: O::InStruct, @@ -484,7 +485,7 @@ impl CmdHeader { } pub(crate) struct Cmd { - pub headers: Box, DeviceAlloc>, + pub headers: CmdHeader, pub payload: Option>, } @@ -494,7 +495,7 @@ where { fn new(nodeid: u64, op_header: O::InStruct) -> Self { Self { - headers: Box::new_in(CmdHeader::new(nodeid, op_header), DeviceAlloc), + headers: CmdHeader::new(nodeid, op_header), payload: None, } } @@ -507,10 +508,7 @@ where fn with_cstring(nodeid: u64, op_header: O::InStruct, cstring: CString) -> Self { let cstring_bytes = cstring.into_bytes_with_nul().to_vec_in(DeviceAlloc); Self { - headers: Box::new_in( - CmdHeader::with_payload_size(nodeid, op_header, cstring_bytes.len()), - DeviceAlloc, - ), + headers: CmdHeader::with_payload_size(nodeid, op_header, cstring_bytes.len()), payload: Some(cstring_bytes), } } @@ -524,17 +522,14 @@ where let mut device_slice = Vec::with_capacity_in(slice.len(), DeviceAlloc); device_slice.extend_from_slice(&slice); Self { - headers: Box::new_in( - CmdHeader::with_payload_size(nodeid, op_header, slice.len()), - DeviceAlloc, - ), + headers: CmdHeader::with_payload_size(nodeid, op_header, slice.len()), payload: Some(device_slice), } } } #[repr(C)] -#[derive(Debug)] +#[derive(Debug, FromBytes)] pub(crate) struct RspHeader { out_header: fuse_out_header, op_header: O::OutStruct, diff --git a/src/lib.rs b/src/lib.rs index 205471e173..e78abdc69b 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -24,6 +24,7 @@ feature(specialization) )] #![feature(thread_local)] +#![feature(vec_into_raw_parts)] #![cfg_attr(target_os = "none", no_std)] #![cfg_attr(target_os = "none", feature(custom_test_frameworks))] #![cfg_attr(all(target_os = "none", test), test_runner(crate::test_runner))]