Skip to content

Commit

Permalink
Added .just and .fail static funcs to dry code a little bit
Browse files Browse the repository at this point in the history
  • Loading branch information
yury committed Feb 22, 2021
1 parent 918867b commit 7381c2f
Show file tree
Hide file tree
Showing 12 changed files with 76 additions and 60 deletions.
4 changes: 2 additions & 2 deletions Blink/Commands/ssh/CopyFiles.swift
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ public class BlinkCopy: NSObject {
}

func localTranslator(to path: String) -> AnyPublisher<Translator, Error> {
return Just(BlinkFiles.Local()).mapError {$0 as Error}.eraseToAnyPublisher()
return .just(BlinkFiles.Local())
}

func remoteTranslator(toFilePath filePath: String, atHost hostPath: String, using proto: BlinkFilesProtocols, isSource: Bool = true) -> AnyPublisher<Translator, Error> {
Expand All @@ -236,7 +236,7 @@ public class BlinkCopy: NSObject {
sshOptions = try sshCommand.connectionOptions.get()
} catch {
let message = SSHCommand.message(for: error)
return Fail(error: CommandError(message: message)).eraseToAnyPublisher()
return .fail(error: CommandError(message: message))
}

let config = SSHClientConfigProvider.config(command: sshCommand, config: sshOptions, using: device)
Expand Down
2 changes: 1 addition & 1 deletion Blink/Commands/ssh/SSHPool.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class SSHPool {
guard let conn = connection(for: host, with: config) else {
return shared.startConnection(host, with: config, proxy: proxy)
}
return Just(conn).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(conn)
}

private func startConnection(_ host: String, with config: SSHClientConfig,
Expand Down
9 changes: 6 additions & 3 deletions BlinkFiles/CopyFiles.swift
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ extension Translator {
return self.create(name: name, flags: O_WRONLY, mode: S_IRWXU)
.flatMap { f -> CopyProgressInfo in
if size == 0 {
return Just((name, 0, 0)).mapError { $0 as Error }.eraseToAnyPublisher()
return Just((name, 0, 0)).setFailureType(to: Error.self).eraseToAnyPublisher()
}
return f.copyFile(from: t, name: name, size: size)
}.eraseToAnyPublisher()
Expand Down Expand Up @@ -117,7 +117,8 @@ extension File {
print("File Copied bytes \(totalWritten)")
totalWritten += written
let report = Just((name, fileSize, written))
.mapError { $0 as Error }.eraseToAnyPublisher()
.setFailureType(to: Error.self)
.eraseToAnyPublisher()

if totalWritten == fileSize {
// Close and send the final report
Expand All @@ -134,7 +135,9 @@ extension File {
.tryCatch { error -> CopyProgressInfo in
// Closing the file while reading may provoke an error. Capture it here and if we are done, we ignore it.
if totalWritten == fileSize {
return Just((name, fileSize, 0)).mapError {$0 as Error}.eraseToAnyPublisher()
return Just((name, fileSize, 0))
.setFailureType(to: Error.self)
.eraseToAnyPublisher()
} else {
throw error
}
Expand Down
11 changes: 11 additions & 0 deletions SSH/Publishers.swift
Original file line number Diff line number Diff line change
Expand Up @@ -181,6 +181,17 @@ extension AnyPublisher where Output == ssh_session, Failure == Error {
}
}

// TODO: Move to own module?
public extension AnyPublisher {
@inlinable static func just(_ output: Output) -> Self {
.init(Just(output).setFailureType(to: Failure.self))
}

@inlinable static func fail(error: Failure) -> Self {
.init(Fail(error: error))
}
}

fileprivate class UnfairLock {
private var pLock: UnsafeMutablePointer<os_unfair_lock>

Expand Down
16 changes: 8 additions & 8 deletions SSH/SCP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,9 @@ public class SCPClient: CopierFrom, CopierTo {
public static func execute(using ssh: SSHClient,
as mode: SCPMode,
root quotedPath: String) -> AnyPublisher<SCPClient, Error> {
guard let scp = ssh_scp_new(ssh.session, Int32(mode.scpMode()), quotedPath) else {
return Fail(error: SSHError(-1, forSession: ssh.session))
.eraseToAnyPublisher()
guard let scp = ssh_scp_new(ssh.session, Int32(mode.scpMode()), quotedPath)
else {
return .fail(error: SSHError(-1, forSession: ssh.session))
}

let client = SCPClient(scp, client: ssh)
Expand Down Expand Up @@ -130,7 +130,7 @@ public class SCPClient: CopierFrom, CopierTo {

// Wrap each scp call so we are sure it is running in the proper place.
func connection() -> AnyPublisher<ssh_scp, Error> {
return Just(scp).subscribe(on: self.ssh.rloop).mapError { $0 as Error }.eraseToAnyPublisher()
AnyPublisher.just(scp).subscribe(on: self.ssh.rloop).eraseToAnyPublisher()
}
}

Expand Down Expand Up @@ -191,7 +191,7 @@ extension SCPClient {
}.flatMap { t -> CopyProgressInfo in
// On empty file, just report, nothing to copy
if size == 0 {
return Just((name, 0, 0)).mapError { $0 as Error }.eraseToAnyPublisher()
return .just((name, 0, 0))
}
return self.copyFileFrom(t, name: name, size: size)
}
Expand Down Expand Up @@ -335,9 +335,9 @@ extension SCPClient {
case SSH_SCP_REQUEST_ENDDIR:
currentDir = dirsFifo[0]
dirsFifo.removeFirst()
return Just(req).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(req)
default:
return Just(req).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(req)
}
}
.filter { $0 == SSH_SCP_REQUEST_NEWFILE }
Expand Down Expand Up @@ -448,7 +448,7 @@ extension SCPClient {
})
.flatMap(maxPublishers: .max(1)) { data -> AnyPublisher<Int, Error> in
if data.count == 0 {
return Just(0).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(0)
}
return t.write(data, max: data.count)
}
Expand Down
6 changes: 3 additions & 3 deletions SSH/SFTP.swift
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ public class SFTPClient : BlinkFiles.Translator {

public func directoryFilesAndAttributes() -> AnyPublisher<[FileAttributes], Error> {
if fileType != .typeDirectory {
return Fail(error: FileError(title: "Not a directory.", in: session)).eraseToAnyPublisher()
return .fail(error: FileError(title: "Not a directory.", in: session))
}

return connection().trySFTP { sftp -> [FileAttributes] in
Expand Down Expand Up @@ -193,7 +193,7 @@ public class SFTPClient : BlinkFiles.Translator {

public func open(flags: Int32) -> AnyPublisher<File, Error> {
if fileType != .typeRegular {
return Fail(error: FileError(title: "Not a file.", in: session)).eraseToAnyPublisher()
return .fail(error: FileError(title: "Not a file.", in: session))
}

return connection().trySFTP { sftp -> SFTPFile in
Expand All @@ -207,7 +207,7 @@ public class SFTPClient : BlinkFiles.Translator {

public func create(name: String, flags: Int32, mode: mode_t = S_IRWXU) -> AnyPublisher<BlinkFiles.File, Error> {
if fileType != .typeDirectory {
return Fail(error: FileError(title: "Not a directory.", in: session)).eraseToAnyPublisher()
return .fail(error: FileError(title: "Not a directory.", in: session))
}

return connection().trySFTP { sftp -> SFTPFile in
Expand Down
59 changes: 28 additions & 31 deletions SSH/SSHClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import Foundation
import Network
import LibSSH


typealias SSHConnection = AnyPublisher<ssh_session, Error>
typealias SSHChannel = AnyPublisher<ssh_channel, Error>

Expand Down Expand Up @@ -311,16 +310,16 @@ public class SSHClient {
}

func connection() -> SSHConnection {
Just(self.session).mapError{ $0 as Error}.subscribe(on: rloop)
AnyPublisher.just(self.session).subscribe(on: rloop)
.eraseToAnyPublisher()
}

func newChannel() -> SSHChannel {
guard let channel = ssh_channel_new(self.session) else {
return Fail(error: SSHError(title: "Could not create channel")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not create channel"))
}
ssh_channel_set_blocking(channel, 0)
return Just(channel).mapError { $0 as Error }.subscribe(on: rloop)
return AnyPublisher.just(channel).subscribe(on: rloop)
.eraseToAnyPublisher()
}

Expand All @@ -331,7 +330,7 @@ public class SSHClient {
do {
c = try SSHClient(to: host, with: opts, proxyCb: proxyCb)
} catch {
return Fail(error: error).eraseToAnyPublisher()
return .fail(error: error)
}

// Done this way we don't have to handle cancellations here.
Expand All @@ -345,14 +344,14 @@ public class SSHClient {
if client.options.requestVerifyHostCallback != nil {
return client.verifyKnownHost()
}
return Just(client).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(client)
}
.flatMap{ $0.auth() }
.flatMap{ client -> AnyPublisher<SSHClient, Error> in
if client.options.keepAliveInterval != nil {
client.startKeepAliveTimer()
}
return Just(client).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(client)
}
// If cancelled, the connection will be closed without being passed to the user or
// once the command is dumped.
Expand Down Expand Up @@ -401,13 +400,12 @@ public class SSHClient {
// None of the calls here require to contact the server, so we can do them without wrapping.
let rc = ssh_get_server_publickey(session, &serverPublicKey)
if rc < 0 {
return Fail(error: SSHError(title: "Could not get server publickey")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not get server publickey"))
}

let state = ssh_get_publickey_hash(serverPublicKey, SSH_PUBLICKEY_HASH_SHA256, &hash, &hlen)
if state < 0 {
return Fail(error: SSHError(title: "Could not get server publickey hash"))
.eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not get server publickey hash"))
}

let hexString = String(cString: ssh_get_hexa(hash, hlen))
Expand All @@ -416,20 +414,19 @@ public class SSHClient {
let rc3 = ssh_session_is_known_server(session)
switch rc3 {
case SSH_KNOWN_HOSTS_OK:
return Just(self).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(self)

case SSH_KNOWN_HOSTS_CHANGED:
return self.options.requestVerifyHostCallback!(.changed(serverFingerprint: hexString)).flatMap { answer -> AnyPublisher<SSHClient, Error> in
if answer == .affirmative {
let rc = ssh_session_update_known_hosts(self.session)
if rc != SSH_OK {
return Fail(error: SSHError(title: "Could not update known_hosts file.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not update known_hosts file."))
}

return Just(self).mapError({ $0 as Error }).eraseToAnyPublisher()
return .just(self)
}

return Fail(error: SSHError(title: "Could not verify host authenticity.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not verify host authenticity."))
}.eraseToAnyPublisher()

case SSH_KNOWN_HOSTS_UNKNOWN:
Expand All @@ -438,41 +435,41 @@ public class SSHClient {
let rc = ssh_session_update_known_hosts(self.session)

if rc < 0 {
return Fail(error: SSHError(title: "Error updating known_hosts file.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Error updating known_hosts file."))
}

return Just(self).mapError({ $0 as Error }).eraseToAnyPublisher()
return .just(self)
}

return Fail(error: SSHError(title: "Could not verify host authenticity.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not verify host authenticity."))
}.eraseToAnyPublisher()


/// The server gave use a key of a type while we had an other type recorded. It is a possible attack.
case SSH_KNOWN_HOSTS_OTHER:
// Stop connection because we could not verify the authenticity. And we could make the other side dispaly it.
return Fail(error: SSHError(title: "The server gave use a key of a type while we had an other type recorded. It is a possible attack.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "The server gave use a key of a type while we had an other type recorded. It is a possible attack."))
/// There had been an eror checking the host.
case SSH_KNOWN_HOSTS_ERROR:
return Fail(error: SSHError(title: "Could not verify host authenticity.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not verify host authenticity."))
/// The known host file does not exist. The host is thus unknown. File will be created if host key is accepted
case SSH_KNOWN_HOSTS_NOT_FOUND:
return self.options.requestVerifyHostCallback!(.notFound(serverFingerprint: hexString)).flatMap { answer -> AnyPublisher<SSHClient, Error> in
if answer == .affirmative {
let rc = ssh_session_update_known_hosts(self.session)

if rc != SSH_OK {
return Fail(error: SSHError(title: "Error updating known_hosts file.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Error updating known_hosts file."))
}

return Just(self).mapError({ $0 as Error }).eraseToAnyPublisher()
return .just(self)
}

return Fail(error: SSHError(title: "Could not verify host authenticity.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Could not verify host authenticity."))
}.eraseToAnyPublisher()

default:
return Fail(error: SSHError(title: "Unknown code received during host key exchange. Possible library error.")).eraseToAnyPublisher()
return .fail(error: SSHError(title: "Unknown code received during host key exchange. Possible library error."))
}
}

Expand All @@ -488,7 +485,8 @@ public class SSHClient {
withTimeInterval: Double(timeout),
repeats: false) {_ in timerFired = true }
return conn
}.eraseToAnyPublisher()
}
.eraseToAnyPublisher()
.tryOperation { session in
if timerFired {
throw SSHError(title: "Connection to \(self.host) timed out.")
Expand Down Expand Up @@ -543,7 +541,7 @@ public class SSHClient {
// Return the Client if any method worked, otherwise return an error
func tryAuth(_ methods: [Authenticator], tried: [Authenticator]) -> AnyPublisher<SSHClient, Error> {
if methods.count == 0 {
return Fail(error: SSHError.authError(msg: "Could not authenticate, no valid methods to try.")).eraseToAnyPublisher()
return .fail(error: SSHError.authError(msg: "Could not authenticate, no valid methods to try."))
}

let method = methods.first!
Expand All @@ -554,7 +552,7 @@ public class SSHClient {
.flatMap { result -> AnyPublisher<SSHClient, Error> in
switch result {
case .Success:
return Just(self).mapError { $0 as Error }.eraseToAnyPublisher()
return .just(self)
case .Partial:
return tryAuth(self.validAuthMethods(), tried: tried)
default:
Expand All @@ -570,7 +568,7 @@ public class SSHClient {
tried.append(methods.removeFirst())

// Return a failure and close the connection that's still open
return Fail(error: SSHError.authFailed(methods: tried)).eraseToAnyPublisher()
return .fail(error: SSHError.authFailed(methods: tried))
}

tried.append(methods.removeFirst())
Expand Down Expand Up @@ -721,8 +719,7 @@ public class SSHClient {
public func requestReverseForward(bindTo address: String?, port: Int32) -> AnyPublisher<Stream, Error> {
if port == 0 {
// TODO Passing 0 will allocate an available port, and be returned at the last param on listen
return Fail(error: SSHError(title: "Ask server to allocate a bound port not allowed"))
.eraseToAnyPublisher()
return .fail(error: SSHError(title: "Ask server to allocate a bound port not allowed"))
}

self.log.message("REVERSE Forward requested to address \(address) on port \(port)", SSH_LOG_INFO)
Expand Down Expand Up @@ -775,7 +772,7 @@ public class SSHClient {
return vars.enumerated().publisher
.flatMap { (_, arg1) -> AnyPublisher<ssh_channel, Error> in
let (key, value) = arg1
return Just(channel).mapError {$0 as Error}.eraseToAnyPublisher()
return .just(channel)
.tryChannel { channel in
self.log.message("Requesting Env Var \(key)", SSH_LOG_INFO)
let rc = ssh_channel_request_env(channel, key, value)
Expand Down
12 changes: 8 additions & 4 deletions SSH/Streams.swift
Original file line number Diff line number Diff line change
Expand Up @@ -171,29 +171,33 @@ public class Stream : Reader, Writer, WriterTo {
}

public func sendEOF() -> AnyPublisher<Void, Error> {
return Just(channel).mapError { $0 as Error }.eraseToAnyPublisher()
return AnyPublisher
.just(channel)
.tryChannel { chan in
let rc = ssh_channel_send_eof(self.channel)
if rc != SSH_OK {
throw SSHError(rc, forSession: self.client.session)
}
self.stdinCancellable?.cancel()
}.subscribe(on: client.rloop)
}
.subscribe(on: client.rloop)
.eraseToAnyPublisher()
}

/**
* Resize the current stream.
*/
public func resizePty(rows: Int32, columns: Int32) -> AnyPublisher<Void, Error> {
return Just(channel).mapError { $0 as Error }.eraseToAnyPublisher()
return AnyPublisher
.just(channel)
.tryChannel { chan in
self.log.message("Resizing PTY: \(rows)x\(columns)", SSH_LOG_INFO)
let rc = ssh_channel_change_pty_size(self.channel, columns, rows)
if rc != SSH_OK {
throw SSHError(rc, forSession: self.client.session)
}
}.subscribe(on: client.rloop)
}
.subscribe(on: client.rloop)
.eraseToAnyPublisher()
}

Expand Down
Loading

0 comments on commit 7381c2f

Please sign in to comment.