diff --git a/bin/connector.rs b/bin/connector.rs new file mode 100644 index 0000000..e95e81c --- /dev/null +++ b/bin/connector.rs @@ -0,0 +1,26 @@ +use crate::config::Config; +use rusty_rtss::postgres::PgConnector; +use sqlx::postgres::PgPool; + +use super::Payload; +use super::Result; + +pub async fn get_pool_from_config(config: &Config) -> Result { + PgPool::connect(&config.postgres.uri) + .await + .map_err(Into::into) +} + +pub async fn get_connector_from_pool( + pool: &PgPool, + config: &Config, +) -> Result> { + let channels = config.postgres.listen_channels.clone(); + + PgConnector::builder() + .with_pool(pool) + .add_channels(channels) + .build() + .await + .map_err(|_| "Unable to get listener from pool".into()) +} diff --git a/bin/listener.rs b/bin/listener.rs deleted file mode 100644 index e0f5475..0000000 --- a/bin/listener.rs +++ /dev/null @@ -1,22 +0,0 @@ -use crate::config::Config; -use sqlx::postgres::PgPool; - -use super::Payload; -use super::Result; -use rusty_rtss::postgres::PgListener; - -pub async fn get_pool_from_config(config: &Config) -> Result { - PgPool::connect(&config.postgres.uri) - .await - .map_err(Into::into) -} - -pub async fn get_listener_from_pool(pool: &PgPool, config: &Config) -> Result> { - let channels = &config.postgres.listen_channels; - - let channels = channels.iter().map(|x| x.as_str()).collect(); - - PgListener::from_pool(pool, channels) - .await - .map_err(|_| "unable to get listener from pool".into()) -} diff --git a/bin/main.rs b/bin/main.rs index 13a1baf..16016ac 100644 --- a/bin/main.rs +++ b/bin/main.rs @@ -1,11 +1,10 @@ #![feature(result_option_inspect)] -use std::sync::Arc; use repository::SubmisisonRepository; -use rusty_rtss::{app::App, sse::SsePublisher}; +use rusty_rtss::{app::App, postgres::PgConnector, sse::SsePublisher}; mod config; -mod listener; +mod connector; mod payload; mod publisher; mod repository; @@ -20,13 +19,13 @@ type Payload = payload::Payload; #[derive(Clone)] pub struct SharedState { - app: Arc>>, + app: App, SsePublisher>, repository: SubmisisonRepository, } #[tokio::main] async fn main() { - env_logger::init(); + env_logger::builder().format_timestamp(None).init(); let config = match config::load_config() { Ok(x) => x, @@ -36,12 +35,12 @@ async fn main() { } }; - let pool = listener::get_pool_from_config(&config) + let pool = connector::get_pool_from_config(&config) .await .expect("Unable to create connection pool"); log::info!("Connected to database"); - let listener = listener::get_listener_from_pool(&pool, &config) + let connector = connector::get_connector_from_pool(&pool, &config) .await .expect("Unable to create listener from connection pool"); log::info!("Listened to channel"); @@ -52,9 +51,18 @@ async fn main() { let publisher = publisher::get_publisher(); log::info!("Created publisher"); - let app = Arc::new(App::new(listener, publisher).expect("Unable to create app")); + let app = App::new(connector, publisher) + .await + .expect("Unable to create app"); log::info!("Created app"); + let _app = app.clone(); + tokio::spawn(async move { + if let Err(e) = _app.handle_connection().await { + log::error!("Handle connection join error: {e:?}"); + } + }); + let shared_state = SharedState { app, repository }; let router = router::get_router(shared_state); diff --git a/bin/publisher.rs b/bin/publisher.rs index 855464b..b512b36 100644 --- a/bin/publisher.rs +++ b/bin/publisher.rs @@ -1,5 +1,5 @@ use rusty_rtss::sse::SsePublisher; -pub fn get_publisher() -> SsePublisher { +pub fn get_publisher() -> SsePublisher { SsePublisher::new() } diff --git a/src/app.rs b/src/app.rs index b50b2ea..9ce6942 100644 --- a/src/app.rs +++ b/src/app.rs @@ -1,55 +1,138 @@ +use crate::listener::Connector; + use super::{listener::Listener, publisher::Publisher}; -use futures_util::StreamExt; +use futures_util::{Stream, StreamExt}; use std::sync::Arc; -use tokio::task::JoinHandle; -pub struct App

