diff --git a/openssl/src/ssl/mod.rs b/openssl/src/ssl/mod.rs index 024e4ca565..5a600053cf 100644 --- a/openssl/src/ssl/mod.rs +++ b/openssl/src/ssl/mod.rs @@ -1432,7 +1432,8 @@ impl SslContextBuilder { } } - /// Sets the callback for generating an application cookie for stateless handshakes. + /// Sets the callback for generating an application cookie for TLS1.3 + /// stateless handshakes. /// /// The callback will be called with the SSL context and a slice into which the cookie /// should be written. The callback should return the number of bytes written. @@ -1454,7 +1455,8 @@ impl SslContextBuilder { } } - /// Sets the callback for verifying an application cookie for stateless handshakes. + /// Sets the callback for verifying an application cookie for TLS1.3 + /// stateless handshakes. /// /// The callback will be called with the SSL context and the cookie supplied by the /// client. It should return true if and only if the cookie is valid. @@ -2632,22 +2634,7 @@ impl Ssl { where S: Read + Write, { - let mut stream = SslStream::new_base(self, stream); - let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock( - MidHandshakeSslStream { stream, error }, - )), - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } - } + SslStreamBuilder::new(self, stream).connect() } /// Initiates a server-side TLS handshake. @@ -2664,22 +2651,7 @@ impl Ssl { where S: Read + Write, { - let mut stream = SslStream::new_base(self, stream); - let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) }; - if ret > 0 { - Ok(stream) - } else { - let error = stream.make_error(ret); - match error.code() { - ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock( - MidHandshakeSslStream { stream, error }, - )), - _ => Err(HandshakeError::Failure(MidHandshakeSslStream { - stream, - error, - })), - } - } + SslStreamBuilder::new(self, stream).accept() } } @@ -2951,6 +2923,114 @@ impl Write for SslStream { } } +/// A partially constructed `SslStream`, useful for unusual handshakes. +pub struct SslStreamBuilder { + inner: SslStream +} + +impl SslStreamBuilder + where S: Read + Write +{ + /// Begin creating an `SslStream` atop `stream` + pub fn new(ssl: Ssl, stream: S) -> Self { + Self { + inner: SslStream::new_base(ssl, stream), + } + } + + /// Perform a stateless server-side handshake + /// + /// Requires that cookie generation and verification callbacks were + /// set on the SSL context. + /// + /// Returns `Ok(true)` if a complete ClientHello containing a valid cookie + /// was read, in which case the handshake should be continued via + /// `accept`. If a HelloRetryRequest containing a fresh cookie was + /// transmitted, `Ok(false)` is returned instead. If the handshake cannot + /// proceed at all, `Err` is returned. + /// + /// This corresponds to [`SSL_stateless`] + /// + /// [`SSL_stateless`]: https://www.openssl.org/docs/manmaster/man3/SSL_stateless.html + #[cfg(ossl111)] + pub fn stateless(&mut self) -> Result { + match unsafe { ffi::SSL_stateless(self.inner.ssl.as_ptr()) } { + 1 => Ok(true), + 0 => Ok(false), + -1 => Err(ErrorStack::get()), + _ => unreachable!(), + } + } + + /// See `Ssl::connect` + pub fn connect(self) -> Result, HandshakeError> { + let mut stream = self.inner; + let ret = unsafe { ffi::SSL_connect(stream.ssl.as_ptr()) }; + if ret > 0 { + Ok(stream) + } else { + let error = stream.make_error(ret); + match error.code() { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock( + MidHandshakeSslStream { stream, error }, + )), + _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + stream, + error, + })), + } + } + } + + /// See `Ssl::accept` + pub fn accept(self) -> Result, HandshakeError> { + let mut stream = self.inner; + let ret = unsafe { ffi::SSL_accept(stream.ssl.as_ptr()) }; + if ret > 0 { + Ok(stream) + } else { + let error = stream.make_error(ret); + match error.code() { + ErrorCode::WANT_READ | ErrorCode::WANT_WRITE => Err(HandshakeError::WouldBlock( + MidHandshakeSslStream { stream, error }, + )), + _ => Err(HandshakeError::Failure(MidHandshakeSslStream { + stream, + error, + })), + } + } + } + + // Future work: early IO methods +} + +impl SslStreamBuilder { + /// Returns a shared reference to the underlying stream. + pub fn get_ref(&self) -> &S { + unsafe { + let bio = self.inner.ssl.get_raw_rbio(); + bio::get_ref(bio) + } + } + + /// Returns a mutable reference to the underlying stream. + /// + /// # Warning + /// + /// It is inadvisable to read from or write to the underlying stream as it + /// will most likely corrupt the SSL session. + pub fn get_mut(&mut self) -> &mut S { + unsafe { + let bio = self.inner.ssl.get_raw_rbio(); + bio::get_mut(bio) + } + } + + /// Returns a shared reference to the `Ssl` object associated with this builder. + pub fn ssl(&self) -> &SslRef { &self.inner.ssl } +} + /// The result of a shutdown request. #[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ShutdownResult {