Skip to content

Commit

Permalink
Adapt talpid-windows-net to windows-sys 0.48
Browse files Browse the repository at this point in the history
  • Loading branch information
faern committed Aug 8, 2023
1 parent 7f948ba commit e350149
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 109 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

70 changes: 70 additions & 0 deletions talpid-types/src/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
use std::{error::Error, fmt, fmt::Write};

/// Used to generate string representations of error chains.
pub trait ErrorExt {
/// Creates a string representation of the entire error chain.
fn display_chain(&self) -> String;

/// Like [Self::display_chain] but with an extra message at the start of the chain
fn display_chain_with_msg(&self, msg: &str) -> String;
}

impl<E: Error> ErrorExt for E {
fn display_chain(&self) -> String {
let mut s = format!("Error: {self}");
let mut source = self.source();
while let Some(error) = source {
write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
source = error.source();
}
s
}

fn display_chain_with_msg(&self, msg: &str) -> String {
let mut s = format!("Error: {msg}\nCaused by: {self}");
let mut source = self.source();
while let Some(error) = source {
write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
source = error.source();
}
s
}
}

#[derive(Debug)]
pub struct BoxedError(Box<dyn Error + 'static + Send>);

impl fmt::Display for BoxedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl Error for BoxedError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.0.source()
}
}

impl BoxedError {
pub fn new(error: impl Error + 'static + Send) -> Self {
BoxedError(Box::new(error))
}
}

/// Helper macro allowing simpler handling of Windows FFI returning `WIN32_ERROR`
/// status codes. Converts a `WIN32_ERROR` into an `io::Result<()>`.
///
/// The caller of this macro must have `windows_sys` as a dependency.
#[cfg(windows)]
#[macro_export]
macro_rules! win32_err {
($expr:expr) => {{
let status = $expr;
if status == ::windows_sys::Win32::Foundation::NO_ERROR {
Ok(())
} else {
Err(::std::io::Error::from_raw_os_error(status as i32))
}
}};
}
55 changes: 2 additions & 53 deletions talpid-types/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,5 @@
#![deny(rust_2018_idioms)]

use std::{error::Error, fmt, fmt::Write};

#[cfg(target_os = "android")]
pub mod android;
pub mod net;
Expand All @@ -13,54 +11,5 @@ pub mod cgroup;
#[cfg(target_os = "windows")]
pub mod split_tunnel;

/// Used to generate string representations of error chains.
pub trait ErrorExt {
/// Creates a string representation of the entire error chain.
fn display_chain(&self) -> String;

/// Like [Self::display_chain] but with an extra message at the start of the chain
fn display_chain_with_msg(&self, msg: &str) -> String;
}

impl<E: Error> ErrorExt for E {
fn display_chain(&self) -> String {
let mut s = format!("Error: {self}");
let mut source = self.source();
while let Some(error) = source {
write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
source = error.source();
}
s
}

fn display_chain_with_msg(&self, msg: &str) -> String {
let mut s = format!("Error: {msg}\nCaused by: {self}");
let mut source = self.source();
while let Some(error) = source {
write!(&mut s, "\nCaused by: {error}").expect("formatting failed");
source = error.source();
}
s
}
}

#[derive(Debug)]
pub struct BoxedError(Box<dyn Error + 'static + Send>);

impl fmt::Display for BoxedError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
self.0.fmt(f)
}
}

impl Error for BoxedError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
self.0.source()
}
}

impl BoxedError {
pub fn new(error: impl Error + 'static + Send) -> Self {
BoxedError(Box::new(error))
}
}
mod error;
pub use error::*;
2 changes: 2 additions & 0 deletions talpid-windows-net/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ socket2 = { version = "0.4.2", features = ["all"] }
futures = "0.3.15"
winapi = { version = "0.3.6", features = ["ws2def"] }

talpid-types = { path = "../talpid-types" }