{ - _handle: JoinHandle<()>, - publisher: Arc

, +struct Inner { + connector: C, + publisher: P, +} + +/// Wrapper around [Inner](Inner) +/// All of the logic should be perform in [Inner](Inner) +pub struct App { + inner: Arc>, } -impl

App

{ - pub fn new(listener: L, publisher: P) -> Result> +impl App { + pub async fn new(connector: C, publisher: P) -> Result> where - L: Listener + 'static, P: Publisher + 'static, T: Send + Sync + 'static, + C: Connector + 'static, + ::Listener: Listener + 'static, { - let publisher = Arc::new(publisher); + let inner = Arc::new(Inner { + connector, + publisher, + }); - let cloned_publisher = Arc::clone(&publisher); - let handle = tokio::spawn(async move { - let cloned_publisher = cloned_publisher; - let stream = listener.into_stream(); + let app = App { inner }; - stream - .for_each_concurrent(10, move |payload| { - let cloned_publisher = Arc::clone(&cloned_publisher); + Ok(app) + } - async move { - let cloned_publisher = cloned_publisher; + pub async fn add_subscriber(&self, subscriber: S) -> Result<(), Box> + where + P: Publisher + 'static, + { + let inner = Arc::clone(&self.inner); - cloned_publisher.publish(payload).await; - } - }) - .await - }); + if let Err(e) = inner.add_subscriber(subscriber).await { + log::warn!("Unable to add subscriber: {e:?}"); + }; - Ok(App { - _handle: handle, - publisher, - }) + tokio::task::yield_now().await; + + Ok(()) } - pub async fn add_subscriber(&self, subscriber: S) -> Result<(), Box> + pub fn add_stream(&self, stream: S) -> tokio::task::JoinHandle<()> + where + S: Stream + Send + 'static, + P: Publisher + 'static, + T: Send + Sync + 'static, + C: Send + Sync + 'static, + { + Arc::clone(&self.inner).add_stream(stream) + } + + pub async fn handle_connection(&self) -> Result<(), tokio::task::JoinError> + where + C: Connector + 'static, + ::Listener: Listener, + P: Publisher + 'static, + T: Send + Sync + 'static, + { + Arc::clone(&self.inner).handle_connection().await + } +} + +impl Clone for App { + fn clone(&self) -> Self { + App { + inner: Arc::clone(&self.inner), + } + } +} + +impl Inner { + pub async fn add_subscriber( + self: Arc, + subscriber: S, + ) -> Result<(), Box> where P: Publisher + 'static, { self.publisher.add_subscriber(subscriber); - tokio::task::yield_now().await; - Ok(()) } + + pub fn add_stream(self: Arc, stream: S) -> tokio::task::JoinHandle<()> + where + S: Stream + Send + 'static, + P: Publisher + 'static, + T: Send + Sync + 'static, + C: Send + Sync + 'static, + { + tokio::spawn(async move { + stream + .for_each_concurrent(10, move |payload| Arc::clone(&self).handle_payload(payload)) + .await; + + log::info!("Stream end"); + }) + } + + async fn handle_payload(self: Arc, payload: T) + where + P: Publisher + 'static, + T: Send + Sync + 'static, + { + self.publisher.publish(payload).await + } + + /// Future become ready after connector give up on connection + async fn handle_connection(self: Arc) -> Result<(), tokio::task::JoinError> + where + C: Connector + 'static, + ::Listener: Listener, + P: Publisher + 'static, + T: Send + Sync + 'static, + { + tokio::spawn(async move { + while let Some(listener) = self.connector.connect().await { + let stream_handle = Arc::clone(&self).add_stream(listener.into_stream()); + + stream_handle.await? + } + log::info!("Connector give up"); + Ok(()) + }) + .await + .flatten() + } } diff --git a/src/lib.rs b/src/lib.rs index e0b1aef..744afa8 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,3 +1,5 @@ +#![feature(result_flattening)] + pub mod app; pub mod listener; pub mod postgres; diff --git a/src/listener.rs b/src/listener.rs index 315c407..a7a3943 100644 --- a/src/listener.rs +++ b/src/listener.rs @@ -7,3 +7,14 @@ pub trait Listener: Send + Sync { fn into_stream(self) -> Self::S; } + +#[async_trait::async_trait] +pub trait Connector: Send + Sync { + type Listener: Listener; + + /// `None` indicates that there will be no connection continue + /// the default implementation is also `None` + async fn connect(&self) -> Option { + None + } +} diff --git a/src/postgres.rs b/src/postgres.rs deleted file mode 100644 index 2e88010..0000000 --- a/src/postgres.rs +++ /dev/null @@ -1,63 +0,0 @@ -use std::marker::PhantomData; - -use futures::stream::BoxStream; -use futures_util::StreamExt; -use sqlx::postgres::PgNotification; - -use super::listener::Listener; - -/// postgres implementation of [`Listener`](super::event::Listener) -pub struct PgListener

{ - listener: sqlx::postgres::PgListener, - _payload: PhantomData

, -} - -impl

Listener for PgListener

