Skip to content

Commit

Permalink
Fix concurrency issues on SSHPool
Browse files Browse the repository at this point in the history
  • Loading branch information
Carlos Cabanero committed Jun 10, 2024
1 parent efa84d2 commit 840b3c1
Showing 1 changed file with 50 additions and 38 deletions.
88 changes: 50 additions & 38 deletions Blink/Commands/ssh/SSHPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,12 +41,13 @@ import SSH
class SSHPool {
static let shared = SSHPool()
private var controls: [SSHClientControl] = []

private let queue = DispatchQueue(label: "SSHPoolControlQueue", attributes: .concurrent)

private init() {}

static func dial(_ host: String,
with config: SSHClientConfig,
withControlMaster: ControlMasterOption = .no,
static func dial(_ host: String,
with config: SSHClientConfig,
withControlMaster: ControlMasterOption = .no,
withProxy proxy: SSH.SSHClient.ExecProxyCommandCallback? = nil) -> AnyPublisher<SSH.SSHClient, Error> {

// Do not use an existing socket.
Expand All @@ -62,7 +63,7 @@ class SSHPool {
shared.removeControl(ctrl)
}
}

return shared.startConnection(host, with: config, proxy: proxy)
}

Expand All @@ -84,7 +85,9 @@ class SSHPool {
},
receiveValue: { conn in
let control = SSHClientControl(for: conn, on: host, with: config, running: runLoop, exposed: exposed)
SSHPool.shared.controls.append(control)
self.queue.sync {
SSHPool.shared.controls.append(control)
}
pb.send(conn)
})

Expand All @@ -111,13 +114,21 @@ class SSHPool {
}

private static func control(on connection: SSH.SSHClient) -> SSHClientControl? {
shared.controls.first { $0.connection === connection }
shared.control(on: connection)
}

private func control(on connection: SSH.SSHClient) -> SSHClientControl? {
queue.sync {
controls.first { $0.connection === connection }
}
}

private func control(for host: String, with config: SSHClientConfig) -> SSHClientControl? {
return controls.first { $0.isConnection(for: host, with: config) }
queue.sync {
controls.first { $0.isConnection(for: host, with: config) }
}
}

private func enforcePersistance(_ control: SSHClientControl) {
print("Current channels \(control.numChannels)")
print("\(control.localTunnels)")
Expand Down Expand Up @@ -153,7 +164,7 @@ extension SSHPool {
c.numShells += 1
}
}

static func deregister(shellOn connection: SSH.SSHClient) {
guard let c = control(on: connection) else {
return
Expand All @@ -165,13 +176,13 @@ extension SSHPool {

// Forward Tunnels
extension SSHPool {
static func register(_ listener: SSHPortForwardListener,
portForwardInfo: PortForwardInfo,
static func register(_ listener: SSHPortForwardListener,
portForwardInfo: PortForwardInfo,
on connection: SSH.SSHClient) {
let c = control(on: connection)
c?.localTunnels[portForwardInfo] = listener
}

static func deregister(localForward: PortForwardInfo, on connection: SSH.SSHClient) {
guard let c = control(on: connection) else {
return
Expand All @@ -186,15 +197,15 @@ extension SSHPool {
guard let c = control(on: connection) else {
return false
}

return c.localTunnels[localForward] != nil
}
}

// Remote Tunnels
extension SSHPool {
static func register(_ client: SSHPortForwardClient,
portForwardInfo: PortForwardInfo,
static func register(_ client: SSHPortForwardClient,
portForwardInfo: PortForwardInfo,
on connection: SSH.SSHClient) {
let c = control(on: connection)
c?.remoteTunnels[portForwardInfo] = client
Expand All @@ -214,7 +225,7 @@ extension SSHPool {
guard let c = control(on: connection) else {
return false
}

return c.remoteTunnels[remoteForward] != nil
}
}
Expand Down Expand Up @@ -242,7 +253,7 @@ extension SSHPool {
guard let c = control(on: connection) else {
return false
}

return c.socks[socksBindAddress] != nil
}
}
Expand All @@ -252,21 +263,23 @@ extension SSHPool {
let c = control(on: connection)
c?.streams.append((command, stream))
}

private func removeControl(_ control: SSHClientControl) {
// For now, we just stop the connection as is
// We could use a delegate just to notify when a connection is dead, and the control could
// take care of figuring out when the connection it contains must go.
guard
let idx = controls.firstIndex(where: { $0 === control })
else {
return
queue.async(flags: .barrier) {
// For now, we just stop the connection as is
// We could use a delegate just to notify when a connection is dead, and the control could
// take care of figuring out when the connection it contains must go.
guard
let idx = self.controls.firstIndex(where: { $0 === control })
else {
return
}

// Removing references to connection to deinit.
// We could also handle the pool with references to the connection.
// But the shell or time based persistance may become more difficult.
self.controls.remove(at: idx)
}

// Removing references to connection to deinit.
// We could also handle the pool with references to the connection.
// But the shell or time based persistance may become more difficult.
controls.remove(at: idx)
}
}

Expand All @@ -276,7 +289,7 @@ fileprivate class SSHClientControl {
let config: SSHClientConfig
let runLoop: RunLoop
let exposed: Bool

var numShells: Int = 0
//var shells: [(SSHCommand, SSH.Stream)] = []

Expand All @@ -291,16 +304,16 @@ fileprivate class SSHClientControl {
return numShells + streams.count + localTunnels.count + remoteTunnels.count + socks.count
}
}

init(for connection: SSH.SSHClient, on host: String, with config: SSHClientConfig, running runLoop: RunLoop, exposed: Bool) {
self.connection = connection
self.host = host
self.config = config
self.runLoop = runLoop
self.exposed = exposed
}


// Other parameters could specify how the connection should be treated by the pool
// (timeouts, etc...)
func isConnection(for host: String, with config: SSHClientConfig) -> Bool {
Expand All @@ -311,7 +324,7 @@ fileprivate class SSHClientControl {
return self.host == host && config == self.config ? true : false
}
}
/*
/*
fileprivate protocol TunnelControl {
func close()
}
Expand All @@ -334,4 +347,3 @@ extension OptionalBindAddressInfo: Hashable {
hasher.combine(self.port)
}
}

0 comments on commit 840b3c1

Please sign in to comment.