Skip to content

Commit

Permalink
add support for normal Unix domain socket for Linux
Browse files Browse the repository at this point in the history
To let Linux os use normal Unix domain socket, this commit
extends the Domain enum and control the socket type by
the address pattern passed in:

vsock://xxx:xx          -> Vsock
unix:///run/aaa.sock    -> normal Unix domain sock
unix://@/run/abc.sock   -> abstract Unix sock
other:///run/d.sock     -> error

Fixes: containerd#115

Signed-off-by: bin liu <[email protected]>
  • Loading branch information
liubin committed Dec 21, 2021
1 parent 465572d commit c0d98aa
Showing 1 changed file with 118 additions and 22 deletions.
140 changes: 118 additions & 22 deletions src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,12 @@ use nix::fcntl::{fcntl, FcntlArg, FdFlag, OFlag};
use nix::sys::socket::*;
use std::os::unix::io::RawFd;

#[derive(Debug)]
#[derive(Debug, Clone, Copy, PartialEq)]
pub enum Domain {
Unix,
#[cfg(target_os = "linux")]
AbstractUnix,
#[cfg(target_os = "linux")]
Vsock,
}

Expand All @@ -39,20 +41,38 @@ pub fn do_listen(listener: RawFd) -> Result<()> {
listen(listener, 10).map_err(|e| Error::Socket(e.to_string()))
}

