Skip to content

Commit

Permalink
Use function pointers for interop in System.Net.NetworkInformation (d…
Browse files Browse the repository at this point in the history
  • Loading branch information
jkotas authored May 6, 2021
1 parent 91dcd97 commit bf62ae4
Show file tree
Hide file tree
Showing 11 changed files with 249 additions and 168 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,19 +42,15 @@ public unsafe struct NetworkInterfaceInfo
private fixed byte __padding[3];
}

public unsafe delegate void IPv4AddressDiscoveredCallback(string ifaceName, IpAddressInfo* ipAddressInfo);
public unsafe delegate void IPv6AddressDiscoveredCallback(string ifaceName, IpAddressInfo* ipAddressInfo, uint* scopeId);
public unsafe delegate void LinkLayerAddressDiscoveredCallback(string ifaceName, LinkLayerAddressInfo* llAddress);
public unsafe delegate void DnsAddessDiscoveredCallback(IpAddressInfo* gatewayAddress);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_EnumerateInterfaceAddresses")]
public static extern int EnumerateInterfaceAddresses(
IPv4AddressDiscoveredCallback ipv4Found,
IPv6AddressDiscoveredCallback? ipv6Found,
LinkLayerAddressDiscoveredCallback? linkLayerFound);
public static extern unsafe int EnumerateInterfaceAddresses(
void* context,
delegate* unmanaged<void*, byte*, IpAddressInfo*, void> ipv4Found,
delegate* unmanaged<void*, byte*, IpAddressInfo*, uint*, void> ipv6Found,
delegate* unmanaged<void*, byte*, LinkLayerAddressInfo*, void> linkLayerFound);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_EnumerateGatewayAddressesForInterface")]
public static extern int EnumerateGatewayAddressesForInterface(uint interfaceIndex, DnsAddessDiscoveredCallback onGatewayFound);
public static extern unsafe int EnumerateGatewayAddressesForInterface(void* context, uint interfaceIndex, delegate* unmanaged<void*, IpAddressInfo*, void> onGatewayFound);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_GetNetworkInterfaces")]
public static unsafe extern int GetNetworkInterfaces(ref int count, ref NetworkInterfaceInfo* addrs, ref int addressCount, ref IpAddressInfo *aa);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,13 @@ public enum NetworkChangeKind
AvailabilityChanged = 2
}

public delegate void NetworkChangeEvent(int socket, NetworkChangeKind kind);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_CreateNetworkChangeListenerSocket")]
public static extern Error CreateNetworkChangeListenerSocket(out int socket);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_CloseNetworkChangeListenerSocket")]
public static extern Error CloseNetworkChangeListenerSocket(int socket);

[DllImport(Libraries.SystemNative, EntryPoint = "SystemNative_ReadEvents")]
public static extern void ReadEvents(int socket, NetworkChangeEvent onNetworkChange);
public static extern unsafe void ReadEvents(int socket, delegate* unmanaged<int, NetworkChangeKind, void> onNetworkChange);
}
}
19 changes: 11 additions & 8 deletions src/libraries/Native/Unix/System.Native/pal_interfaceaddresses.c
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,8 @@ static inline uint8_t mask2prefix(uint8_t* mask, int length)
}
#endif

