From d6a3d1c9042dbf3135fa793d67166bcd501007c2 Mon Sep 17 00:00:00 2001 From: Bruno Dal Bo Date: Mon, 13 Jun 2022 11:49:29 -0700 Subject: [PATCH] Initial basic cmsg support for unix Fixes #313 --- src/cmsg.rs | 359 ++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 5 + src/socket.rs | 72 ++++++++++ src/sys/unix.rs | 38 +++-- tests/socket.rs | 57 ++++++++ 5 files changed, 522 insertions(+), 9 deletions(-) create mode 100644 src/cmsg.rs diff --git a/src/cmsg.rs b/src/cmsg.rs new file mode 100644 index 00000000..b10122a0 --- /dev/null +++ b/src/cmsg.rs @@ -0,0 +1,359 @@ +use std::convert::TryInto as _; +use std::io::IoSlice; + +#[derive(Debug, Clone)] +struct MsgHdrWalker { + buffer: B, + position: Option, +} + +impl> MsgHdrWalker { + fn next_ptr(&mut self) -> Option<*const libc::cmsghdr> { + // Build a msghdr so we can use the functionality in libc. + let mut msghdr: libc::msghdr = unsafe { std::mem::zeroed() }; + let buffer = self.buffer.as_ref(); + // SAFETY: We're giving msghdr a mutable pointer to comply with the C + // API. We'll only allow mutation of `cmsghdr`, however if `B` is + // AsMut<[u8]>. + msghdr.msg_control = buffer.as_ptr() as *mut _; + msghdr.msg_controllen = buffer.len().try_into().expect("buffer is too long"); + + let nxt_hdr = if let Some(position) = self.position { + if position >= buffer.len() { + return None; + } + let cur_hdr = &buffer[position] as *const u8 as *const _; + // Safety: msghdr is a valid pointer and cur_hdr is not null. + unsafe { libc::CMSG_NXTHDR(&msghdr, cur_hdr) } + } else { + // Safety: msghdr is a valid pointer. + unsafe { libc::CMSG_FIRSTHDR(&msghdr) } + }; + + if nxt_hdr.is_null() { + self.position = Some(buffer.len()); + return None; + } + + // SAFETY: nxt_hdr always points to data within the buffer, they must be + // part of the same allocation. + let distance = unsafe { (nxt_hdr as *const u8).offset_from(buffer.as_ptr()) }; + // nxt_hdr is always ahead of the buffer and not null if we're here, + // meaning the distance is always positive. + self.position = Some(distance.try_into().unwrap()); + Some(nxt_hdr) + } + + fn next(&mut self) -> Option<(&libc::cmsghdr, &[u8])> { + self.next_ptr().map(|cmsghdr| { + // SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`. + let data = unsafe { libc::CMSG_DATA(cmsghdr) }; + let cmsghdr = unsafe { &*cmsghdr }; + // SAFETY: data points to buffer and is controlled by control + // message length. + let data = unsafe { + std::slice::from_raw_parts( + data, + (cmsghdr.cmsg_len as usize) + .saturating_sub(std::mem::size_of::()), + ) + }; + (cmsghdr, data) + }) + } +} + +impl + AsMut<[u8]>> MsgHdrWalker { + fn next_mut(&mut self) -> Option<(&mut libc::cmsghdr, &mut [u8])> { + match self.next_ptr() { + Some(cmsghdr) => { + // SAFETY: cmsghdr is a valid pointer given to us by `next_ptr`. + let data = unsafe { libc::CMSG_DATA(cmsghdr) }; + // SAFETY: The mutable pointer is safe because we're not going to + // vend any concurrent access to the same memory region and B is + // AsMut<[u8]> guaranteeing we have exclusive access to the buffer. + let cmsghdr = cmsghdr as *mut libc::cmsghdr; + let cmsghdr = unsafe { &mut *cmsghdr }; + + // We'll always yield the entirety of the rest of the buffer. + let distance = unsafe { data.offset_from(self.buffer.as_ref().as_ptr()) }; + // The data pointer is always part of the buffer, can't be before + // it. + let distance: usize = distance.try_into().unwrap(); + Some((cmsghdr, &mut self.buffer.as_mut()[distance..])) + } + None => None, + } + } +} + +/// A wrapper around a buffer that can be used to write ancillary control +/// messages. +#[derive(Debug)] +pub struct CmsgWriter<'a> { + walker: MsgHdrWalker<&'a mut [u8]>, + last_push: usize, +} + +impl<'a> CmsgWriter<'a> { + /// Creates a new [`CmsgBuffer`] backed by the bytes in `buffer`. + pub fn new(buffer: &'a mut [u8]) -> Self { + Self { + walker: MsgHdrWalker { + buffer, + position: None, + }, + last_push: 0, + } + } + + /// Pushes a new control message `m` to the buffer. + /// + /// # Panics + /// + /// Panics if the contained buffer does not have enough space to fit `m`. + pub fn push(&mut self, m: &Cmsg) { + let (cmsg_level, cmsg_type, size) = m.level_type_size(); + let (nxt_hdr, data) = self + .walker + .next_mut() + .unwrap_or_else(|| panic!("can't fit message {:?}", m)); + // Safety: All values are passed by copy. + let cmsg_len = unsafe { libc::CMSG_LEN(size) }.try_into().unwrap(); + nxt_hdr.cmsg_len = cmsg_len; + nxt_hdr.cmsg_level = cmsg_level; + nxt_hdr.cmsg_type = cmsg_type; + m.write(&mut data[..size as usize]); + // Always store the space required for the last push because the walker + // maintains its position cursor at the currently written option, we + // must always add the space for the last control message when returning + // the consolidated buffer. + self.last_push = unsafe { libc::CMSG_SPACE(size) } as usize; + } + + pub(crate) fn io_slice(&self) -> IoSlice<'_> { + IoSlice::new(self.buffer()) + } + + pub(crate) fn buffer(&self) -> &[u8] { + if let Some(position) = self.walker.position { + &self.walker.buffer.as_ref()[..position + self.last_push] + } else { + &[] + } + } +} + +impl<'a, C: std::borrow::Borrow> Extend for CmsgWriter<'a> { + fn extend>(&mut self, iter: T) { + for cmsg in iter { + self.push(cmsg.borrow()) + } + } +} + +/// An iterator over received control messages. +#[derive(Debug, Clone)] +pub struct CmsgIter<'a> { + walker: MsgHdrWalker<&'a [u8]>, +} + +impl<'a> CmsgIter<'a> { + pub(crate) fn new(buffer: &'a [u8]) -> Self { + Self { + walker: MsgHdrWalker { + buffer, + position: None, + }, + } + } +} + +impl<'a> Iterator for CmsgIter<'a> { + type Item = Cmsg; + + fn next(&mut self) -> Option { + self.walker.next().map( + |( + libc::cmsghdr { + cmsg_len: _, + cmsg_level, + cmsg_type, + .. + }, + data, + )| Cmsg::from_raw(*cmsg_level, *cmsg_type, data), + ) + } +} + +/// An unknown control message. +#[derive(Debug, Eq, PartialEq)] +pub struct UnknownCmsg { + cmsg_level: libc::c_int, + cmsg_type: libc::c_int, +} + +/// Control messages. +#[derive(Debug, Eq, PartialEq)] +pub enum Cmsg { + /// The `IP_TOS` control message. + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + IpTos(u8), + /// The `IPV6_PKTINFO` control message. + #[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))] + Ipv6PktInfo { + /// The address the packet is destined to/received from. Equivalent to + /// `in6_pktinfo.ipi6_addr`. + addr: std::net::Ipv6Addr, + /// The interface index the packet is destined to/received from. + /// Equivalent to `in6_pktinfo.ipi6_ifindex`. + ifindex: u32, + }, + /// An unrecognized control message. + Unknown(UnknownCmsg), +} + +impl Cmsg { + /// Returns the amount of buffer space required to hold this option. + pub fn space(&self) -> usize { + let (_, _, size) = self.level_type_size(); + // Safety: All values are passed by copy. + let size = unsafe { libc::CMSG_SPACE(size) }; + size as usize + } + + fn level_type_size(&self) -> (libc::c_int, libc::c_int, libc::c_uint) { + match self { + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + Cmsg::IpTos(_) => ( + libc::IPPROTO_IP, + libc::IP_TOS, + std::mem::size_of::() as libc::c_uint, + ), + #[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))] + Cmsg::Ipv6PktInfo { .. } => ( + libc::IPPROTO_IPV6, + libc::IPV6_PKTINFO, + std::mem::size_of::() as libc::c_uint, + ), + Cmsg::Unknown(UnknownCmsg { + cmsg_level, + cmsg_type, + }) => (*cmsg_level, *cmsg_type, 0), + } + } + + fn write(&self, buffer: &mut [u8]) { + match self { + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + Cmsg::IpTos(tos) => { + buffer[0] = *tos; + } + #[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))] + Cmsg::Ipv6PktInfo { addr, ifindex } => { + let pktinfo = libc::in6_pktinfo { + ipi6_addr: crate::sys::to_in6_addr(addr), + ipi6_ifindex: *ifindex as _, + }; + let size = std::mem::size_of::(); + assert_eq!(buffer.len(), size); + // Safety: `pktinfo` is valid for reads for its size in bytes. + // `buffer` is valid for write for the same length, as + // guaranteed by the assertion above. Copy unit is byte, so + // alignment is okay. The two regions do not overlap. + unsafe { + std::ptr::copy_nonoverlapping( + &pktinfo as *const libc::in6_pktinfo as *const _, + buffer.as_mut_ptr(), + size, + ) + } + } + Cmsg::Unknown(_) => { + // NOTE: We don't actually allow users of the public API + // serialize unknown control messages, but we use this code path + // for testing. + debug_assert_eq!(buffer.len(), 0); + } + } + } + + fn from_raw(cmsg_level: libc::c_int, cmsg_type: libc::c_int, bytes: &[u8]) -> Self { + match (cmsg_level, cmsg_type) { + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + (libc::IPPROTO_IP, libc::IP_TOS) => { + assert_eq!(bytes.len(), std::mem::size_of::(), "{:?}", bytes); + Cmsg::IpTos(bytes[0]) + } + #[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))] + (libc::IPPROTO_IPV6, libc::IPV6_PKTINFO) => { + let mut pktinfo = unsafe { std::mem::zeroed::() }; + let size = std::mem::size_of::(); + assert!(bytes.len() >= size, "{:?}", bytes); + // Safety: `pktinfo` is valid for writes for its size in bytes. + // `buffer` is valid for read for the same length, as + // guaranteed by the assertion above. Copy unit is byte, so + // alignment is okay. The two regions do not overlap. + unsafe { + std::ptr::copy_nonoverlapping( + bytes.as_ptr(), + &mut pktinfo as *mut libc::in6_pktinfo as *mut _, + size, + ) + } + Cmsg::Ipv6PktInfo { + addr: crate::sys::from_in6_addr(pktinfo.ipi6_addr), + ifindex: pktinfo.ipi6_ifindex as _, + } + } + (cmsg_level, cmsg_type) => { + let _ = bytes; + Cmsg::Unknown(UnknownCmsg { + cmsg_level, + cmsg_type, + }) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn ser_deser() { + let cmsgs = [ + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + Cmsg::IpTos(2), + #[cfg(not(any(target_os = "fuchsia", target_os = "solaris", target_os = "illumos")))] + Cmsg::Ipv6PktInfo { + addr: std::net::Ipv6Addr::new(1, 2, 3, 4, 5, 6, 7, 8), + ifindex: 13, + }, + Cmsg::Unknown(UnknownCmsg { + cmsg_level: 12345678, + cmsg_type: 87654321, + }), + ]; + let mut buffer = [0u8; 256]; + let mut writer = CmsgWriter::new(&mut buffer[..]); + writer.extend(cmsgs.iter()); + let deser = CmsgIter::new(writer.buffer()).collect::>(); + assert_eq!(&cmsgs[..], &deser[..]); + } + + #[test] + #[should_panic] + #[cfg(not(any(target_os = "solaris", target_os = "illumos")))] + fn ser_insufficient_space_panics() { + let mut buffer = CmsgWriter::new(&mut []); + buffer.push(&Cmsg::IpTos(2)); + } + + #[test] + fn empty_deser() { + assert_eq!(CmsgIter::new(&[]).next(), None); + } +} diff --git a/src/lib.rs b/src/lib.rs index d9786260..3adfac7e 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -115,6 +115,8 @@ macro_rules! from { }; } +#[cfg(all(unix, not(target_os = "redox")))] +mod cmsg; mod sockaddr; mod socket; mod sockref; @@ -141,6 +143,9 @@ pub use sockref::SockRef; )))] pub use socket::InterfaceIndexOrAddress; +#[cfg(all(unix, not(target_os = "redox")))] +pub use cmsg::{Cmsg, CmsgIter, CmsgWriter}; + /// Specification of the communication domain for a socket. /// /// This is a newtype wrapper around an integer which provides a nicer API in diff --git a/src/socket.rs b/src/socket.rs index c5cca84f..7118c98c 100644 --- a/src/socket.rs +++ b/src/socket.rs @@ -19,6 +19,8 @@ use std::os::windows::io::{FromRawSocket, IntoRawSocket}; use std::time::Duration; use crate::sys::{self, c_int, getsockopt, setsockopt, Bool}; +#[cfg(all(unix, not(target_os = "redox")))] +use crate::{CmsgIter, CmsgWriter}; use crate::{Domain, Protocol, SockAddr, TcpKeepalive, Type}; #[cfg(not(target_os = "redox"))] use crate::{MaybeUninitSlice, RecvFlags}; @@ -660,6 +662,76 @@ impl Socket { ) -> io::Result { sys::send_to_vectored(self.as_raw(), bufs, addr, flags) } + + /// Sends data on the socket to the connected peer accompanied by ancillary + /// control message data. + #[cfg(all(unix, not(target_os = "redox")))] + pub fn send_msg( + &self, + bufs: &[IoSlice<'_>], + cmsg: &CmsgWriter<'_>, + flags: c_int, + ) -> io::Result { + sys::sendmsg( + self.as_raw(), + std::ptr::null(), + 0, + bufs, + cmsg.io_slice(), + flags, + ) + } + + /// Sends data on the socket to the given address accompanied by ancillary + /// control message data. + #[cfg(all(unix, not(target_os = "redox")))] + pub fn send_msg_to( + &self, + addr: &SockAddr, + bufs: &[IoSlice<'_>], + cmsg: &CmsgWriter<'_>, + flags: c_int, + ) -> io::Result { + sys::sendmsg( + self.as_raw(), + addr.as_storage_ptr(), + addr.len(), + bufs, + cmsg.io_slice(), + flags, + ) + } + + /// Receives data on the socket accompanied by ancillary control message data. + #[cfg(all(unix, not(target_os = "redox")))] + pub fn recv_msg( + &self, + bufs: &mut [MaybeUninitSlice<'_>], + control_data: &mut [MaybeUninit], + flags: c_int, + ) -> io::Result<(usize, SockAddr, CmsgIter<'_>, RecvFlags)> { + // Safety: `recvmsg` initialises the address storage and we set the length + // manually. + unsafe { + SockAddr::init(|storage, len| { + sys::recvmsg(self.as_raw(), storage, bufs, control_data, flags).map( + |(n, addrlen, cmsg_len, recv_flags)| { + // Set the correct address length. + *len = addrlen; + // Safety: Slice was initialized up to cmsg_len by + // recvmsg. + let ptr = std::slice::from_raw_parts( + control_data.as_ptr() as *const u8, + cmsg_len, + ); + let cmsg = CmsgIter::new(ptr); + (n, cmsg, recv_flags) + }, + ) + }) + } + .map(|((n, cmsg, recv_flags), addr)| (n, addr, cmsg, recv_flags)) + } } /// Set `SOCK_CLOEXEC` and `NO_HANDLE_INHERIT` on the `ty`pe on platforms that diff --git a/src/sys/unix.rs b/src/sys/unix.rs index 13e09662..577a7ffd 100644 --- a/src/sys/unix.rs +++ b/src/sys/unix.rs @@ -743,7 +743,7 @@ pub(crate) fn recv_vectored( bufs: &mut [crate::MaybeUninitSlice<'_>], flags: c_int, ) -> io::Result<(usize, RecvFlags)> { - recvmsg(fd, ptr::null_mut(), bufs, flags).map(|(n, _, recv_flags)| (n, recv_flags)) + recvmsg(fd, ptr::null_mut(), bufs, &mut [], flags).map(|(n, _, _, recv_flags)| (n, recv_flags)) } #[cfg(not(target_os = "redox"))] @@ -756,7 +756,7 @@ pub(crate) fn recv_from_vectored( // manually. unsafe { SockAddr::init(|storage, len| { - recvmsg(fd, storage, bufs, flags).map(|(n, addrlen, recv_flags)| { + recvmsg(fd, storage, bufs, &mut [], flags).map(|(n, addrlen, _, recv_flags)| { // Set the correct address length. *len = addrlen; (n, recv_flags) @@ -768,12 +768,13 @@ pub(crate) fn recv_from_vectored( /// Returns the (bytes received, sending address len, `RecvFlags`). #[cfg(not(target_os = "redox"))] -fn recvmsg( +pub(crate) fn recvmsg( fd: Socket, msg_name: *mut sockaddr_storage, bufs: &mut [crate::MaybeUninitSlice<'_>], + control_data: &mut [MaybeUninit], flags: c_int, -) -> io::Result<(usize, libc::socklen_t, RecvFlags)> { +) -> io::Result<(usize, libc::socklen_t, libc::size_t, RecvFlags)> { let msg_namelen = if msg_name.is_null() { 0 } else { @@ -785,8 +786,16 @@ fn recvmsg( msg.msg_namelen = msg_namelen; msg.msg_iov = bufs.as_mut_ptr().cast(); msg.msg_iovlen = min(bufs.len(), IovLen::MAX as usize) as IovLen; - syscall!(recvmsg(fd, &mut msg, flags)) - .map(|n| (n as usize, msg.msg_namelen, RecvFlags(msg.msg_flags))) + msg.msg_control = control_data.as_mut_ptr().cast(); + msg.msg_controllen = control_data.len() as _; + syscall!(recvmsg(fd, &mut msg, flags)).map(|n| { + ( + n as usize, + msg.msg_namelen, + msg.msg_controllen as libc::size_t, + RecvFlags(msg.msg_flags), + ) + }) } pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result { @@ -801,7 +810,7 @@ pub(crate) fn send(fd: Socket, buf: &[u8], flags: c_int) -> io::Result { #[cfg(not(target_os = "redox"))] pub(crate) fn send_vectored(fd: Socket, bufs: &[IoSlice<'_>], flags: c_int) -> io::Result { - sendmsg(fd, ptr::null(), 0, bufs, flags) + sendmsg(fd, ptr::null(), 0, bufs, IoSlice::new(&[]), flags) } pub(crate) fn send_to(fd: Socket, buf: &[u8], addr: &SockAddr, flags: c_int) -> io::Result { @@ -823,16 +832,24 @@ pub(crate) fn send_to_vectored( addr: &SockAddr, flags: c_int, ) -> io::Result { - sendmsg(fd, addr.as_storage_ptr(), addr.len(), bufs, flags) + sendmsg( + fd, + addr.as_storage_ptr(), + addr.len(), + bufs, + IoSlice::new(&[]), + flags, + ) } /// Returns the (bytes received, sending address len, `RecvFlags`). #[cfg(not(target_os = "redox"))] -fn sendmsg( +pub(crate) fn sendmsg( fd: Socket, msg_name: *const sockaddr_storage, msg_namelen: socklen_t, bufs: &[IoSlice<'_>], + control_data: IoSlice<'_>, flags: c_int, ) -> io::Result { // libc::msghdr contains unexported padding fields on Fuchsia. @@ -845,6 +862,9 @@ fn sendmsg( // Safety: Same as above about `*const` -> `*mut`. msg.msg_iov = bufs.as_ptr() as *mut _; msg.msg_iovlen = min(bufs.len(), IovLen::MAX as usize) as IovLen; + // Safety: Same as above about `*const` -> `*mut`. + msg.msg_control = control_data.as_ptr() as *mut _; + msg.msg_controllen = control_data.len() as _; syscall!(sendmsg(fd, &msg, flags)).map(|n| n as usize) } diff --git a/tests/socket.rs b/tests/socket.rs index 2d6cafe6..805ff098 100644 --- a/tests/socket.rs +++ b/tests/socket.rs @@ -1265,3 +1265,60 @@ fn header_included() { let got = socket.header_included().expect("failed to get value"); assert_eq!(got, true, "set and get values differ"); } + +#[test] +#[cfg(all( + unix, + not(any( + target_os = "fuchsia", + target_os = "solaris", + target_os = "illumos", + target_os = "netbsd", + target_os = "redox", + )) +))] +fn sendmsg_recvmsg() { + use socket2::{Cmsg, CmsgWriter}; + + let receiver = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap(); + receiver.bind(&any_ipv4()).unwrap(); + let receiver_addr = receiver.local_addr().unwrap(); + receiver.set_recv_tos(true).unwrap(); + + let sender = Socket::new(Domain::IPV4, Type::DGRAM, Some(Protocol::UDP)).unwrap(); + sender.bind(&any_ipv4()).unwrap(); + let sender_addr = sender.local_addr().unwrap(); + + let data = "hello".as_bytes(); + + let input_messages = [Cmsg::IpTos(16)]; + let mut cmsg_buffer = [0; 32]; + let mut cmsg_writer = CmsgWriter::new(&mut cmsg_buffer[..]); + cmsg_writer.extend(input_messages.iter()); + let sent = sender + .send_msg_to( + &receiver_addr, + &[IoSlice::new(&data[..])][..], + &cmsg_writer, + 0, + ) + .unwrap(); + assert_eq!(sent, data.len()); + + let mut hello = [MaybeUninit::new(10); 10]; + let mut control_data = [MaybeUninit::uninit(); 32]; + let (read, from, cmsg, flags) = receiver + .recv_msg( + &mut [MaybeUninitSlice::new(&mut hello)], + &mut control_data[..], + 0, + ) + .unwrap(); + assert_eq!(read, data.len()); + assert_eq!( + from.as_socket_ipv4().unwrap(), + sender_addr.as_socket_ipv4().unwrap() + ); + assert_eq!(flags.is_truncated(), false); + assert_eq!(cmsg.collect::>(), &input_messages[..]); +}