pub fn parse_host(host: &str) -> Result<(Domain, Vec<&str>)> {
pub fn parse_host(host: &str) -> Result<(Domain, &str)> {
let hostv: Vec<&str> = host.trim().split("://").collect();
if hostv.len() != 2 {
return Err(Error::Others(format!("Host {} is not right", host)));
}

let addr = hostv[1];
if addr.is_empty() {
return Err(Error::Others(format!("address {} is empty", addr)));
}

let domain = match &hostv[0].to_lowercase()[..] {
"unix" => Domain::Unix,
"unix" if !addr.starts_with('@') => Domain::Unix,
#[cfg(not(target_os = "linux"))]
"unix" if addr.starts_with('@') => {
return Err(Error::Others(
"Abstract socket is not supported".to_string(),
))
}
#[cfg(target_os = "linux")]
"unix" if addr.starts_with('@') => Domain::AbstractUnix,
#[cfg(target_os = "linux")]
"vsock" => Domain::Vsock,
x => return Err(Error::Others(format!("Scheme {:?} is not supported", x))),
};

Ok((domain, hostv))
#[cfg(target_os = "linux")]
if domain == Domain::AbstractUnix {
return Ok((domain, &addr[1..]));
}

Ok((domain, addr))
}

pub fn set_fd_close_exec(fd: RawFd) -> Result<RawFd> {
Expand All @@ -72,37 +92,47 @@ pub(crate) const SOCK_CLOEXEC: SockFlag = SockFlag::SOCK_CLOEXEC;
pub(crate) const SOCK_CLOEXEC: SockFlag = SockFlag::empty();

#[cfg(target_os = "linux")]
fn make_addr(host: &str) -> Result<UnixAddr> {
UnixAddr::new_abstract(host.as_bytes()).map_err(err_to_others_err!(e, ""))
fn make_addr(domain: Domain, host: &str) -> Result<UnixAddr> {
match domain {
Domain::Unix => UnixAddr::new(host).map_err(err_to_others_err!(e, "")),
Domain::AbstractUnix => {
UnixAddr::new_abstract(host.as_bytes()).map_err(err_to_others_err!(e, ""))
}
Domain::Vsock => Err(Error::Others(
"function make_addr does not support create vsock socket".to_string(),
)),
}
}

#[cfg(not(target_os = "linux"))]
fn make_addr(host: &str) -> Result<UnixAddr> {
fn make_addr(_domain: Domain, host: &str) -> Result<UnixAddr> {
UnixAddr::new(host).map_err(err_to_others_err!(e, ""))
}

fn make_socket(addr: (&str, u32)) -> Result<(RawFd, Domain, SockAddr)> {
let (host, _) = addr;
let (domain, hostv) = parse_host(host)?;

let sockaddr: SockAddr;
let fd: RawFd;
let get_sock_addr = |domain, host| -> Result<(RawFd, SockAddr)> {
let fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)
.map_err(|e| Error::Socket(e.to_string()))?;

match domain {
Domain::Unix => {
fd = socket(AddressFamily::Unix, SockType::Stream, SOCK_CLOEXEC, None)
.map_err(|e| Error::Socket(e.to_string()))?;
// MacOS doesn't support atomic creation of a socket descriptor with SOCK_CLOEXEC flag,
// so there is a chance of leak if fork + exec happens in between of these calls.
#[cfg(target_os = "macos")]
set_fd_close_exec(fd)?;

// MacOS doesn't support atomic creation of a socket descriptor with SOCK_CLOEXEC flag,
// so there is a chance of leak if fork + exec happens in between of these calls.
#[cfg(target_os = "macos")]
set_fd_close_exec(fd)?;
let sockaddr = SockAddr::Unix(make_addr(domain, host)?);
Ok((fd, sockaddr))
};

sockaddr = SockAddr::Unix(make_addr(hostv[1])?);
}
let (fd, sockaddr) = match domain {
Domain::Unix => get_sock_addr(domain, hostv)?,
#[cfg(target_os = "linux")]
Domain::AbstractUnix => get_sock_addr(domain, hostv)?,
#[cfg(target_os = "linux")]
Domain::Vsock => {
let host_port_v: Vec<&str> = hostv[1].split(':').collect();
let host_port_v: Vec<&str> = hostv.split(':').collect();
if host_port_v.len() != 2 {
return Err(Error::Others(format!(
"Host {} is not right for vsock",
Expand All @@ -112,15 +142,16 @@ fn make_socket(addr: (&str, u32)) -> Result<(RawFd, Domain, SockAddr)> {
let port: u32 = host_port_v[1]
.parse()
.expect("the vsock port is not an number");
fd = socket(
let fd = socket(
AddressFamily::Vsock,
SockType::Stream,
SockFlag::SOCK_CLOEXEC,
None,
)
.map_err(|e| Error::Socket(e.to_string()))?;
let cid = addr.1;
sockaddr = SockAddr::new_vsock(cid, port);
let sockaddr = SockAddr::new_vsock(cid, port);
(fd, sockaddr)
}
};

Expand Down Expand Up @@ -180,3 +211,68 @@ pub const MESSAGE_LENGTH_MAX: usize = 4 << 20;

pub const MESSAGE_TYPE_REQUEST: u8 = 0x1;
pub const MESSAGE_TYPE_RESPONSE: u8 = 0x2;

#[cfg(test)]
mod tests {
use super::parse_host;
use super::Domain;

#[cfg(target_os = "linux")]
#[test]
fn test_parse_host() {
for i in &[
(
"unix:///run/a.sock",
Some(Domain::Unix),
"/run/a.sock",
true,
),
("vsock://8:1024", Some(Domain::Vsock), "8:1024", true),
("Vsock://8:1025", Some(Domain::Vsock), "8:1025", true),
(
"unix://@/run/b.sock",
Some(Domain::AbstractUnix),
"/run/b.sock",
true,
),
("abc:///run/c.sock", None, "", false),
] {
let (input, domain, addr, success) = (i.0, i.1, i.2, i.3);
let r = parse_host(input);
if success {
let (rd, ra) = r.unwrap();
assert_eq!(rd, domain.unwrap());
assert_eq!(ra, addr);
} else {
assert!(r.is_err());
}
}
}

#[cfg(not(target_os = "linux"))]
#[test]
fn test_parse_host() {
for i in &[
(
"unix:///run/a.sock",
Some(Domain::Unix),
"/run/a.sock",
true,
),
("vsock:///run/c.sock", None, "", false),
("Vsock:///run/c.sock", None, "", false),
("unix://@/run/b.sock", None, "", false),
("abc:///run/c.sock", None, "", false),
] {
let (input, domain, addr, success) = (i.0, i.1, i.2, i.3);
let r = parse_host(input);
if success {
let (rd, ra) = r.unwrap();
assert_eq!(rd, domain.unwrap());
assert_eq!(ra, addr);
} else {
assert!(r.is_err());
}
}
}
}

0 comments on commit c0d98aa

Please sign in to comment.