Skip to content

Commit

Permalink
Wrap operations requiring SocketPermission with doPrivileged blocks
Browse files Browse the repository at this point in the history
Motivation:

Currently Netty does not wrap socket connect, bind, or accept
operations in doPrivileged blocks. Nor does it wrap cases where a dns
lookup might happen.

This prevents an application utilizing the SecurityManager from
isolating SocketPermissions to Netty.

Modifications:

I have introduced a class (SocketUtils) that wraps operations
requiring SocketPermissions in doPrivileged blocks.

Result:

A user of Netty can grant SocketPermissions explicitly to the Netty
jar, without granting it to the rest of their application.
  • Loading branch information
Tim-Brooks authored and normanmaurer committed Jan 19, 2017
1 parent 2d11331 commit 3344cd2
Show file tree
Hide file tree
Showing 42 changed files with 527 additions and 81 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.netty.buffer.ByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.socket.InternetProtocolFamily;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.StringUtil;
import io.netty.util.internal.ThreadLocalRandom;
import org.junit.Test;
Expand Down Expand Up @@ -69,12 +70,12 @@ private static void testEncodeName(byte[] expected, String name) throws Exceptio

@Test
public void testOptEcsRecordIpv4() throws Exception {
testOptEcsRecordIp(InetAddress.getByName("1.2.3.4"));
testOptEcsRecordIp(SocketUtils.addressByName("1.2.3.4"));
}

@Test
public void testOptEcsRecordIpv6() throws Exception {
testOptEcsRecordIp(InetAddress.getByName("::0"));
testOptEcsRecordIp(SocketUtils.addressByName("::0"));
}

private static void testOptEcsRecordIp(InetAddress address) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import io.netty.channel.embedded.EmbeddedChannel;

import io.netty.channel.socket.DatagramPacket;
import io.netty.util.internal.SocketUtils;
import org.junit.Assert;
import org.junit.Test;

