From 49ea56331c2d866087c6104a3a83dd1a2f7bc52e Mon Sep 17 00:00:00 2001 From: Carlos Cabanero Date: Wed, 24 Mar 2021 16:05:54 -0400 Subject: [PATCH] Reverse fwd supports port 0 and improved connections - Port 0 is now allowed to be randomly assigned by server and mapped back from our side once the first connection is done. - Map a new connection per new channel. --- SSH/SSHClient.swift | 18 +++-- SSH/SSHPortForward.swift | 108 +++++++++++++++-------------- SSHTests/SSHPortForwardTests.swift | 7 +- 3 files changed, 75 insertions(+), 58 deletions(-) diff --git a/SSH/SSHClient.swift b/SSH/SSHClient.swift index 0d5733640..f131efcef 100644 --- a/SSH/SSHClient.swift +++ b/SSH/SSHClient.swift @@ -788,9 +788,8 @@ public class SSHClient { } public func requestReverseForward(bindTo address: String?, port: Int32) -> AnyPublisher { - if port == 0 { - // TODO Passing 0 will allocate an available port, and be returned at the last param on listen - return .fail(error: SSHError(title: "Ask server to allocate a bound port not allowed")) + if let _ = self.reversePorts[port] { + return .fail(error: SSHError(title: "Reverse forward already exits for that port.")) } self.log.message("REVERSE Forward requested to address \(address) on port \(port)", SSH_LOG_INFO) @@ -800,8 +799,16 @@ public class SSHClient { ctxt.log.message("REVERSE Forward callback", SSH_LOG_INFO) - guard let pub = ctxt.reversePorts[Int32(port)] else { - // Should never happen + let port = Int32(port) + + // If there is no associated port, check if it may be on 0 + if ctxt.reversePorts[port] == nil { + if let _ = ctxt.reversePorts[0] { + ctxt.reversePorts[port] = ctxt.reversePorts.removeValue(forKey: 0) + } + } + + guard let pub = ctxt.reversePorts[port] else { return nil } @@ -820,6 +827,7 @@ public class SSHClient { .tryOperation { session -> Int32 in self.log.message("REVERSE Starting listener to forward on remote", SSH_LOG_INFO) + // We could pass the callback here, and then have somewhere on the libssh side a way to map let rc = ssh_channel_listen_forward(session, address, port, nil) if rc != SSH_OK { throw SSHError(rc, forSession: session) diff --git a/SSH/SSHPortForward.swift b/SSH/SSHPortForward.swift index 6c08c0243..670cf45d9 100644 --- a/SSH/SSHPortForward.swift +++ b/SSH/SSHPortForward.swift @@ -217,59 +217,39 @@ public class SSHPortForwardListener { public class SSHPortForwardClient { let client: SSHClient - let conn: NWConnection + let forwardHost: NWEndpoint.Host + let localPort: NWEndpoint.Port let queue: DispatchQueue let remotePort: NWEndpoint.Port - + let bindAddress: String? + var log: SSHLogger { get { client.log } } let status = CurrentValueSubject(.starting) var isReady = false var reverseForward: AnyCancellable? + var streams: [Stream] = [] - // The listener is on the other side, here we just connect to a local port public init(forward address: String, onPort localPort: UInt16, - toRemotePort remotePort: UInt16, using client: SSHClient) { - let p = NWEndpoint.Port(integerLiteral: localPort) - let host = NWEndpoint.Host(address) - self.conn = NWConnection(host: host, port: p, using: .tcp) + toRemotePort remotePort: UInt16, bindAddress: String? = nil, using client: SSHClient) { + self.localPort = NWEndpoint.Port(integerLiteral: localPort) + self.forwardHost = NWEndpoint.Host(address) self.remotePort = NWEndpoint.Port(integerLiteral: remotePort) self.queue = DispatchQueue(label: "r-fwd-\(localPort)") + self.bindAddress = bindAddress self.client = client } public func connect() -> AnyPublisher { - // TODO Expose address to bind to on remote server // This is a different case than regular forward, // because here we serve the requests from the other side, // so the streams are received instead of generated here. - var stream: Stream? - - self.conn.stateUpdateHandler = { state in - self.log.message("Listener state Updated \(state)", SSH_LOG_INFO) - self.isReady = false - - switch state { - case .ready: - self.status.send(PortForwardState.ready) - self.isReady = true - startReverse() - case .waiting(let error): - // Just notify, the connection itself will be reopened after a wait. - self.status.send(PortForwardState.waiting(error)) - case .failed(let error): - self.status.send(completion: .failure(SSHPortForwardError(title: "Connection state failed", error))) - default: - break - } - } - - func startReverse() { - reverseForward = self.client.requestReverseForward(bindTo: nil, port: Int32(remotePort.rawValue)) - .sink(receiveCompletion: { completion in + reverseForward = self.client.requestReverseForward(bindTo: bindAddress, port: Int32(remotePort.rawValue)) + .sink( + receiveCompletion: { completion in switch completion { - // If the Reverse Forward is closed, then close the connection. + // If the Reverse Forward is closed, then close the connection. case .finished: return self.close() case .failure(let error): @@ -277,33 +257,59 @@ public class SSHPortForwardClient { self.isReady = false self.close() } - }, receiveValue: { s in - self.log.message("Reverse stream received. Connecting stream", SSH_LOG_INFO) - stream = s - s.connect(stdout: self.conn, stdin: self.conn) - s.handleCompletion = { - stream = nil - } - s.handleFailure = { error in - stream = nil - self.status.send(.error(error)) - } - }) - } - - self.conn.start(queue: queue) - + }, + receiveValue: receive) + return status.eraseToAnyPublisher() } public func close() { log.message("Closing Reverse Forward", SSH_LOG_INFO) if isReady { + // Note we are not cancelling the already open connections reverseForward?.cancel() self.status.send(completion: .finished) } - - self.conn.cancel() + } + + private func receive(stream: Stream) { + self.log.message("Reverse stream received. Establishing connection and piping stream", SSH_LOG_INFO) + + self.streams.append(stream) + let conn = NWConnection(host: self.forwardHost, port: self.localPort, using: .tcp) + conn.stateUpdateHandler = { (state: NWConnection.State) in + self.log.message("Connection state Updated \(state)", SSH_LOG_INFO) + self.isReady = false + + switch state { + case .ready: + // Notify that a connection has been established. + self.status.send(PortForwardState.ready) + self.isReady = true + case .waiting(let error): + // Just notify, the connection itself will be reopened after a wait. + self.status.send(PortForwardState.waiting(error)) + case .failed(let error): + self.status.send(completion: .failure(SSHPortForwardError(title: "Connection state failed", error))) + default: + break + } + } + conn.start(queue: self.queue) + stream.connect(stdout: conn, stdin: conn) + + func removeStream() { + if let idx = self.streams.firstIndex(where: { stream === $0 }) { + self.streams.remove(at: idx) + } + } + stream.handleCompletion = { + removeStream() + } + stream.handleFailure = { error in + removeStream() + self.status.send(.error(error)) + } } } diff --git a/SSHTests/SSHPortForwardTests.swift b/SSHTests/SSHPortForwardTests.swift index 740242339..793f7106d 100644 --- a/SSHTests/SSHPortForwardTests.swift +++ b/SSHTests/SSHPortForwardTests.swift @@ -231,13 +231,16 @@ extension SSHTests { client = SSHPortForwardClient(forward: "www.guimp.com", onPort: 80, toRemotePort: 8080, using: conn) return client! - }.flatMap { $0.connect() } + }.flatMap { c -> AnyPublisher in + expectForward.fulfill() + return c.connect() + } .assertNoFailure() .sink { event in print("Received \(event)") switch event { case .ready: - expectForward.fulfill() + break case .error(let error): XCTFail("\(error)") default: