diff --git a/src/net.rs b/src/net.rs index 6ddf2e76..dc0f1b6a 100644 --- a/src/net.rs +++ b/src/net.rs @@ -72,6 +72,16 @@ impl AsyncFd { Recv::new(self, buf, flags) } + /// Shuts down the read, write, or both halves of this connection. + pub const fn shutdown<'fd>(&'fd self, how: std::net::Shutdown) -> Shutdown<'fd> { + let how = match how { + std::net::Shutdown::Read => libc::SHUT_RD, + std::net::Shutdown::Write => libc::SHUT_WR, + std::net::Shutdown::Both => libc::SHUT_RDWR, + }; + Shutdown::new(self, how) + } + /// Accept a new socket stream ([`AsyncFd`]). /// /// If an accepted stream is returned, the remote address of the peer is @@ -198,6 +208,19 @@ op_future! { }, } +// Shutdown. +op_future! { + fn AsyncFd::shutdown -> (), + struct Shutdown<'fd> { + // Doesn't need any fields. + }, + setup_state: flags: libc::c_int, + setup: |submission, fd, (), how| unsafe { + submission.shutdown(fd.fd, how); + }, + map_result: |n| Ok(debug_assert!(n == 0)), +} + // Accept. op_future! { fn AsyncFd::accept -> (AsyncFd, SocketAddr), diff --git a/src/op.rs b/src/op.rs index 44a12293..d705f236 100644 --- a/src/op.rs +++ b/src/op.rs @@ -294,6 +294,12 @@ impl Submission { self.inner.len = len; } + pub(crate) unsafe fn shutdown(&mut self, fd: RawFd, how: libc::c_int) { + self.inner.opcode = libc::IORING_OP_SHUTDOWN as u8; + self.inner.fd = fd; + self.inner.len = how as u32; + } + /// Create a accept submission starting. /// /// Avaialable since Linux kernel 5.5. diff --git a/tests/net.rs b/tests/net.rs index 245a9614..205d7c2c 100644 --- a/tests/net.rs +++ b/tests/net.rs @@ -4,7 +4,7 @@ use std::io::{Read, Write}; use std::mem; -use std::net::{SocketAddr, SocketAddrV4, TcpListener, TcpStream}; +use std::net::{Shutdown, SocketAddr, SocketAddrV4, TcpListener, TcpStream}; use std::pin::Pin; use a10::io::ReadBufPool; @@ -414,6 +414,38 @@ fn send_zc_extractor() { assert_eq!(&buf[0..n], DATA2); } +#[test] +fn shutdown() { + let sq = test_queue(); + let waker = Waker::new(); + + // Bind a socket. + let listener = TcpListener::bind("127.0.0.1:0").expect("failed to bind listener"); + let local_addr = match listener.local_addr().unwrap() { + SocketAddr::V4(addr) => addr, + _ => unreachable!(), + }; + + // Create a socket and connect the listener. + let stream = waker.block_on(tcp_ipv4_socket(sq)); + let addr = addr_storage(&local_addr); + let addr_len = mem::size_of::() as libc::socklen_t; + let mut connect_future = stream.connect(addr, addr_len); + // Poll the future to schedule the operation. + assert!(poll_nop(Pin::new(&mut connect_future)).is_pending()); + + let (mut client, _) = listener.accept().expect("failed to accept connection"); + + waker.block_on(connect_future).expect("failed to connect"); + + waker + .block_on(stream.shutdown(Shutdown::Write)) + .expect("failed to shutdown"); + let mut buf = vec![0; 10]; + let n = client.read(&mut buf).expect("failed to send data"); + assert_eq!(n, 0); +} + fn addr_storage(addres: &SocketAddrV4) -> libc::sockaddr_storage { // SAFETY: zeroed out `sockaddr_storage` is valid. let mut addr: libc::sockaddr_storage = unsafe { mem::zeroed() };