int32_t SystemNative_EnumerateInterfaceAddresses(IPv4AddressFound onIpv4Found,
int32_t SystemNative_EnumerateInterfaceAddresses(void* context,
IPv4AddressFound onIpv4Found,
IPv6AddressFound onIpv6Found,
LinkLayerAddressFound onLinkLayerFound)
{
Expand Down Expand Up @@ -142,7 +143,7 @@ int32_t SystemNative_EnumerateInterfaceAddresses(IPv4AddressFound onIpv4Found,
// ifa_netmask can be NULL according to documentation, probably P2P interfaces.
iai.PrefixLength = mask_sain != NULL ? mask2prefix((uint8_t*)&mask_sain->sin_addr.s_addr, NUM_BYTES_IN_IPV4_ADDRESS) : NUM_BYTES_IN_IPV4_ADDRESS * 8;

onIpv4Found(actualName, &iai);
onIpv4Found(context, actualName, &iai);
}
}
else if (family == AF_INET6)
Expand All @@ -160,7 +161,7 @@ int32_t SystemNative_EnumerateInterfaceAddresses(IPv4AddressFound onIpv4Found,

struct sockaddr_in6* mask_sain6 = (struct sockaddr_in6*)current->ifa_netmask;
iai.PrefixLength = mask_sain6 != NULL ? mask2prefix((uint8_t*)&mask_sain6->sin6_addr.s6_addr, NUM_BYTES_IN_IPV6_ADDRESS) : NUM_BYTES_IN_IPV6_ADDRESS * 8;
onIpv6Found(actualName, &iai, &scopeId);
onIpv6Found(context, actualName, &iai, &scopeId);
}
}
#if defined(AF_PACKET)
Expand All @@ -186,7 +187,7 @@ int32_t SystemNative_EnumerateInterfaceAddresses(IPv4AddressFound onIpv4Found,
lla.HardwareType = MapHardwareType(sall->sll_hatype);

memcpy_s(&lla.AddressBytes, sizeof_member(LinkLayerAddressInfo, AddressBytes), &sall->sll_addr, sall->sll_halen);
onLinkLayerFound(current->ifa_name, &lla);
onLinkLayerFound(context, current->ifa_name, &lla);
}
}
#elif defined(AF_LINK)
Expand Down Expand Up @@ -223,7 +224,7 @@ int32_t SystemNative_EnumerateInterfaceAddresses(IPv4AddressFound onIpv4Found,
}
#endif
memcpy_s(&lla.AddressBytes, sizeof_member(LinkLayerAddressInfo, AddressBytes), (uint8_t*)LLADDR(sadl), sadl->sdl_alen);
onLinkLayerFound(current->ifa_name, &lla);
onLinkLayerFound(context, current->ifa_name, &lla);
}
}
#endif
Expand All @@ -233,6 +234,7 @@ int32_t SystemNative_EnumerateInterfaceAddresses(IPv4AddressFound onIpv4Found,
return 0;
#else
// Not supported on e.g. Android. Also, prevent a compiler error because parameters are unused
(void)context;
(void)onIpv4Found;
(void)onIpv6Found;
(void)onLinkLayerFound;
Expand Down Expand Up @@ -460,7 +462,7 @@ int32_t SystemNative_GetNetworkInterfaces(int32_t * interfaceCount, NetworkInter
}

#if HAVE_RT_MSGHDR && defined(CTL_NET)
int32_t SystemNative_EnumerateGatewayAddressesForInterface(uint32_t interfaceIndex, GatewayAddressFound onGatewayFound)
int32_t SystemNative_EnumerateGatewayAddressesForInterface(void* context, uint32_t interfaceIndex, GatewayAddressFound onGatewayFound)
{
static struct in6_addr anyaddr = IN6ADDR_ANY_INIT;
int routeDumpName[] = {CTL_NET, AF_ROUTE, 0, 0, NET_RT_DUMP, 0};
Expand Down Expand Up @@ -558,16 +560,17 @@ int32_t SystemNative_EnumerateGatewayAddressesForInterface(uint32_t interfaceInd
// Ignore other address families.
continue;
}
onGatewayFound(&iai);
onGatewayFound(context, &iai);
}
}

free(buffer);
return 0;
}
#else
int32_t SystemNative_EnumerateGatewayAddressesForInterface(uint32_t interfaceIndex, GatewayAddressFound onGatewayFound)
int32_t SystemNative_EnumerateGatewayAddressesForInterface(void* context, uint32_t interfaceIndex, GatewayAddressFound onGatewayFound)
{
(void)context;
(void)interfaceIndex;
(void)onGatewayFound;
errno = ENOTSUP;
Expand Down
12 changes: 6 additions & 6 deletions src/libraries/Native/Unix/System.Native/pal_interfaceaddresses.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,13 @@ typedef struct
uint8_t __padding[3];
} NetworkInterfaceInfo;

typedef void (*IPv4AddressFound)(const char* interfaceName, IpAddressInfo* addressInfo);
typedef void (*IPv6AddressFound)(const char* interfaceName, IpAddressInfo* info, uint32_t* scopeId);
typedef void (*LinkLayerAddressFound)(const char* interfaceName, LinkLayerAddressInfo* llAddress);
typedef void (*GatewayAddressFound)(IpAddressInfo* addressInfo);
typedef void (*IPv4AddressFound)(void* context, const char* interfaceName, IpAddressInfo* addressInfo);
typedef void (*IPv6AddressFound)(void* context, const char* interfaceName, IpAddressInfo* info, uint32_t* scopeId);
typedef void (*LinkLayerAddressFound)(void* context, const char* interfaceName, LinkLayerAddressInfo* llAddress);
typedef void (*GatewayAddressFound)(void* context, IpAddressInfo* addressInfo);

PALEXPORT int32_t SystemNative_EnumerateInterfaceAddresses(
IPv4AddressFound onIpv4Found, IPv6AddressFound onIpv6Found, LinkLayerAddressFound onLinkLayerFound);
void* context, IPv4AddressFound onIpv4Found, IPv6AddressFound onIpv6Found, LinkLayerAddressFound onLinkLayerFound);
PALEXPORT int32_t SystemNative_GetNetworkInterfaces(int32_t * interfaceCount, NetworkInterfaceInfo** interfaces, int32_t * addressCount, IpAddressInfo **addressList);

