Skip to content

Commit

Permalink
Permit configuring the notice callback
Browse files Browse the repository at this point in the history
Right now the behavior is hardcoded to log any received notices at the
info level. Add a `notice_callback` configuration option that permits
installing an arbitrary callback to handle any received notices.

As discussed in sfackler#588.
  • Loading branch information
benesch committed Sep 22, 2020
1 parent 4237843 commit 4af6fcd
Show file tree
Hide file tree
Showing 3 changed files with 60 additions and 8 deletions.
36 changes: 31 additions & 5 deletions postgres/src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,16 @@
use crate::connection::Connection;
use crate::Client;
use log::info;
use std::fmt;
use std::path::Path;
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;
use tokio::runtime;
#[doc(inline)]
pub use tokio_postgres::config::{ChannelBinding, Host, SslMode, TargetSessionAttrs};
use tokio_postgres::error::DbError;
use tokio_postgres::tls::{MakeTlsConnect, TlsConnect};
use tokio_postgres::{Error, Socket};

Expand Down Expand Up @@ -90,6 +93,7 @@ use tokio_postgres::{Error, Socket};
#[derive(Clone)]
pub struct Config {
config: tokio_postgres::Config,
notice_callback: Arc<dyn Fn(DbError) + Send + Sync>,
}

impl fmt::Debug for Config {
Expand All @@ -109,9 +113,7 @@ impl Default for Config {
impl Config {
/// Creates a new configuration.
pub fn new() -> Config {
Config {
config: tokio_postgres::Config::new(),
}
tokio_postgres::Config::new().into()
}

/// Sets the user to authenticate with.
Expand Down Expand Up @@ -307,6 +309,25 @@ impl Config {
self.config.get_channel_binding()
}

/// Sets the notice callback.
///
/// This callback will be invoked with the contents of every
/// [`AsyncMessage::Notice`] that is received by the connection. Notices use
/// the same structure as errors, but they are not "errors" per-se.
///
/// Notices are distinct from notifications, which are instead accessible
/// via the [`Notifications`] API.
///
/// [`AsyncMessage::Notice`]: tokio_postgres::AsyncMessage::Notice
/// [`Notifications`]: crate::Notifications
pub fn notice_callback<F>(&mut self, f: F) -> &mut Config
where
F: Fn(DbError) + Send + Sync + 'static,
{
self.notice_callback = Arc::new(f);
self
}

/// Opens a connection to a PostgreSQL database.
pub fn connect<T>(&self, tls: T) -> Result<Client, Error>
where
Expand All @@ -323,7 +344,7 @@ impl Config {

let (client, connection) = runtime.block_on(self.config.connect(tls))?;

let connection = Connection::new(runtime, connection);
let connection = Connection::new(runtime, connection, self.notice_callback.clone());
Ok(Client::new(connection, client))
}
}
Expand All @@ -338,6 +359,11 @@ impl FromStr for Config {

impl From<tokio_postgres::Config> for Config {
fn from(config: tokio_postgres::Config) -> Config {
Config { config }
Config {
config,
notice_callback: Arc::new(|notice| {
info!("{}: {}", notice.severity(), notice.message())
}),
}
}
}
14 changes: 11 additions & 3 deletions postgres/src/connection.rs
Original file line number Diff line number Diff line change
@@ -1,24 +1,30 @@
use crate::{Error, Notification};
use futures::future;
use futures::{pin_mut, Stream};
use log::info;
use std::collections::VecDeque;
use std::future::Future;
use std::ops::{Deref, DerefMut};
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite};
use tokio::runtime::Runtime;
use tokio_postgres::error::DbError;
use tokio_postgres::AsyncMessage;

pub struct Connection {
runtime: Runtime,
connection: Pin<Box<dyn Stream<Item = Result<AsyncMessage, Error>> + Send>>,
notifications: VecDeque<Notification>,
notice_callback: Arc<dyn Fn(DbError)>,
}

impl Connection {
pub fn new<S, T>(runtime: Runtime, connection: tokio_postgres::Connection<S, T>) -> Connection
pub fn new<S, T>(
runtime: Runtime,
connection: tokio_postgres::Connection<S, T>,
notice_callback: Arc<dyn Fn(DbError)>,
) -> Connection
where
S: AsyncRead + AsyncWrite + Unpin + 'static + Send,
T: AsyncRead + AsyncWrite + Unpin + 'static + Send,
Expand All @@ -27,6 +33,7 @@ impl Connection {
runtime,
connection: Box::pin(ConnectionStream { connection }),
notifications: VecDeque::new(),
notice_callback,
}
}

Expand Down Expand Up @@ -55,6 +62,7 @@ impl Connection {
{
let connection = &mut self.connection;
let notifications = &mut self.notifications;
let notice_callback = &mut self.notice_callback;
self.runtime.block_on({
future::poll_fn(|cx| {
let done = loop {
Expand All @@ -63,7 +71,7 @@ impl Connection {
notifications.push_back(notification);
}
Poll::Ready(Some(Ok(AsyncMessage::Notice(notice)))) => {
info!("{}: {}", notice.severity(), notice.message());
notice_callback(notice)
}
Poll::Ready(Some(Ok(_))) => {}
Poll::Ready(Some(Err(e))) => return Poll::Ready(Err(e)),
Expand Down
18 changes: 18 additions & 0 deletions postgres/src/test.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
use std::io::{Read, Write};
use std::str::FromStr;
use std::sync::mpsc;
use std::thread;
use std::time::Duration;
use tokio_postgres::error::SqlState;
Expand Down Expand Up @@ -476,6 +478,22 @@ fn notifications_timeout_iter() {
assert_eq!(notifications[1].payload(), "world");
}

#[test]
fn notice_callback() {
let (notice_tx, notice_rx) = mpsc::sync_channel(64);
let mut client = Config::from_str("host=localhost port=5433 user=postgres")
.unwrap()
.notice_callback(move |n| notice_tx.send(n).unwrap())
.connect(NoTls)
.unwrap();

client
.batch_execute("DO $$BEGIN RAISE NOTICE 'custom'; END$$")
.unwrap();

assert_eq!(notice_rx.recv().unwrap().message(), "custom");
}

#[test]
fn explicit_close() {
let client = Client::connect("host=localhost port=5433 user=postgres", NoTls).unwrap();
Expand Down

0 comments on commit 4af6fcd

Please sign in to comment.