Skip to content

Commit

Permalink
Don't leak fds on error in UDS on Unix
Browse files Browse the repository at this point in the history
We do this by simply creating a fd-managing type, e.g. UnixDatagram,
from the fd once it's created.

Also first tries to parse the path as that can fail without doing a
system call.
  • Loading branch information
Thomasdezeeuw committed Nov 30, 2022
1 parent 0accf7d commit ef0fe1d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 23 deletions.
7 changes: 4 additions & 3 deletions src/sys/unix/uds/datagram.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,13 @@ use std::os::unix::net;
use std::path::Path;

pub(crate) fn bind(path: &Path) -> io::Result<net::UnixDatagram> {
let fd = new_socket(libc::AF_UNIX, libc::SOCK_DGRAM)?;
// Ensure the fd is closed.
let socket = unsafe { net::UnixDatagram::from_raw_fd(fd) };
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const _;

let fd = new_socket(libc::AF_UNIX, libc::SOCK_DGRAM)?;
let socket = unsafe { net::UnixDatagram::from_raw_fd(fd) };
syscall!(bind(fd, sockaddr, socklen))?;

Ok(socket)
}

Expand Down
16 changes: 6 additions & 10 deletions src/sys/unix/uds/listener.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,19 +7,15 @@ use std::path::Path;
use std::{io, mem};

pub(crate) fn bind(path: &Path) -> io::Result<net::UnixListener> {
let socket = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr;

syscall!(bind(socket, sockaddr, socklen))
.and_then(|_| syscall!(listen(socket, 1024)))
.map_err(|err| {
// Close the socket if we hit an error, ignoring the error from
// closing since we can't pass back two errors.
let _ = unsafe { libc::close(socket) };
err
})
.map(|_| unsafe { net::UnixListener::from_raw_fd(socket) })
let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let socket = unsafe { net::UnixListener::from_raw_fd(fd) };
syscall!(bind(fd, sockaddr, socklen))?;
syscall!(listen(fd, 1024))?;

Ok(socket)
}

pub(crate) fn accept(listener: &net::UnixListener) -> io::Result<(UnixStream, SocketAddr)> {
Expand Down
15 changes: 5 additions & 10 deletions src/sys/unix/uds/stream.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,18 @@ use std::os::unix::net;
use std::path::Path;

pub(crate) fn connect(path: &Path) -> io::Result<net::UnixStream> {
let socket = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let (sockaddr, socklen) = socket_addr(path)?;
let sockaddr = &sockaddr as *const libc::sockaddr_un as *const libc::sockaddr;

match syscall!(connect(socket, sockaddr, socklen)) {
let fd = new_socket(libc::AF_UNIX, libc::SOCK_STREAM)?;
let socket = unsafe { net::UnixStream::from_raw_fd(fd) };
match syscall!(connect(fd, sockaddr, socklen)) {
Ok(_) => {}
Err(ref err) if err.raw_os_error() == Some(libc::EINPROGRESS) => {}
Err(e) => {
// Close the socket if we hit an error, ignoring the error
// from closing since we can't pass back two errors.
let _ = unsafe { libc::close(socket) };

return Err(e);
}
Err(e) => return Err(e),
}

Ok(unsafe { net::UnixStream::from_raw_fd(socket) })
Ok(socket)
}

pub(crate) fn pair() -> io::Result<(net::UnixStream, net::UnixStream)> {
Expand Down

0 comments on commit ef0fe1d

Please sign in to comment.