PALEXPORT int32_t SystemNative_EnumerateGatewayAddressesForInterface(uint32_t interfaceIndex, GatewayAddressFound onGatewayFound);
PALEXPORT int32_t SystemNative_EnumerateGatewayAddressesForInterface(void* context, uint32_t interfaceIndex, GatewayAddressFound onGatewayFound);
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@
<Reference Include="System.Net.Primitives" />
<Reference Include="System.Net.Sockets" />
<Reference Include="System.Runtime" />
<Reference Include="System.Runtime.CompilerServices.Unsafe" />
<Reference Include="System.Runtime.InteropServices" />
<Reference Include="System.Threading" />
<Reference Include="System.Threading.Overlapped" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Net.NetworkInformation
{
Expand All @@ -25,6 +27,21 @@ internal sealed class BsdIPv4GlobalStatistics : IPGlobalStatistics
private readonly int _numIPAddresses;
private readonly int _numRoutes;

private struct Context
{
internal int _numIPAddresses;
internal HashSet<string> _interfaceSet;
}

[UnmanagedCallersOnly]
private static unsafe void ProcessIpv4Address(void* pContext, byte* ifaceName, Interop.Sys.IpAddressInfo* ipAddr)
{
ref Context context = ref Unsafe.As<byte, Context>(ref *(byte*)pContext);

context._interfaceSet.Add(new string((sbyte*)ifaceName));
context._numIPAddresses++;
}

public unsafe BsdIPv4GlobalStatistics()
{
Interop.Sys.IPv4GlobalStatistics statistics;
Expand All @@ -48,19 +65,18 @@ public unsafe BsdIPv4GlobalStatistics()
_defaultTtl = statistics.DefaultTtl;
_forwarding = statistics.Forwarding == 1;

HashSet<string> interfaceSet = new HashSet<string>();
int numIPAddresses = 0;
Context context;
context._numIPAddresses = 0;
context._interfaceSet = new HashSet<string>();

Interop.Sys.EnumerateInterfaceAddresses(
(name, addressInfo) =>
{
interfaceSet.Add(name);
numIPAddresses++;
},
Unsafe.AsPointer(ref context),
&ProcessIpv4Address,
null,
null);

_numInterfaces = interfaceSet.Count;
_numIPAddresses = numIPAddresses;
_numInterfaces = context._interfaceSet.Count;
_numIPAddresses = context._numIPAddresses;

_numRoutes = Interop.Sys.GetNumRoutes();
if (_numRoutes == -1)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
// The .NET Foundation licenses this file to you under the MIT license.

using System.Collections.Generic;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;

namespace System.Net.NetworkInformation
{
Expand Down Expand Up @@ -40,36 +42,43 @@ public override IPv6InterfaceProperties GetIPv6Properties()
return _ipv6Properties;
}

private struct Context
{
internal int _interfaceIndex;
internal HashSet<IPAddress> _addressSet;
}

private static unsafe GatewayIPAddressInformationCollection GetGatewayAddresses(int interfaceIndex)
{
HashSet<IPAddress> addressSet = new HashSet<IPAddress>();
if (Interop.Sys.EnumerateGatewayAddressesForInterface((uint)interfaceIndex,
(gatewayAddressInfo) =>
{
byte[] ipBytes = new byte[gatewayAddressInfo->NumAddressBytes];
fixed (byte* ipArrayPtr = ipBytes)
{
Buffer.MemoryCopy(gatewayAddressInfo->AddressBytes, ipArrayPtr, ipBytes.Length, ipBytes.Length);
}
IPAddress ipAddress = new IPAddress(ipBytes);
if (ipAddress.IsIPv6LinkLocal)
{
// For Link-Local addresses add ScopeId as that is not part of the route entry.
ipAddress.ScopeId = interfaceIndex;
}
addressSet.Add(ipAddress);
}) == -1)
Context context;
context._interfaceIndex = interfaceIndex;
context._addressSet = new HashSet<IPAddress>();
if (Interop.Sys.EnumerateGatewayAddressesForInterface(Unsafe.AsPointer(ref context), (uint)interfaceIndex, &OnGatewayFound) == -1)
{
throw new NetworkInformationException(SR.net_PInvokeError);
}

GatewayIPAddressInformationCollection collection = new GatewayIPAddressInformationCollection();
foreach (IPAddress address in addressSet)
foreach (IPAddress address in context._addressSet)
{
collection.InternalAdd(new SimpleGatewayIPAddressInformation(address));
}

return collection;
}

[UnmanagedCallersOnly]
private static unsafe void OnGatewayFound(void* pContext, Interop.Sys.IpAddressInfo* gatewayAddressInfo)
{
ref Context context = ref Unsafe.As<byte, Context>(ref *(byte*)pContext);

IPAddress ipAddress = new IPAddress(new Span<byte>(gatewayAddressInfo->AddressBytes, gatewayAddressInfo->NumAddressBytes).ToArray());
if (ipAddress.IsIPv6LinkLocal)
{
// For Link-Local addresses add ScopeId as that is not part of the route entry.
ipAddress.ScopeId = context._interfaceIndex;
}
context._addressSet.Add(ipAddress);
}
}
}
Loading

0 comments on commit bf62ae4

Please sign in to comment.