Skip to content

Commit

Permalink
Reverse fwd supports port 0 and improved connections
Browse files Browse the repository at this point in the history
- Port 0 is now allowed to be randomly assigned by server and mapped back from
our side once the first connection is done.
- Map a new connection per new channel.
  • Loading branch information
Carlos Cabanero authored and Carlos Cabanero committed Mar 24, 2021
1 parent b4dba42 commit 49ea563
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 58 deletions.
18 changes: 13 additions & 5 deletions SSH/SSHClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -788,9 +788,8 @@ public class SSHClient {
}

public func requestReverseForward(bindTo address: String?, port: Int32) -> AnyPublisher<Stream, Error> {
if port == 0 {
// TODO Passing 0 will allocate an available port, and be returned at the last param on listen
return .fail(error: SSHError(title: "Ask server to allocate a bound port not allowed"))
if let _ = self.reversePorts[port] {
return .fail(error: SSHError(title: "Reverse forward already exits for that port."))
}

self.log.message("REVERSE Forward requested to address \(address) on port \(port)", SSH_LOG_INFO)
Expand All @@ -800,8 +799,16 @@ public class SSHClient {

ctxt.log.message("REVERSE Forward callback", SSH_LOG_INFO)

guard let pub = ctxt.reversePorts[Int32(port)] else {
// Should never happen
let port = Int32(port)

// If there is no associated port, check if it may be on 0
if ctxt.reversePorts[port] == nil {
if let _ = ctxt.reversePorts[0] {
ctxt.reversePorts[port] = ctxt.reversePorts.removeValue(forKey: 0)
}
}

guard let pub = ctxt.reversePorts[port] else {
return nil
}

Expand All @@ -820,6 +827,7 @@ public class SSHClient {
.tryOperation { session -> Int32 in
self.log.message("REVERSE Starting listener to forward on remote", SSH_LOG_INFO)

// We could pass the callback here, and then have somewhere on the libssh side a way to map
let rc = ssh_channel_listen_forward(session, address, port, nil)
if rc != SSH_OK {
throw SSHError(rc, forSession: session)
Expand Down
108 changes: 57 additions & 51 deletions SSH/SSHPortForward.swift
Original file line number Diff line number Diff line change
Expand Up @@ -217,93 +217,99 @@ public class SSHPortForwardListener {

public class SSHPortForwardClient {
let client: SSHClient
let conn: NWConnection
let forwardHost: NWEndpoint.Host
let localPort: NWEndpoint.Port
let queue: DispatchQueue
let remotePort: NWEndpoint.Port

let bindAddress: String?

var log: SSHLogger { get { client.log } }

let status = CurrentValueSubject<PortForwardState, Error>(.starting)
var isReady = false

var reverseForward: AnyCancellable?
var streams: [Stream] = []

// The listener is on the other side, here we just connect to a local port
public init(forward address: String, onPort localPort: UInt16,
toRemotePort remotePort: UInt16, using client: SSHClient) {
let p = NWEndpoint.Port(integerLiteral: localPort)
let host = NWEndpoint.Host(address)
self.conn = NWConnection(host: host, port: p, using: .tcp)
toRemotePort remotePort: UInt16, bindAddress: String? = nil, using client: SSHClient) {
self.localPort = NWEndpoint.Port(integerLiteral: localPort)
self.forwardHost = NWEndpoint.Host(address)
self.remotePort = NWEndpoint.Port(integerLiteral: remotePort)
self.queue = DispatchQueue(label: "r-fwd-\(localPort)")
self.bindAddress = bindAddress
self.client = client
}

public func connect() -> AnyPublisher<PortForwardState, Error> {
// TODO Expose address to bind to on remote server
// This is a different case than regular forward,
// because here we serve the requests from the other side,
// so the streams are received instead of generated here.
var stream: Stream?

self.conn.stateUpdateHandler = { state in
self.log.message("Listener state Updated \(state)", SSH_LOG_INFO)
self.isReady = false

switch state {
case .ready:
self.status.send(PortForwardState.ready)
self.isReady = true
startReverse()
case .waiting(let error):
// Just notify, the connection itself will be reopened after a wait.
self.status.send(PortForwardState.waiting(error))
case .failed(let error):
self.status.send(completion: .failure(SSHPortForwardError(title: "Connection state failed", error)))
default:
break
}
}

func startReverse() {
reverseForward = self.client.requestReverseForward(bindTo: nil, port: Int32(remotePort.rawValue))
.sink(receiveCompletion: { completion in
reverseForward = self.client.requestReverseForward(bindTo: bindAddress, port: Int32(remotePort.rawValue))
.sink(
receiveCompletion: { completion in
switch completion {
// If the Reverse Forward is closed, then close the connection.
// If the Reverse Forward is closed, then close the connection.
case .finished:
return self.close()
case .failure(let error):
self.status.send(completion: .failure(error))
self.isReady = false
self.close()
}
}, receiveValue: { s in
self.log.message("Reverse stream received. Connecting stream", SSH_LOG_INFO)
stream = s
s.connect(stdout: self.conn, stdin: self.conn)
s.handleCompletion = {
stream = nil
}
s.handleFailure = { error in
stream = nil
self.status.send(.error(error))
}
})
}

self.conn.start(queue: queue)

},
receiveValue: receive)

return status.eraseToAnyPublisher()
}

public func close() {
log.message("Closing Reverse Forward", SSH_LOG_INFO)
if isReady {
// Note we are not cancelling the already open connections
reverseForward?.cancel()
self.status.send(completion: .finished)
}

self.conn.cancel()
}

private func receive(stream: Stream) {
self.log.message("Reverse stream received. Establishing connection and piping stream", SSH_LOG_INFO)

self.streams.append(stream)
let conn = NWConnection(host: self.forwardHost, port: self.localPort, using: .tcp)
conn.stateUpdateHandler = { (state: NWConnection.State) in
self.log.message("Connection state Updated \(state)", SSH_LOG_INFO)
self.isReady = false

switch state {
case .ready:
// Notify that a connection has been established.
self.status.send(PortForwardState.ready)
self.isReady = true
case .waiting(let error):
// Just notify, the connection itself will be reopened after a wait.
self.status.send(PortForwardState.waiting(error))
case .failed(let error):
self.status.send(completion: .failure(SSHPortForwardError(title: "Connection state failed", error)))
default:
break
}
}
conn.start(queue: self.queue)
stream.connect(stdout: conn, stdin: conn)

func removeStream() {
if let idx = self.streams.firstIndex(where: { stream === $0 }) {
self.streams.remove(at: idx)
}
}
stream.handleCompletion = {
removeStream()
}
stream.handleFailure = { error in
removeStream()
self.status.send(.error(error))
}
}
}

Expand Down
7 changes: 5 additions & 2 deletions SSHTests/SSHPortForwardTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -231,13 +231,16 @@ extension SSHTests {
client = SSHPortForwardClient(forward: "www.guimp.com", onPort: 80,
toRemotePort: 8080, using: conn)
return client!
}.flatMap { $0.connect() }
}.flatMap { c -> AnyPublisher<PortForwardState, Error> in
expectForward.fulfill()
return c.connect()
}
.assertNoFailure()
.sink { event in
print("Received \(event)")
switch event {
case .ready:
expectForward.fulfill()
break
case .error(let error):
XCTFail("\(error)")
default:
Expand Down

0 comments on commit 49ea563

Please sign in to comment.