[target.'cfg(windows)'.dependencies.windows-sys]
workspace = true
features = [
Expand Down
76 changes: 20 additions & 56 deletions talpid-windows-net/src/net.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,12 @@ use std::{
sync::Mutex,
time::{Duration, Instant},
};
use talpid_types::win32_err;
use winapi::shared::ws2def::SOCKADDR_STORAGE as sockaddr_storage;
use windows_sys::{
core::GUID,
Win32::{
Foundation::{ERROR_NOT_FOUND, HANDLE, NO_ERROR},
Foundation::{ERROR_NOT_FOUND, HANDLE},
NetworkManagement::{
IpHelper::{
CancelMibChangeNotify2, ConvertInterfaceAliasToLuid, ConvertInterfaceLuidToAlias,
Expand Down Expand Up @@ -174,21 +175,16 @@ pub fn notify_ip_interface_change<'a, T: FnMut(&MIB_IPINTERFACE_ROW, i32) + Send
handle: 0,
});

let status = unsafe {
win32_err!(unsafe {
NotifyIpInterfaceChange(
af_family_from_family(family),
Some(inner_callback),
&mut *context as *mut _ as *mut _,
0,
(&mut context.handle) as *mut _,
)
};

if status == NO_ERROR as i32 {
Ok(context)
} else {
Err(io::Error::from_raw_os_error(status))
}
})?;
Ok(context)
}

/// Returns information about a network IP interface.
Expand All @@ -200,22 +196,13 @@ pub fn get_ip_interface_entry(
row.Family = family as u16;
row.InterfaceLuid = *luid;

let result = unsafe { GetIpInterfaceEntry(&mut row) };
if result == NO_ERROR as i32 {
Ok(row)
} else {
Err(io::Error::from_raw_os_error(result))
}
win32_err!(unsafe { GetIpInterfaceEntry(&mut row) })?;
Ok(row)
}

/// Set the properties of an IP interface.
pub fn set_ip_interface_entry(row: &mut MIB_IPINTERFACE_ROW) -> io::Result<()> {
let result = unsafe { SetIpInterfaceEntry(row as *mut _) };
if result == NO_ERROR as i32 {
Ok(())
} else {
Err(io::Error::from_raw_os_error(result))
}
win32_err!(unsafe { SetIpInterfaceEntry(row as *mut _) })
}

fn ip_interface_entry_exists(family: AddressFamily, luid: &NET_LUID_LH) -> io::Result<bool> {
Expand Down Expand Up @@ -293,12 +280,8 @@ pub async fn wait_for_addresses(luid: NET_LUID_LH) -> Result<()> {
let mut ready = true;

for row in &mut unicast_rows {
let status = unsafe { GetUnicastIpAddressEntry(row) };
if status != NO_ERROR as i32 {
return Err(Error::ObtainUnicastAddress(io::Error::from_raw_os_error(
status,
)));
}
win32_err!(unsafe { GetUnicastIpAddressEntry(row) })
.map_err(Error::ObtainUnicastAddress)?;
if row.DadState == IpDadStateTentative {
ready = false;
break;
Expand Down Expand Up @@ -347,13 +330,7 @@ pub fn add_ip_address_for_interface(luid: NET_LUID_LH, address: IpAddr) -> Resul
row.DadState = IpDadStatePreferred;
row.OnLinkPrefixLength = 255;

let status = unsafe { CreateUnicastIpAddressEntry(&row) };
if status != NO_ERROR as i32 {
return Err(Error::CreateUnicastEntry(io::Error::from_raw_os_error(
status,
)));
}
Ok(())
win32_err!(unsafe { CreateUnicastIpAddressEntry(&row) }).map_err(Error::CreateUnicastEntry)
}

/// Returns the unicast IP address table. If `family` is `None`, then addresses for all families are
Expand All @@ -364,11 +341,9 @@ pub fn get_unicast_table(
let mut unicast_rows = vec![];
let mut unicast_table: *mut MIB_UNICASTIPADDRESS_TABLE = std::ptr::null_mut();

let status =
unsafe { GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table) };
if status != NO_ERROR as i32 {
return Err(io::Error::from_raw_os_error(status));
}
win32_err!(unsafe {
GetUnicastIpAddressTable(af_family_from_family(family), &mut unicast_table)
})?;
let first_row = unsafe { &(*unicast_table).Table[0] } as *const MIB_UNICASTIPADDRESS_ROW;
for i in 0..unsafe { *unicast_table }.NumEntries {
unicast_rows.push(unsafe { *(first_row.offset(i as isize)) });
Expand All @@ -381,20 +356,14 @@ pub fn get_unicast_table(
/// Returns the index of a network interface given its LUID.
pub fn index_from_luid(luid: &NET_LUID_LH) -> io::Result<u32> {
let mut index = 0u32;
let status = unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) };
if status != NO_ERROR as i32 {
return Err(io::Error::from_raw_os_error(status));
}
win32_err!(unsafe { ConvertInterfaceLuidToIndex(luid, &mut index) })?;
Ok(index)
}

/// Returns the GUID of a network interface given its LUID.
pub fn guid_from_luid(luid: &NET_LUID_LH) -> io::Result<GUID> {
let mut guid = MaybeUninit::zeroed();
let status = unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) };
if status != NO_ERROR as i32 {
return Err(io::Error::from_raw_os_error(status));
}
win32_err!(unsafe { ConvertInterfaceLuidToGuid(luid, guid.as_mut_ptr()) })?;
Ok(unsafe { guid.assume_init() })
}

Expand All @@ -406,21 +375,16 @@ pub fn luid_from_alias<T: AsRef<OsStr>>(alias: T) -> io::Result<NET_LUID_LH> {
.chain(std::iter::once(0u16))
.collect();
let mut luid: NET_LUID_LH = unsafe { std::mem::zeroed() };
let status = unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) };
if status != NO_ERROR as i32 {
return Err(io::Error::from_raw_os_error(status));
}
win32_err!(unsafe { ConvertInterfaceAliasToLuid(alias_wide.as_ptr(), &mut luid) })?;
Ok(luid)
}

/// Returns the alias of an interface given its LUID.
pub fn alias_from_luid(luid: &NET_LUID_LH) -> io::Result<OsString> {
let mut buffer = [0u16; IF_MAX_STRING_SIZE as usize + 1];
let status =
unsafe { ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len()) };
if status != NO_ERROR as i32 {
return Err(io::Error::from_raw_os_error(status));
}
win32_err!(unsafe {
ConvertInterfaceLuidToAlias(luid, &mut buffer[0] as *mut _, buffer.len())
})?;
let nul = buffer.iter().position(|&c| c == 0u16).unwrap();
Ok(OsString::from_wide(&buffer[0..nul]))
}
Expand Down

0 comments on commit e350149

Please sign in to comment.