Expand All @@ -32,7 +33,7 @@ public class DnsQueryTest {

@Test
public void writeQueryTest() throws Exception {
InetSocketAddress addr = new InetSocketAddress("8.8.8.8", 53);
InetSocketAddress addr = SocketUtils.socketAddress("8.8.8.8", 53);
EmbeddedChannel embedder = new EmbeddedChannel(new DatagramDnsQueryEncoder());
List<DnsQuery> queries = new ArrayList<DnsQuery>(5);
queries.add(new DatagramDnsQuery(null, addr, 1).setRecord(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
package io.netty.handler.codec.socks;

import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.Test;

import java.net.InetAddress;
import java.net.UnknownHostException;

import static org.junit.Assert.*;
Expand Down Expand Up @@ -65,7 +65,7 @@ public void testCmdRequestDecoderIPv4() {

@Test
public void testCmdRequestDecoderIPv6() throws UnknownHostException {
String[] hosts = {SocksCommonUtils.ipv6toStr(InetAddress.getByName("::1").getAddress())};
String[] hosts = {SocksCommonUtils.ipv6toStr(SocketUtils.addressByName("::1").getAddress())};
int[] ports = {1, 32769, 65535};
for (SocksCmdType cmdType : SocksCmdType.values()) {
for (String host : hosts) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,12 @@

import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.NetUtil;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;
import org.junit.Test;

import java.net.IDN;
import java.net.InetAddress;
import java.net.UnknownHostException;
import java.util.Arrays;

Expand Down Expand Up @@ -68,7 +68,7 @@ public void testCmdRequestDecoderIPv4() {
@Test
public void testCmdRequestDecoderIPv6() throws UnknownHostException {
String[] hosts = {
NetUtil.bytesToIpAddress(InetAddress.getByName("::1").getAddress()) };
NetUtil.bytesToIpAddress(SocketUtils.addressByName("::1").getAddress()) };
int[] ports = {1, 32769, 65535};
for (Socks5CommandType cmdType: Arrays.asList(Socks5CommandType.BIND,
Socks5CommandType.CONNECT,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.string.StringDecoder;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.SocketUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand Down Expand Up @@ -49,8 +50,8 @@ public void tearDown() {

@Test
public void testDecode() {
InetSocketAddress recipient = new InetSocketAddress("127.0.0.1", 10000);
InetSocketAddress sender = new InetSocketAddress("127.0.0.1", 20000);
InetSocketAddress recipient = SocketUtils.socketAddress("127.0.0.1", 10000);
InetSocketAddress sender = SocketUtils.socketAddress("127.0.0.1", 20000);
ByteBuf content = Unpooled.wrappedBuffer("netty".getBytes(CharsetUtil.UTF_8));
assertTrue(channel.writeInbound(new DatagramPacket(content, recipient, sender)));
assertEquals("netty", channel.readInbound());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import io.netty.channel.socket.DatagramPacket;
import io.netty.handler.codec.string.StringEncoder;
import io.netty.util.CharsetUtil;
import io.netty.util.internal.SocketUtils;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -46,8 +47,8 @@ public void tearDown() {

@Test
public void testEncode() {
InetSocketAddress recipient = new InetSocketAddress("127.0.0.1", 10000);
InetSocketAddress sender = new InetSocketAddress("127.0.0.1", 20000);
InetSocketAddress recipient = SocketUtils.socketAddress("127.0.0.1", 10000);
InetSocketAddress sender = SocketUtils.socketAddress("127.0.0.1", 20000);
assertTrue(channel.writeOutbound(
new DefaultAddressedEnvelope<String, InetSocketAddress>("netty", recipient, sender)));
DatagramPacket packet = channel.readOutbound();
Expand All @@ -62,8 +63,8 @@ public void testEncode() {

@Test
public void testUnmatchedMessageType() {
InetSocketAddress recipient = new InetSocketAddress("127.0.0.1", 10000);
InetSocketAddress sender = new InetSocketAddress("127.0.0.1", 20000);
InetSocketAddress recipient = SocketUtils.socketAddress("127.0.0.1", 10000);
InetSocketAddress sender = SocketUtils.socketAddress("127.0.0.1", 20000);
DefaultAddressedEnvelope<Long, InetSocketAddress> envelope =
new DefaultAddressedEnvelope<Long, InetSocketAddress>(1L, recipient, sender);
assertTrue(channel.writeOutbound(envelope));
Expand Down
7 changes: 4 additions & 3 deletions common/src/main/java/io/netty/util/NetUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package io.netty.util;

import io.netty.util.internal.PlatformDependent;
import io.netty.util.internal.SocketUtils;
import io.netty.util.internal.logging.InternalLogger;
import io.netty.util.internal.logging.InternalLoggerFactory;

Expand Down Expand Up @@ -165,7 +166,7 @@ public final class NetUtil {
for (Enumeration<NetworkInterface> i = NetworkInterface.getNetworkInterfaces(); i.hasMoreElements();) {
NetworkInterface iface = i.nextElement();
// Use the interface with proper INET addresses only.
if (iface.getInetAddresses().hasMoreElements()) {
if (SocketUtils.addressesFromNetworkInterface(iface).hasMoreElements()) {
ifaces.add(iface);
}
}
Expand All @@ -179,7 +180,7 @@ public final class NetUtil {
NetworkInterface loopbackIface = null;
InetAddress loopbackAddr = null;
loop: for (NetworkInterface iface: ifaces) {
for (Enumeration<InetAddress> i = iface.getInetAddresses(); i.hasMoreElements();) {
for (Enumeration<InetAddress> i = SocketUtils.addressesFromNetworkInterface(iface); i.hasMoreElements();) {
InetAddress addr = i.nextElement();
if (addr.isLoopbackAddress()) {
// Found
Expand All @@ -195,7 +196,7 @@ public final class NetUtil {
try {
for (NetworkInterface iface: ifaces) {
if (iface.isLoopback()) {
Enumeration<InetAddress> i = iface.getInetAddresses();
Enumeration<InetAddress> i = SocketUtils.addressesFromNetworkInterface(iface);
if (i.hasMoreElements()) {
// Found the one with INET address.
loopbackIface = iface;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ public static byte[] bestAvailableMac() {
for (Enumeration<NetworkInterface> i = NetworkInterface.getNetworkInterfaces(); i.hasMoreElements();) {
NetworkInterface iface = i.nextElement();
// Use the interface with proper INET addresses only.
Enumeration<InetAddress> addrs = iface.getInetAddresses();
Enumeration<InetAddress> addrs = SocketUtils.addressesFromNetworkInterface(iface);
if (addrs.hasMoreElements()) {
InetAddress a = addrs.nextElement();
if (!a.isLoopbackAddress()) {
Expand All @@ -76,7 +76,7 @@ public static byte[] bestAvailableMac() {

byte[] macAddr;
try {
macAddr = iface.getHardwareAddress();
macAddr = SocketUtils.hardwareAddressFromNetworkInterface(iface);
} catch (SocketException e) {
logger.debug("Failed to get the hardware address of a network interface: {}", iface, e);
continue;
Expand Down
197 changes: 197 additions & 0 deletions common/src/main/java/io/netty/util/internal/SocketUtils.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
/*
* Copyright 2016 The Netty Project
*
* The Netty Project licenses this file to you under the Apache License,
* version 2.0 (the "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at:
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations
* under the License.
*/
package io.netty.util.internal;

import java.io.IOException;
import java.net.InetAddress;
import java.net.InetSocketAddress;
import java.net.NetworkInterface;
import java.net.ServerSocket;
import java.net.Socket;
import java.net.SocketAddress;
import java.net.SocketException;
import java.net.SocketPermission;
import java.net.UnknownHostException;
import java.nio.channels.DatagramChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.security.AccessController;
import java.security.PrivilegedAction;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.Enumeration;

/**
* Provides socket operations with privileges enabled. This is necessary for applications that use the
* {@link SecurityManager} to restrict {@link SocketPermission} to their application. By asserting that these
* operations are privileged, the operations can proceed even if some code in the calling chain lacks the appropriate
* {@link SocketPermission}.
*/
public final class SocketUtils {

private SocketUtils() {
}

public static void connect(final Socket socket, final SocketAddress remoteAddress, final int timeout)
throws IOException {
try {
AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws IOException {
socket.connect(remoteAddress, timeout);
return null;
}
});
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}

public static void bind(final Socket socket, final SocketAddress bindpoint) throws IOException {
try {
AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws IOException {
socket.bind(bindpoint);
return null;
}
});
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}

public static boolean connect(final SocketChannel socketChannel, final SocketAddress remoteAddress)
throws IOException {
try {
return AccessController.doPrivileged(new PrivilegedExceptionAction<Boolean>() {
@Override
public Boolean run() throws IOException {
return socketChannel.connect(remoteAddress);
}
});
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}

public static void bind(final SocketChannel socketChannel, final SocketAddress address) throws IOException {
try {
AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws IOException {
socketChannel.bind(address);
return null;
}
});
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}

public static SocketChannel accept(final ServerSocketChannel serverSocketChannel) throws IOException {
try {
return AccessController.doPrivileged(new PrivilegedExceptionAction<SocketChannel>() {
@Override
public SocketChannel run() throws IOException {
return serverSocketChannel.accept();
}
});
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}

public static void bind(final DatagramChannel networkChannel, final SocketAddress address) throws IOException {
try {
AccessController.doPrivileged(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws IOException {
networkChannel.bind(address);
return null;
}
});
} catch (PrivilegedActionException e) {
throw (IOException) e.getCause();
}
}

public static SocketAddress localSocketAddress(final ServerSocket socket) {
return AccessController.doPrivileged(new PrivilegedAction<SocketAddress>() {
@Override
public SocketAddress run() {
return socket.getLocalSocketAddress();
}
});
}

public static InetAddress addressByName(final String hostname) throws UnknownHostException {
try {
return AccessController.doPrivileged(new PrivilegedExceptionAction<InetAddress>() {
@Override
public InetAddress run() throws UnknownHostException {
return InetAddress.getByName(hostname);
}
});
} catch (PrivilegedActionException e) {
throw (UnknownHostException) e.getCause();
}
}

public static InetAddress[] allAddressesByName(final String hostname) throws UnknownHostException {
try {
return AccessController.doPrivileged(new PrivilegedExceptionAction<InetAddress[]>() {
@Override
public InetAddress[] run() throws UnknownHostException {
return InetAddress.getAllByName(hostname);
}
});
} catch (PrivilegedActionException e) {
throw (UnknownHostException) e.getCause();
}
}

public static InetSocketAddress socketAddress(final String hostname, final int port) {
return AccessController.doPrivileged(new PrivilegedAction<InetSocketAddress>() {
@Override
public InetSocketAddress run() {
return new InetSocketAddress(hostname, port);
}
});
}

public static Enumeration<InetAddress> addressesFromNetworkInterface(final NetworkInterface intf) {
return AccessController.doPrivileged(new PrivilegedAction<Enumeration<InetAddress>>() {
@Override
public Enumeration<InetAddress> run() {
return intf.getInetAddresses();
}
});
}

public static byte[] hardwareAddressFromNetworkInterface(final NetworkInterface intf) throws SocketException {
try {
return AccessController.doPrivileged(new PrivilegedExceptionAction<byte[]>() {
@Override
public byte[] run() throws SocketException {
return intf.getHardwareAddress();
}
});
} catch (PrivilegedActionException e) {
throw (SocketException) e.getCause();
}
}
}
Loading

0 comments on commit 3344cd2

Please sign in to comment.