Skip to content

Commit

Permalink
Make Stream easier to send across threads (surrealdb#3042)
Browse files Browse the repository at this point in the history
  • Loading branch information
rushmorem authored Nov 29, 2023
1 parent a70ddb2 commit a8ed51f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 21 deletions.
32 changes: 17 additions & 15 deletions lib/src/api/method/live.rs
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Router;
use crate::api::err::Error;
use crate::api::opt::Range;
use crate::api::Connection;
use crate::api::ExtraFeatures;
use crate::api::Result;
use crate::dbs;
use crate::method::OnceLockExt;
use crate::method::Query;
use crate::opt::from_value;
use crate::opt::Resource;
Expand All @@ -26,6 +26,7 @@ use crate::sql::Thing;
use crate::sql::Uuid;
use crate::sql::Value;
use crate::Notification;
use crate::Surreal;
use channel::Receiver;
use futures::StreamExt;
use serde::de::DeserializeOwned;
Expand All @@ -43,7 +44,7 @@ const ID: &str = "id";
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Live<'r, C: Connection, R> {
pub(super) router: Result<&'r Router<C>>,
pub(super) client: &'r Surreal<C>,
pub(super) resource: Result<Resource>,
pub(super) range: Option<Range<Id>>,
pub(super) response_type: PhantomData<R>,
Expand All @@ -53,13 +54,13 @@ macro_rules! into_future {
() => {
fn into_future(self) -> Self::IntoFuture {
let Live {
router,
client,
resource,
range,
..
} = self;
Box::pin(async move {
let router = router?;
let router = client.router.extract()?;
if !router.features.contains(&ExtraFeatures::LiveQueries) {
return Err(Error::LiveQueriesNotSupported.into());
}
Expand Down Expand Up @@ -104,9 +105,9 @@ macro_rules! into_future {
param.other = vec![id.clone()];
conn.execute_unit(router, param).await?;
Ok(Stream {
router,
id,
rx,
client: client.clone(),
response_type: PhantomData,
})
})
Expand Down Expand Up @@ -212,7 +213,7 @@ impl<'r, Client> IntoFuture for Live<'r, Client, Value>
where
Client: Connection,
{
type Output = Result<Stream<'r, Client, Value>>;
type Output = Result<Stream<Client, Value>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;

into_future! {}
Expand All @@ -223,7 +224,7 @@ where
Client: Connection,
R: DeserializeOwned,
{
type Output = Result<Stream<'r, Client, Option<R>>>;
type Output = Result<Stream<Client, Option<R>>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;

into_future! {}
Expand All @@ -234,7 +235,7 @@ where
Client: Connection,
R: DeserializeOwned,
{
type Output = Result<Stream<'r, Client, Vec<R>>>;
type Output = Result<Stream<Client, Vec<R>>>;
type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + Sync + 'r>>;

into_future! {}
Expand All @@ -243,14 +244,14 @@ where
/// A stream of live query notifications
#[derive(Debug)]
#[must_use = "streams do nothing unless you poll them"]
pub struct Stream<'r, C: Connection, R> {
router: &'r Router<C>,
pub struct Stream<C: Connection, R> {
client: Surreal<C>,
id: Value,
rx: Receiver<dbs::Notification>,
response_type: PhantomData<R>,
}

impl<Client, R> Stream<'_, Client, R>
impl<Client, R> Stream<Client, R>
where
Client: Connection,
{
Expand All @@ -260,8 +261,9 @@ where
/// If the stream is dropped without calling this method, the process
/// will be killed next time it tries to send a notification to the stream.
pub async fn close(self) -> Result<()> {
let router = self.client.router.extract()?;
let mut conn = Client::new(Method::Kill);
conn.execute_unit(self.router, Param::new(vec![self.id])).await
conn.execute_unit(router, Param::new(vec![self.id])).await
}
}

Expand All @@ -281,7 +283,7 @@ macro_rules! poll_next {
};
}

impl<C> futures::Stream for Stream<'_, C, Value>
impl<C> futures::Stream for Stream<C, Value>
where
C: Connection,
{
Expand All @@ -303,7 +305,7 @@ macro_rules! poll_next_and_convert {
};
}

impl<C, R> futures::Stream for Stream<'_, C, Option<R>>
impl<C, R> futures::Stream for Stream<C, Option<R>>
where
C: Connection,
R: DeserializeOwned + Unpin,
Expand All @@ -313,7 +315,7 @@ where
poll_next_and_convert! {}
}

impl<C, R> futures::Stream for Stream<'_, C, Vec<R>>
impl<C, R> futures::Stream for Stream<C, Vec<R>>
where
C: Connection,
R: DeserializeOwned + Unpin,
Expand Down
2 changes: 1 addition & 1 deletion lib/src/api/method/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ where
/// ```
pub fn select<R>(&self, resource: impl opt::IntoResource<R>) -> Select<C, R> {
Select {
router: self.router.extract(),
client: self,
resource: resource.into_resource(),
range: None,
response_type: PhantomData,
Expand Down
11 changes: 6 additions & 5 deletions lib/src/api/method/select.rs
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
use crate::api::conn::Method;
use crate::api::conn::Param;
use crate::api::conn::Router;
use crate::api::method::OnceLockExt;
use crate::api::opt::Range;
use crate::api::opt::Resource;
use crate::api::Connection;
use crate::api::Result;
use crate::method::Live;
use crate::sql::Id;
use crate::sql::Value;
use crate::Surreal;
use serde::de::DeserializeOwned;
use std::future::Future;
use std::future::IntoFuture;
Expand All @@ -18,7 +19,7 @@ use std::pin::Pin;
#[derive(Debug)]
#[must_use = "futures do nothing unless you `.await` or poll them"]
pub struct Select<'r, C: Connection, R> {
pub(super) router: Result<&'r Router<C>>,
pub(super) client: &'r Surreal<C>,
pub(super) resource: Result<Resource>,
pub(super) range: Option<Range<Id>>,
pub(super) response_type: PhantomData<R>,
Expand All @@ -28,7 +29,7 @@ macro_rules! into_future {
($method:ident) => {
fn into_future(self) -> Self::IntoFuture {
let Select {
router,
client,
resource,
range,
..
Expand All @@ -39,7 +40,7 @@ macro_rules! into_future {
None => resource?.into(),
};
let mut conn = Client::new(Method::Select);
conn.$method(router?, Param::new(vec![param])).await
conn.$method(client.router.extract()?, Param::new(vec![param])).await
})
}
};
Expand Down Expand Up @@ -153,7 +154,7 @@ where
/// ```
pub fn live(self) -> Live<'r, C, R> {
Live {
router: self.router,
client: self.client,
resource: self.resource,
range: self.range,
response_type: self.response_type,
Expand Down

0 comments on commit a8ed51f

Please sign in to comment.