Skip to content

Commit

Permalink
tokio-postgres-openssl
Browse files Browse the repository at this point in the history
  • Loading branch information
sfackler committed Jun 27, 2018
1 parent 5c89b35 commit 369f6e0
Show file tree
Hide file tree
Showing 4 changed files with 203 additions and 0 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ members = [
"postgres-openssl",
"postgres-native-tls",
"tokio-postgres",
"tokio-postgres-openssl",
]

[patch.crates-io]
Expand Down
15 changes: 15 additions & 0 deletions tokio-postgres-openssl/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
[package]
name = "tokio-postgres-openssl"
version = "0.1.0"
authors = ["Steven Fackler <[email protected]>"]

[dependencies]
bytes = "0.4"
futures = "0.1"
openssl = "0.10"
tokio-io = "0.1"
tokio-openssl = "0.2"
tokio-postgres = { version = "0.3", path = "../tokio-postgres" }

[dev-dependencies]
tokio = "0.1.7"
127 changes: 127 additions & 0 deletions tokio-postgres-openssl/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
extern crate bytes;
extern crate futures;
extern crate openssl;
extern crate tokio_io;
extern crate tokio_openssl;
extern crate tokio_postgres;

#[cfg(test)]
extern crate tokio;

use bytes::{Buf, BufMut};
use futures::{Future, IntoFuture, Poll};
use openssl::error::ErrorStack;
use openssl::ssl::{ConnectConfiguration, SslConnector, SslMethod};
use std::error::Error;
use std::io::{self, Read, Write};
use tokio_io::{AsyncRead, AsyncWrite};
use tokio_openssl::ConnectConfigurationExt;
use tokio_postgres::tls::{Socket, TlsConnect, TlsStream};

#[cfg(test)]
mod test;

pub struct TlsConnector {
connector: SslConnector,
callback: Box<Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + Sync + Send>,
}

impl TlsConnector {
pub fn new() -> Result<TlsConnector, ErrorStack> {
let connector = SslConnector::builder(SslMethod::tls())?.build();
Ok(TlsConnector::with_connector(connector))
}

pub fn with_connector(connector: SslConnector) -> TlsConnector {
TlsConnector {
connector,
callback: Box::new(|_| Ok(())),
}
}

pub fn set_callback<F>(&mut self, f: F)
where
F: Fn(&mut ConnectConfiguration) -> Result<(), ErrorStack> + 'static + Sync + Send,
{
self.callback = Box::new(f);
}
}

impl TlsConnect for TlsConnector {
fn connect(
&self,
domain: &str,
socket: Socket,
) -> Box<Future<Item = Box<TlsStream>, Error = Box<Error + Sync + Send>> + Sync + Send> {
let f = self
.connector
.configure()
.and_then(|mut ssl| (self.callback)(&mut ssl).map(|_| ssl))
.map_err(|e| {
let e: Box<Error + Sync + Send> = Box::new(e);
e
})
.into_future()
.and_then({
let domain = domain.to_string();
move |ssl| {
ssl.connect_async(&domain, socket)
.map(|s| {
let s: Box<TlsStream> = Box::new(SslStream(s));
s
})
.map_err(|e| {
let e: Box<Error + Sync + Send> = Box::new(e);
e
})
}
});
Box::new(f)
}
}

struct SslStream(tokio_openssl::SslStream<Socket>);

impl Read for SslStream {
fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
self.0.read(buf)
}
}

impl AsyncRead for SslStream {
unsafe fn prepare_uninitialized_buffer(&self, buf: &mut [u8]) -> bool {
self.0.prepare_uninitialized_buffer(buf)
}

fn read_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: BufMut,
{
self.0.read_buf(buf)
}
}

impl Write for SslStream {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.0.write(buf)
}

fn flush(&mut self) -> io::Result<()> {
self.0.flush()
}
}

impl AsyncWrite for SslStream {
fn shutdown(&mut self) -> Poll<(), io::Error> {
self.0.shutdown()
}

fn write_buf<B>(&mut self, buf: &mut B) -> Poll<usize, io::Error>
where
B: Buf,
{
self.0.write_buf(buf)
}
}

impl TlsStream for SslStream {}
60 changes: 60 additions & 0 deletions tokio-postgres-openssl/src/test.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
use futures::{Future, Stream};
use openssl::ssl::{SslConnector, SslMethod};
use tokio::runtime::current_thread::Runtime;
use tokio_postgres::{self, TlsMode};

use TlsConnector;

fn smoke_test(url: &str, tls: TlsMode) {
let mut runtime = Runtime::new().unwrap();

let handshake = tokio_postgres::connect(url.parse().unwrap(), tls);
let (mut client, connection) = runtime.block_on(handshake).unwrap();
let connection = connection.map_err(|e| panic!("{}", e));
runtime.handle().spawn(connection).unwrap();

let prepare = client.prepare("SELECT 1::INT4");
let statement = runtime.block_on(prepare).unwrap();
let select = client.query(&statement, &[]).collect().map(|rows| {
assert_eq!(rows.len(), 1);
assert_eq!(rows[0].get::<_, i32>(0), 1);
});
runtime.block_on(select).unwrap();

drop(statement);
drop(client);
runtime.run().unwrap();
}

#[test]
fn require() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let connector = TlsConnector::with_connector(builder.build());
smoke_test(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::Require(Box::new(connector)),
);
}

#[test]
fn prefer() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let connector = TlsConnector::with_connector(builder.build());
smoke_test(
"postgres://ssl_user@localhost:5433/postgres",
TlsMode::Prefer(Box::new(connector)),
);
}

#[test]
fn scram_user() {
let mut builder = SslConnector::builder(SslMethod::tls()).unwrap();
builder.set_ca_file("../test/server.crt").unwrap();
let connector = TlsConnector::with_connector(builder.build());
smoke_test(
"postgres://scram_user:password@localhost:5433/postgres",
TlsMode::Require(Box::new(connector)),
);
}

0 comments on commit 369f6e0

Please sign in to comment.