Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Remote Dynamic Forwarding #615

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
279 changes: 277 additions & 2 deletions src/Renci.SshNet/Channels/ChannelForwardedTcpip.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
using System;
using System.IO;
using System.Linq;
using System.Net;
using System.Net.Sockets;
using System.Threading;
using Renci.SshNet.Abstractions;
using Renci.SshNet.Common;
using Renci.SshNet.Messages.Connection;
Expand All @@ -16,6 +19,10 @@ internal class ChannelForwardedTcpip : ServerChannel, IChannelForwardedTcpip
private Socket _socket;
private IForwardedPort _forwardedPort;

private bool doSocks;
private bool doSocks5;
private ManualResetEvent completionWaitHandle;

/// <summary>
/// Initializes a new <see cref="ChannelForwardedTcpip"/> instance.
/// </summary>
Expand Down Expand Up @@ -69,7 +76,17 @@ public void Bind(IPEndPoint remoteEndpoint, IForwardedPort forwardedPort)
_forwardedPort = forwardedPort;
_forwardedPort.Closing += ForwardedPort_Closing;

// Try to connect to the socket
if (remoteEndpoint == null)
{
doSocks = true;
SendMessage(new ChannelOpenConfirmationMessage(RemoteChannelNumber, LocalWindowSize, LocalPacketSize, LocalChannelNumber));

completionWaitHandle = new ManualResetEvent(false);
completionWaitHandle.WaitOne();
completionWaitHandle.Dispose();
}

// Try to connect to the socket
try
{
_socket = SocketAbstraction.Connect(remoteEndpoint, ConnectionInfo.Timeout);
Expand Down Expand Up @@ -111,6 +128,11 @@ private void ForwardedPort_Closing(object sender, EventArgs eventArgs)
//
// if the FIN/ACK is not sent in time, the socket will be closed in Close(bool)
ShutdownSocket(SocketShutdown.Send);

if (completionWaitHandle != null)
{
completionWaitHandle.Set();
}
}

/// <summary>
Expand Down Expand Up @@ -190,6 +212,13 @@ protected override void Close()
/// <param name="data">The data.</param>
protected override void OnData(byte[] data)
{
if (doSocks)
{
var stream = new MemoryStream(data);
HandleSocks(stream);
return;
}

base.OnData(data);

var socket = _socket;
Expand All @@ -198,5 +227,251 @@ protected override void OnData(byte[] data)
SocketAbstraction.Send(socket, data, 0, data.Length);
}
}

private void HandleSocks(MemoryStream stream)
{
var version = ReadByte(stream);
switch (version)
{
case 4:
HandleSocks4(stream);
doSocks = false;
return;
case 5:
if (!doSocks5)
{
var authenticationMethodsCount = ReadByte(stream);
var authenticationMethods = new byte[authenticationMethodsCount];
if (stream.Read(authenticationMethods, 0, authenticationMethods.Length) == 0)
{
return;
}

if (authenticationMethods.Min() == 0)
{
// no user authentication is one of the authentication methods supported
// by the SOCKS client
SendData(new byte[] { 0x05, 0x00 });
}
else
{
// the SOCKS client requires authentication, which we currently do not support
SendData(new byte[] { 0x05, 0xFF });
}
doSocks5 = true;
return;
}
HandleSocks5(stream);
doSocks = false;
return;
}
throw new NotSupportedException(string.Format("SOCKS version {0} is not supported.", version));
}

private void HandleSocks4(MemoryStream stream)
{
var commandCode = ReadByte(stream);
if (commandCode == -1)
{
return;
}

var portBuffer = new byte[2];
if (stream.Read(portBuffer, 0, portBuffer.Length) == 0)
{
return;
}

var port = (portBuffer[0] * 256 + portBuffer[1]);

var ipBuffer = new byte[4];
if (stream.Read(ipBuffer, 0, ipBuffer.Length) == 0)
{
return;
}

var ipAddress = new IPAddress(ipBuffer);

ThreadAbstraction.ExecuteThread(() =>
{
var endpoint = new IPEndPoint(ipAddress, port);

try
{
_socket = SocketAbstraction.Connect(endpoint, ConnectionInfo.Timeout);
}
catch (Exception exp)
{
// send channel open failure message
SendMessage(new ChannelOpenFailureMessage(RemoteChannelNumber, exp.ToString(), ChannelOpenFailureMessage.ConnectFailed, "en"));
completionWaitHandle.Set();
throw;
}

SendData(new byte[] { 0x00, 0x5a });
SendData(portBuffer);
SendData(ipBuffer);

var buffer = new byte[RemotePacketSize];
SocketAbstraction.ReadContinuous(_socket, buffer, 0, buffer.Length, SendData);
});
}

private void HandleSocks5(MemoryStream stream)
{
var commandCode = ReadByte(stream);
if (commandCode == -1)
{
return;
}

var reserved = ReadByte(stream);
if (reserved == -1)
{
return;
}

if (reserved != 0)
{
throw new ProxyException("SOCKS5: 0 is expected for reserved byte.");
}

var addressType = ReadByte(stream);
if (addressType == -1)
{
// SOCKS client closed connection
return;
}

var ipAddress = GetSocks5Host(addressType, stream);
if (ipAddress == null)
{
// SOCKS client closed connection
return;
}

var portBuffer = new byte[2];
if (stream.Read(portBuffer, 0, portBuffer.Length) == 0)
{
return;
}

var port = (portBuffer[0] * 256 + portBuffer[1]);

ThreadAbstraction.ExecuteThread(() =>
{
var endpoint = new IPEndPoint(ipAddress, port);

try
{
_socket = SocketAbstraction.Connect(endpoint, ConnectionInfo.Timeout);
}
catch
{
// send channel open failure message
SendData(CreateSocks5Reply(false));
completionWaitHandle.Set();
throw;
}

SendData(CreateSocks5Reply(true));

var buffer = new byte[RemotePacketSize];
SocketAbstraction.ReadContinuous(_socket, buffer, 0, buffer.Length, SendData);
});
}