-where - P: Send + Sync + From, -{ - type Data = P; - type S = BoxStream<'static, Self::Data>; - - fn into_stream(self) -> Self::S { - self.listener - .into_stream() - .filter_map(|result: Result| async move { - result.ok().map(Into::into) - }) - .boxed() - } -} - -impl

PgListener

{ - /// Consume [`PgListenerConfig`](PgListenerConfig), then connect and listen to the specify channel - pub async fn connect(config: PgListenerConfig<'_>) -> Result> { - let mut con = sqlx::postgres::PgListener::connect(config.url).await?; - - con.listen_all(config.channels).await?; - - Ok(PgListener { - listener: con, - _payload: Default::default(), - }) - } - - pub async fn from_pool( - pool: &sqlx::postgres::PgPool, - channels: Vec<&str>, - ) -> Result> { - let mut con = sqlx::postgres::PgListener::connect_with(pool).await?; - - con.listen_all(channels).await?; - - Ok(PgListener { - listener: con, - _payload: Default::default(), - }) - } -} - -pub struct PgListenerConfig<'a> { - pub url: &'a str, - pub channels: Vec<&'a str>, -} diff --git a/src/postgres/builder.rs b/src/postgres/builder.rs new file mode 100644 index 0000000..0f5edef --- /dev/null +++ b/src/postgres/builder.rs @@ -0,0 +1,73 @@ +use std::marker::PhantomData; + +use sqlx::PgPool; + +use super::connector::PgConnector; + +pub struct PgConnectorBuilder

{ + url: Option, + pool: Option, + listen_channels: Vec, + _payload: PhantomData

, +} + +impl

PgConnectorBuilder

{ + pub(super) fn new() -> Self { + Self { + url: None, + pool: None, + listen_channels: Vec::with_capacity(1), + _payload: Default::default(), + } + } + + pub fn with_url(mut self, url: String) -> Self { + self.url = Some(url); + + self + } + + pub fn with_pool(mut self, pool: &PgPool) -> Self { + self.pool = Some(pool.clone()); + + self + } + + pub fn add_channel(mut self, channel: String) -> Self { + self.listen_channels.push(channel); + + self + } + + pub fn add_channels(mut self, channels: Vec) -> Self { + self.listen_channels.reserve(channels.len()); + + for channel in channels { + self.listen_channels.push(channel); + } + + self + } + + pub async fn build(self) -> Result, Box> { + let url = self.url; + let pool = self.pool; + let listen_channels = self.listen_channels; + + match (url, pool) { + (None, None) | (Some(..), Some(..)) => { + Err("Either url or pool needed to be supplied".into()) + } + (None, Some(pool)) => Ok(Self::build_with_pool(pool, listen_channels)), + (Some(url), None) => Ok(Self::build_with_url(url, listen_channels)), + } + } + + fn build_with_url(url: String, listen_channels: Vec) -> PgConnector

{ + PgConnector::from_url(url, listen_channels) + } + + fn build_with_pool(pool: PgPool, listen_channels: Vec) -> PgConnector

{ + PgConnector::from_pool(pool, listen_channels) + } +} diff --git a/src/postgres/connector.rs b/src/postgres/connector.rs new file mode 100644 index 0000000..adc11d4 --- /dev/null +++ b/src/postgres/connector.rs @@ -0,0 +1,86 @@ +use std::{marker::PhantomData, time::Duration}; + +use sqlx::PgPool; +use tokio::{sync::Mutex, task::JoinHandle}; + +use crate::listener::Connector; + +use super::{builder::PgConnectorBuilder, PgListener}; + +pub struct PgConnector

{ + connection_method: ConnectionMethod, + last_attempt_handle: Mutex>>, + listen_channels: Vec, + _payload: PhantomData

, +} + +enum ConnectionMethod { + Url(String), + Pool(PgPool), +} + +impl

PgConnector

{ + fn new(connection_method: ConnectionMethod, listen_channels: Vec) -> Self { + Self { + connection_method, + listen_channels, + last_attempt_handle: Default::default(), + _payload: PhantomData, + } + } + + pub fn builder() -> PgConnectorBuilder

{ + PgConnectorBuilder::new() + } + + pub fn from_pool(pool: PgPool, channels: Vec) -> Self { + Self::new(ConnectionMethod::Pool(pool), channels) + } + + pub fn from_url(url: String, channels: Vec) -> Self { + Self::new(ConnectionMethod::Url(url), channels) + } +} + +#[async_trait::async_trait] +impl

Connector for PgConnector

+where + P: Send + Sync + From, +{ + type Listener = PgListener

; + + async fn connect(&self) -> Option { + let mut listener = loop { + let mut lock = self.last_attempt_handle.lock().await; + + if let Some(handle) = lock.take() { + if let Err(e) = handle.await { + log::error!("Reconnect join error: {e:?}"); + } + } + + let new_handle = tokio::spawn(async move { + tokio::time::sleep(Duration::from_secs(180)).await; + }); + + *lock = Some(new_handle); + drop(lock); + + log::trace!("Trying to connect to database..."); + if let Some(listener) = match &self.connection_method { + ConnectionMethod::Url(url) => sqlx::postgres::PgListener::connect(url).await.ok(), + ConnectionMethod::Pool(pool) => { + sqlx::postgres::PgListener::connect_with(pool).await.ok() + } + } { + break listener; + } + }; + + let channels: Vec<&str> = self.listen_channels.iter().map(|x| x.as_str()).collect(); + + listener.listen_all(channels).await.ok()?; + + Some(PgListener::new(listener)) + } +} diff --git a/src/postgres/listener.rs b/src/postgres/listener.rs new file mode 100644 index 0000000..affb1db --- /dev/null +++ b/src/postgres/listener.rs @@ -0,0 +1,39 @@ +use std::marker::PhantomData; + +use futures::stream::BoxStream; +use futures_util::StreamExt; +use sqlx::postgres::PgNotification; + +use crate::listener::Listener; + +/// postgres implementation of [`Listener`](super::event::Listener) +pub struct PgListener

{ + listener: sqlx::postgres::PgListener, + _payload: PhantomData

, +} + +impl

Listener for PgListener

+where + P: Send + Sync + From, +{ + type Data = P; + type S = BoxStream<'static, Self::Data>; + + fn into_stream(self) -> Self::S { + self.listener + .into_stream() + .filter_map(|result: Result| async move { + result.ok().map(Into::into) + }) + .boxed() + } +} + +impl

PgListener

{ + pub fn new(listener: sqlx::postgres::PgListener) -> Self { + PgListener { + listener, + _payload: PhantomData, + } + } +} diff --git a/src/postgres/mod.rs b/src/postgres/mod.rs new file mode 100644 index 0000000..996b788 --- /dev/null +++ b/src/postgres/mod.rs @@ -0,0 +1,6 @@ +mod builder; +mod connector; +mod listener; + +pub use connector::PgConnector; +pub use listener::PgListener; diff --git a/tests/test.rs b/tests/test.rs index cc1f048..6ddaff3 100644 --- a/tests/test.rs +++ b/tests/test.rs @@ -13,7 +13,10 @@ mod mock { use dashmap::DashMap; use futures::channel::mpsc::{UnboundedReceiver, UnboundedSender}; use futures_util::stream::BoxStream; - use rusty_rtss::{listener::Listener, publisher::Publisher}; + use rusty_rtss::{ + listener::{Connector, Listener}, + publisher::Publisher, + }; type Writer = UnboundedSender; type Reader = UnboundedReceiver; @@ -93,29 +96,63 @@ mod mock { Self { rx } } } + + pub struct MockConnector { + rx: Mutex>, + } + + impl MockConnector { + pub fn new(rx: Reader) -> Self { + Self { + rx: Mutex::new(Some(rx)), + } + } + } + + #[async_trait::async_trait] + impl Connector for MockConnector { + type Listener = MockListener; + + async fn connect(&self) -> Option { + let rx = self.rx.lock().unwrap().take(); + + if let Some(rx) = rx { + Some(MockListener::new(rx)) + } else { + None + } + } + } } -fn get_app() -> ( - App, +async fn get_app() -> ( + App, UnboundedSender, ) { let (input, rx) = futures::channel::mpsc::unbounded(); - let listener = mock::MockListener::new(rx); + let connector = mock::MockConnector::new(rx); let publisher = mock::MockFanoutPublisher::new(); ( - App::new(listener, publisher).expect("unable to create app"), + App::new(connector, publisher) + .await + .expect("unable to create app"), input, ) } #[tokio::test] async fn test_one_subscriber() { - let (app, mut input) = get_app(); + let (app, mut input) = get_app().await; let (tx, mut rx) = futures::channel::mpsc::unbounded(); let subscriber = mock::MockSubscriber::new(tx); + let _app = app.clone(); + tokio::spawn(async move { + let _ = _app.handle_connection().await; + }); + app.add_subscriber(subscriber).await.unwrap(); let d = poll!(rx.next()); @@ -134,7 +171,7 @@ async fn test_one_subscriber() { #[tokio::test] async fn test_many_subscriber() { - let (app, mut input) = get_app(); + let (app, mut input) = get_app().await; let (tx1, mut rx1) = futures::channel::mpsc::unbounded(); let subscriber1 = mock::MockSubscriber::new(tx1); @@ -149,6 +186,11 @@ async fn test_many_subscriber() { app.add_subscriber(subscriber2).await.unwrap(); app.add_subscriber(subscriber3).await.unwrap(); + let _app = app.clone(); + tokio::spawn(async move { + let _ = _app.handle_connection().await; + }); + let d1 = poll!(rx1.next()); let d2 = poll!(rx2.next()); let d3 = poll!(rx3.next());