private IPAddress GetSocks5Host(int addressType, MemoryStream stream)
{
switch (addressType)
{
case 0x01: // IPv4
{
var addressBuffer = new byte[4];
if (stream.Read(addressBuffer, 0, 4) == 0)
{
// SOCKS client closed connection
return null;
}

return new IPAddress(addressBuffer);
}
case 0x03: // Domain name
{
var length = ReadByte(stream);
if (length == -1)
{
// SOCKS client closed connection
return null;
}
var addressBuffer = new byte[length];
if (stream.Read(addressBuffer, 0, addressBuffer.Length) == 0)
{
// SOCKS client closed connection
return null;
}

var hostName = SshData.Ascii.GetString(addressBuffer, 0, addressBuffer.Length);
return DnsAbstraction.GetHostAddresses(hostName)[0];
}
case 0x04: // IPv6
{
var addressBuffer = new byte[16];
if (stream.Read(addressBuffer, 0, 16) == 0)
{
return null;
}

return new IPAddress(addressBuffer);
}
default:
throw new ProxyException(string.Format("SOCKS5: Address type '{0}' is not supported.", addressType));
}
}

private static byte[] CreateSocks5Reply(bool success)
{
var socksReply = new byte[
// SOCKS version
1 +
// Reply field
1 +
// Reserved; fixed: 0x00
1 +
// Address type; fixed: 0x01
1 +
// IPv4 server bound address; fixed: {0x00, 0x00, 0x00, 0x00}
4 +
// server bound port; fixed: {0x00, 0x00}
2];

socksReply[0] = 0x05;

if (success)
{
socksReply[1] = 0x00; // succeeded
}
else
{
socksReply[1] = 0x01; // general SOCKS server failure
}

// reserved
socksReply[2] = 0x00;

// IPv4 address type
socksReply[3] = 0x01;

return socksReply;
}

private int ReadByte(MemoryStream stream)
{
var buffer = new byte[1];
if (stream.Read(buffer, 0, 1) == 0)
return -1;

return buffer[0];
}
}
}
}
59 changes: 57 additions & 2 deletions src/Renci.SshNet/ForwardedPortRemote.cs
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,53 @@ public ForwardedPortRemote(string boundHost, uint boundPort, string host, uint p
{
}


/// <summary>
/// Initializes a new instance of the <see cref="ForwardedPortRemote"/> class.
/// </summary>
/// <param name="boundPort">The bound port.</param>
/// <example>
/// <code source="..\..\src\Renci.SshNet.Tests\Classes\ForwardedPortRemoteTest.cs" region="Example SshClient AddForwardedPort Start Stop ForwardedPortRemote" language="C#" title="Remote port forwarding" />
/// </example>
public ForwardedPortRemote(uint boundPort)
: this (string.Empty, boundPort)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="ForwardedPortRemote"/> class.
/// </summary>
/// <param name="boundHost">The bound host.</param>
/// <param name="boundPort">The bound port.</param>
/// <example>
/// <code source="..\..\src\Renci.SshNet.Tests\Classes\ForwardedPortRemoteTest.cs" region="Example SshClient AddForwardedPort Start Stop ForwardedPortRemote" language="C#" title="Remote port forwarding" />
/// </example>
public ForwardedPortRemote(string boundHost, uint boundPort)
: this(DnsAbstraction.GetHostAddresses(boundHost)[0],
boundPort)
{
}

/// <summary>
/// Initializes a new instance of the <see cref="ForwardedPortRemote" /> class.
/// </summary>
/// <param name="boundHostAddress">The bound host address.</param>
/// <param name="boundPort">The bound port.</param>
/// <exception cref="ArgumentNullException"><paramref name="boundHostAddress"/> is <c>null</c>.</exception>
/// <exception cref="ArgumentOutOfRangeException"><paramref name="boundPort" /> is greater than <see cref="F:System.Net.IPEndPoint.MaxPort" />.</exception>
public ForwardedPortRemote(IPAddress boundHostAddress, uint boundPort)
{
if (boundHostAddress == null)
throw new ArgumentNullException("boundHostAddress");

boundPort.ValidatePort("boundPort");

BoundHostAddress = boundHostAddress;
BoundPort = boundPort;

_status = ForwardedPortStatus.Stopped;
}

/// <summary>
/// Starts remote port forwarding.
/// </summary>
Expand All @@ -151,7 +198,7 @@ protected override void StartPort()

// send global request to start forwarding
Session.SendMessage(new TcpIpForwardGlobalRequestMessage(BoundHost, BoundPort));
// wat for response on global request to start direct tcpip
// wait for response on global request to start direct tcpip
Session.WaitOnHandle(_globalRequestResponse);

if (!_requestStatus)
Expand Down Expand Up @@ -250,7 +297,15 @@ private void Session_ChannelOpening(object sender, MessageEventArgs<ChannelOpenM
channelOpenMessage.InitialWindowSize, channelOpenMessage.MaximumPacketSize))
{
channel.Exception += Channel_Exception;
channel.Bind(new IPEndPoint(HostAddress, (int) Port), this);

if (HostAddress == null)
{
channel.Bind(null, this); // Get HostAddress vom SOCKS in Data
}
else
{
channel.Bind(new IPEndPoint(HostAddress, (int)Port), this);
}
}
}
catch (Exception exp)
Expand Down