diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index f0c13870..13de097d 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -25,9 +25,11 @@ pub enum DbConfig { max_log_size: usize, /// Interval at which to force compaction replication_log_compact_interval: Option, + transaction_timeout_duration: Duration, }, Replica { primary_node_id: NodeId, proxy_request_timeout_duration: Duration, + transaction_timeout_duration: Duration, }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 0978b549..7300feea 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -5,18 +5,20 @@ use std::path::PathBuf; use std::pin::Pin; use std::sync::Arc; use std::task::{ready, Context, Poll}; -use std::time::Instant; +use std::time::{Duration, Instant}; use either::Either; use libsqlx::libsql::LibsqlDatabase; use libsqlx::program::Program; use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::result_builder::ResultBuilder; use libsqlx::{Database as _, InjectableDatabase}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; use tokio::time::Interval; use crate::allocation::primary::FrameStreamer; +use crate::allocation::timeout_notifier::timeout_monitor; use crate::compactor::CompactionQueue; use crate::hrana; use crate::hrana::http::handle_pipeline; @@ -30,17 +32,25 @@ use self::config::{AllocConfig, DbConfig}; use self::primary::compactor::Compactor; use self::primary::{PrimaryConnection, PrimaryDatabase, ProxyResponseBuilder}; use self::replica::{ProxyDatabase, RemoteDb, ReplicaConnection, Replicator}; +use self::timeout_notifier::TimeoutMonitor; pub mod config; mod primary; mod replica; +mod timeout_notifier; /// Maximum number of frame a Frame message is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; /// Maximum number of frames in the injector buffer const MAX_INJECTOR_BUFFER_CAPACITY: usize = 32; -type ExecFn = Box; +pub enum ConnectionMessage { + Execute { + pgm: Program, + builder: Box, + }, + Describe, +} pub enum AllocationMessage { HranaPipelineReq { @@ -54,12 +64,14 @@ pub enum Database { Primary { db: PrimaryDatabase, compact_interval: Option>>, + transaction_timeout_duration: Duration, }, Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, primary_id: NodeId, last_received_frame_ts: Option, + transaction_timeout_duration: Duration, }, } @@ -68,6 +80,7 @@ impl Database { if let Self::Primary { compact_interval: Some(ref mut interval), db, + .. } = self { ready!(interval.poll_tick(cx)); @@ -81,6 +94,13 @@ impl Database { Poll::Pending } + + fn txn_timeout_duration(&self) -> Duration { + match self { + Database::Primary { transaction_timeout_duration, .. } => *transaction_timeout_duration, + Database::Replica { transaction_timeout_duration, .. } => *transaction_timeout_duration, + } + } } impl Database { @@ -96,6 +116,7 @@ impl Database { DbConfig::Primary { max_log_size, replication_log_compact_interval, + transaction_timeout_duration, } => { let (sender, receiver) = tokio::sync::watch::channel(0); let db = LibsqlDatabase::new_primary( @@ -127,11 +148,13 @@ impl Database { snapshot_store: compaction_queue.snapshot_store.clone(), }, compact_interval, + transaction_timeout_duration, } } DbConfig::Replica { primary_node_id, proxy_request_timeout_duration, + transaction_timeout_duration, } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAPACITY, ()).unwrap(); @@ -158,27 +181,40 @@ impl Database { injector_handle: sender, primary_id: primary_node_id, last_received_frame_ts: None, + transaction_timeout_duration, } } } } - fn connect(&self, connection_id: u32, alloc: &Allocation) -> impl ConnectionHandler { + fn connect( + &self, + connection_id: u32, + alloc: &Allocation, + on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static, + ) -> impl ConnectionHandler { match self { Database::Primary { db: PrimaryDatabase { db, .. }, .. - } => Either::Right(PrimaryConnection { - conn: db.connect().unwrap(), - }), - Database::Replica { db, primary_id, .. } => Either::Left(ReplicaConnection { - conn: db.connect().unwrap(), - connection_id, - next_req_id: 0, - primary_node_id: *primary_id, - database_id: DatabaseId::from_name(&alloc.db_name), - dispatcher: alloc.dispatcher.clone(), - }), + } => { + let mut conn = db.connect().unwrap(); + conn.set_on_txn_status_change_cb(on_txn_status_change_cb); + Either::Right(PrimaryConnection { conn }) + } + Database::Replica { db, primary_id, .. } => { + let mut conn = db.connect().unwrap(); + conn.reader_mut() + .set_on_txn_status_change_cb(on_txn_status_change_cb); + Either::Left(ReplicaConnection { + conn, + connection_id, + next_req_id: 0, + primary_node_id: *primary_id, + database_id: DatabaseId::from_name(&alloc.db_name), + dispatcher: alloc.dispatcher.clone(), + }) + } } } @@ -204,25 +240,21 @@ pub struct Allocation { #[derive(Clone)] pub struct ConnectionHandle { - exec: mpsc::Sender, + messages: mpsc::Sender, inbound: mpsc::Sender, } impl ConnectionHandle { - pub async fn exec(&self, f: F) -> crate::Result - where - F: for<'a> FnOnce(&'a mut dyn libsqlx::Connection) -> R + Send + 'static, - R: Send + 'static, - { - let (sender, ret) = oneshot::channel(); - let cb = move |conn: &mut dyn libsqlx::Connection| { - let res = f(conn); - let _ = sender.send(res); - }; - - self.exec.send(Box::new(cb)).await.unwrap(); - - Ok(ret.await?) + pub async fn execute( + &self, + pgm: Program, + builder: Box, + ) -> crate::Result<()> { + self.messages + .send(ConnectionMessage::Execute { pgm, builder }) + .await + .unwrap(); + Ok(()) } } @@ -362,18 +394,9 @@ impl Allocation { let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); let exec = |conn: ConnectionHandle| async move { - let _ = conn - .exec(move |conn| { - let builder = ProxyResponseBuilder::new( - dispatcher, - database_id, - to, - req_id, - connection_id, - ); - conn.execute_program(&program, Box::new(builder)).unwrap(); - }) - .await; + let builder = + ProxyResponseBuilder::new(dispatcher, database_id, to, req_id, connection_id); + conn.execute(program, Box::new(builder)).await.unwrap(); }; if self.database.is_primary() { @@ -395,20 +418,33 @@ impl Allocation { async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { let conn_id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect(conn_id, self)); - let (exec_sender, exec_receiver) = mpsc::channel(1); + let (timeout_monitor, notifier) = timeout_monitor(); + let timeout = self.database.txn_timeout_duration(); + let conn = block_in_place(|| { + self.database.connect(conn_id, self, move |is_txn| { + if is_txn { + notifier.timeout_at(Instant::now() + timeout); + } else { + notifier.disable(); + } + }) + }); + + let (messages_sender, messages_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); let id = remote.unwrap_or((self.dispatcher.node_id(), conn_id)); let conn = Connection { id, conn, - exec: exec_receiver, + messages: messages_receiver, inbound: inbound_receiver, + last_txn_timedout: false, + timeout_monitor, }; self.connections_futs.spawn(conn.run()); let handle = ConnectionHandle { - exec: exec_sender, + messages: messages_sender, inbound: inbound_sender, }; self.connections @@ -436,7 +472,7 @@ impl Allocation { #[async_trait::async_trait] trait ConnectionHandler: Send { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; - async fn handle_exec(&mut self, exec: ExecFn); + async fn handle_conn_message(&mut self, exec: ConnectionMessage); async fn handle_inbound(&mut self, msg: Inbound); } @@ -453,10 +489,10 @@ where } } - async fn handle_exec(&mut self, exec: ExecFn) { + async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match self { - Either::Left(l) => l.handle_exec(exec).await, - Either::Right(r) => r.handle_exec(exec).await, + Either::Left(l) => l.handle_conn_message(msg).await, + Either::Right(r) => r.handle_conn_message(msg).await, } } async fn handle_inbound(&mut self, msg: Inbound) { @@ -470,21 +506,37 @@ where struct Connection { id: (NodeId, u32), conn: C, - exec: mpsc::Receiver, + messages: mpsc::Receiver, inbound: mpsc::Receiver, + last_txn_timedout: bool, + timeout_monitor: TimeoutMonitor, } impl Connection { async fn run(mut self) -> (NodeId, u32) { loop { - let fut = - futures::future::join(self.exec.recv(), poll_fn(|cx| self.conn.poll_ready(cx))); + let message_ready = + futures::future::join(self.messages.recv(), poll_fn(|cx| self.conn.poll_ready(cx))); + tokio::select! { + _ = &mut self.timeout_monitor => { + self.last_txn_timedout = true; + } Some(inbound) = self.inbound.recv() => { self.conn.handle_inbound(inbound).await; } - (Some(exec), _) = fut => { - self.conn.handle_exec(exec).await; + (Some(msg), _) = message_ready => { + if self.last_txn_timedout { + self.last_txn_timedout = false; + match msg { + ConnectionMessage::Execute { mut builder, .. } => { + let _ = builder.finnalize_error("transaction has timed out".into()); + }, + ConnectionMessage::Describe => todo!(), + } + } else { + self.conn.handle_conn_message(msg).await; + } }, else => break, } @@ -498,11 +550,15 @@ impl Connection { mod test { use std::time::Duration; - use libsqlx::result_builder::ResultBuilder; + use heed::EnvOpenOptions; + use libsqlx::result_builder::{ResultBuilder, StepResultsBuilder}; + use tempfile::tempdir; use tokio::sync::Notify; use crate::allocation::replica::ReplicaConnection; + use crate::init_dirs; use crate::linc::bus::Bus; + use crate::snapshot_store::SnapshotStore; use super::*; @@ -526,13 +582,16 @@ mod test { dispatcher: bus, }; - let (exec_sender, exec) = mpsc::channel(1); + let (messages_sender, messages) = mpsc::channel(1); let (_inbound_sender, inbound) = mpsc::channel(1); + let (timeout_monitor, _) = timeout_monitor(); let connection = Connection { id: (0, 0), conn, - exec, + messages, inbound, + timeout_monitor, + last_txn_timedout: false, }; let handle = tokio::spawn(connection.run()); @@ -546,16 +605,65 @@ mod test { } let builder = Box::new(Builder(notify.clone())); - exec_sender - .send(Box::new(move |conn| { - conn.execute_program(&Program::seq(&["create table test (c)"]), builder) - .unwrap(); - })) - .await - .unwrap(); + let msg = ConnectionMessage::Execute { + pgm: Program::seq(&["create table test (c)"]), + builder, + }; + messages_sender.send(msg).await.unwrap(); notify.notified().await; handle.abort(); } + + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn txn_timeout() { + let bus = Arc::new(Bus::new(0, |_, _| async {})); + let tmp = tempdir().unwrap(); + init_dirs(tmp.path()).await.unwrap(); + let config = AllocConfig { + max_conccurent_connection: 10, + db_name: "test/db".to_owned(), + db_config: DbConfig::Primary { + max_log_size: 100000, + replication_log_compact_interval: None, + transaction_timeout_duration: Duration::from_millis(100), + }, + }; + let (sender, inbox) = mpsc::channel(10); + let env = EnvOpenOptions::new().max_dbs(10).map_size(4096 * 100).open(tmp.path()).unwrap(); + let store = Arc::new(SnapshotStore::new(tmp.path().to_path_buf(), env.clone()).unwrap()); + let queue = Arc::new(CompactionQueue::new(env, tmp.path().to_path_buf(), store).unwrap()); + let mut alloc = Allocation { + inbox, + database: Database::from_config( + &config, + tmp.path().to_path_buf(), + bus.clone(), + queue, + ), + connections_futs: JoinSet::new(), + next_conn_id: 0, + max_concurrent_connections: config.max_conccurent_connection, + hrana_server: Arc::new(hrana::http::Server::new(None)), + dispatcher: bus, // TODO: handle self URL? + db_name: config.db_name, + connections: HashMap::new(), + }; + + let conn = alloc.new_conn(None).await; + tokio::spawn(alloc.run()); + + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute(Program::seq(&["begin"]), Box::new(builder)).await.unwrap(); + rcv.await.unwrap().unwrap(); + + tokio::time::sleep(Duration::from_secs(1)).await; + + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute(Program::seq(&["create table test (x)"]), Box::new(builder)).await.unwrap(); + assert!(rcv.await.unwrap().is_err()); + } } diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index ccd67c55..505333cb 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -7,7 +7,7 @@ use std::time::Duration; use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; use libsqlx::result_builder::ResultBuilder; -use libsqlx::{Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; +use libsqlx::{Connection, Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; use tokio::task::block_in_place; use crate::linc::bus::Dispatch; @@ -16,7 +16,7 @@ use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; use crate::snapshot_store::SnapshotStore; -use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; +use super::{ConnectionHandler, ConnectionMessage, FRAMES_MESSAGE_MAX_COUNT}; pub mod compactor; @@ -317,8 +317,15 @@ impl ConnectionHandler for PrimaryConnection { Poll::Ready(()) } - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); + async fn handle_conn_message(&mut self, msg: ConnectionMessage) { + match msg { + ConnectionMessage::Execute { pgm, builder } => { + self.conn.execute_program(&pgm, builder).unwrap() + } + ConnectionMessage::Describe => { + todo!() + } + } } async fn handle_inbound(&mut self, _msg: Inbound) { diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index ee8008d6..5ff7d897 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -9,7 +9,7 @@ use libsqlx::libsql::{LibsqlConnection, LibsqlDatabase, ReplicaType}; use libsqlx::program::Program; use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; -use libsqlx::{DescribeResponse, Frame, FrameNo, Injector}; +use libsqlx::{Connection, DescribeResponse, Frame, FrameNo, Injector}; use parking_lot::Mutex; use tokio::sync::mpsc; use tokio::task::block_in_place; @@ -22,7 +22,7 @@ use crate::linc::Inbound; use crate::linc::{NodeId, Outbound}; use crate::meta::DatabaseId; -use super::{ConnectionHandler, ExecFn}; +use super::{ConnectionHandler, ConnectionMessage}; type ProxyConnection = WriteProxyConnection, RemoteConn>; pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; @@ -287,40 +287,45 @@ impl ConnectionHandler for ReplicaConnection { Poll::Ready(()) } - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); - let msg = { - let mut lock = self.conn.writer().inner.current_req.lock(); - match *lock { - Some(ref mut req) if req.id.is_none() => { - let program = req - .pgm - .take() - .expect("unsent request should have a program"); - let req_id = self.next_req_id; - self.next_req_id += 1; - req.id = Some(req_id); - - let msg = Outbound { - to: self.primary_node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::ProxyRequest { - connection_id: self.connection_id, - req_id, - program, - }, - }, - }; + async fn handle_conn_message(&mut self, msg: ConnectionMessage) { + match msg { + ConnectionMessage::Execute { pgm, builder } => { + self.conn.execute_program(&pgm, builder).unwrap(); + let msg = { + let mut lock = self.conn.writer().inner.current_req.lock(); + match *lock { + Some(ref mut req) if req.id.is_none() => { + let program = req + .pgm + .take() + .expect("unsent request should have a program"); + let req_id = self.next_req_id; + self.next_req_id += 1; + req.id = Some(req_id); + + let msg = Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyRequest { + connection_id: self.connection_id, + req_id, + program, + }, + }, + }; + + Some(msg) + } + _ => None, + } + }; - Some(msg) + if let Some(msg) = msg { + self.dispatcher.dispatch(msg).await; } - _ => None, } - }; - - if let Some(msg) = msg { - self.dispatcher.dispatch(msg).await; + ConnectionMessage::Describe => (), } } diff --git a/libsqlx-server/src/allocation/timeout_notifier.rs b/libsqlx-server/src/allocation/timeout_notifier.rs new file mode 100644 index 00000000..b64c71a8 --- /dev/null +++ b/libsqlx-server/src/allocation/timeout_notifier.rs @@ -0,0 +1,90 @@ +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll, Waker}; +use std::time::Instant; + +use futures::{Future, FutureExt}; +use parking_lot::Mutex; +use tokio::time::{sleep_until, Sleep}; + +pub fn timeout_monitor() -> (TimeoutMonitor, TimeoutNotifier) { + let inner = Arc::new(Mutex::new(TimeoutInner { + sleep: Box::pin(sleep_until(Instant::now().into())), + enabled: false, + waker: None, + })); + + ( + TimeoutMonitor { + inner: inner.clone(), + }, + TimeoutNotifier { inner }, + ) +} + +pub struct TimeoutMonitor { + inner: Arc>, +} + +pub struct TimeoutNotifier { + inner: Arc>, +} + +impl TimeoutNotifier { + pub fn disable(&self) { + self.inner.lock().enabled = false; + } + + pub fn timeout_at(&self, at: Instant) { + let mut inner = self.inner.lock(); + inner.enabled = true; + inner.sleep.as_mut().reset(at.into()); + if let Some(waker) = inner.waker.take() { + waker.wake() + } + } +} + +struct TimeoutInner { + sleep: Pin>, + enabled: bool, + waker: Option, +} + +impl Future for TimeoutMonitor { + type Output = (); + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let mut inner = self.inner.lock(); + if inner.enabled { + inner.sleep.poll_unpin(cx) + } else { + inner.waker.replace(cx.waker().clone()); + Poll::Pending + } + } +} + +#[cfg(test)] +mod test { + use std::time::Duration; + + use super::*; + + #[tokio::test] + async fn set_timeout() { + let (monitor, notifier) = timeout_monitor(); + notifier.timeout_at(Instant::now() + Duration::from_millis(100)); + monitor.await; + } + + #[tokio::test] + async fn disable_timeout() { + let (monitor, notifier) = timeout_monitor(); + notifier.timeout_at(Instant::now() + Duration::from_millis(1)); + notifier.disable(); + assert!(tokio::time::timeout(Duration::from_millis(10), monitor) + .await + .is_err()); + } +} diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index a9ed0553..6d41c8b4 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -74,15 +74,10 @@ pub async fn execute_batch( db: &ConnectionHandle, pgm: Program, ) -> color_eyre::Result { - let fut = db - .exec(move |conn| -> color_eyre::Result<_> { - let (builder, ret) = HranaBatchProtoBuilder::new(); - conn.execute_program(&pgm, Box::new(builder))?; - Ok(ret) - }) - .await??; + let (builder, ret) = HranaBatchProtoBuilder::new(); + db.execute(pgm, Box::new(builder)).await?; - Ok(fut.await?) + Ok(ret.await?) } pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { @@ -111,17 +106,11 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { } pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { - let fut = conn - .exec(move |conn| -> color_eyre::Result<_> { - let (snd, rcv) = oneshot::channel(); - let builder = StepResultsBuilder::new(snd); - conn.execute_program(&pgm, Box::new(builder))?; - - Ok(rcv) - }) - .await??; + let (snd, rcv) = oneshot::channel(); + let builder = StepResultsBuilder::new(snd); + conn.execute(pgm, Box::new(builder)).await?; - fut.await?.into_iter().try_for_each(|result| match result { + rcv.await?.map_err(|e| anyhow!("{e}"))?.into_iter().try_for_each(|result| match result { StepResult::Ok => Ok(()), StepResult::Err(e) => match stmt_error_from_sqld_error(e) { Ok(stmt_err) => Err(anyhow!(stmt_err)), diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 1a8c03f6..1b843367 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -47,17 +47,10 @@ pub async fn execute_stmt( conn: &ConnectionHandle, query: Query, ) -> color_eyre::Result { - let fut = conn - .exec(move |conn| -> color_eyre::Result<_> { - let (builder, ret) = SingleStatementBuilder::new(); - let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(&pgm, Box::new(builder))?; - - Ok(ret) - }) - .await??; - - fut.await? + let (builder, ret) = SingleStatementBuilder::new(); + let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); + conn.execute(pgm, Box::new(builder)).await?; + ret.await? .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { Ok(stmt_error) => anyhow!(stmt_error), Err(sqld_error) => anyhow!(sqld_error), diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 9b51b7ed..ff718674 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -63,6 +63,8 @@ pub struct Primary { #[serde(default = "default_max_log_size")] pub max_replication_log_size: bytesize::ByteSize, pub replication_log_compact_interval: Option, + #[serde(default = "default_txn_timeout")] + transaction_timeout_duration: HumanDuration, } #[derive(Debug)] @@ -112,6 +114,8 @@ pub enum DbConfigReq { primary_node_id: NodeId, #[serde(default = "default_proxy_timeout")] proxy_request_timeout_duration: HumanDuration, + #[serde(default = "default_txn_timeout")] + transaction_timeout_duration: HumanDuration, }, } @@ -123,6 +127,10 @@ const fn default_proxy_timeout() -> HumanDuration { HumanDuration(Duration::from_secs(5)) } +const fn default_txn_timeout() -> HumanDuration { + HumanDuration(Duration::from_secs(5)) +} + async fn allocate( State(state): State>, Json(req): Json, @@ -134,18 +142,22 @@ async fn allocate( DbConfigReq::Primary(Primary { max_replication_log_size, replication_log_compact_interval, + transaction_timeout_duration, }) => DbConfig::Primary { max_log_size: max_replication_log_size.as_u64() as usize, replication_log_compact_interval: replication_log_compact_interval .as_deref() .copied(), + transaction_timeout_duration: *transaction_timeout_duration, }, DbConfigReq::Replica { primary_node_id, proxy_request_timeout_duration, + transaction_timeout_duration, } => DbConfig::Replica { primary_node_id, proxy_request_timeout_duration: *proxy_request_timeout_duration, + transaction_timeout_duration: *transaction_timeout_duration, }, }, }; diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index 965c65a9..32f7f0e9 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -91,10 +91,6 @@ impl SnapshotStore { end_frame_no: u64::MAX.into(), }; - for entry in self.database.lazily_decode_data().iter(&txn).unwrap() { - let (k, _) = entry.unwrap(); - } - match self .database .get_lower_than_or_equal_to(&txn, &key) diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 1f0ea7ab..0ad8b780 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -1,12 +1,10 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; -use std::time::Instant; use rusqlite::{OpenFlags, Statement, StatementStatus}; use sqld_libsql_bindings::wal_hook::{WalHook, WalMethodsHook}; use crate::connection::{Connection, DescribeCol, DescribeParam, DescribeResponse}; -use crate::database::TXN_TIMEOUT; use crate::error::Error; use crate::program::{Cond, Program, Step}; use crate::query::Query; @@ -50,10 +48,12 @@ where } pub struct LibsqlConnection { - timeout_deadline: Option, conn: sqld_libsql_bindings::Connection<'static>, // holds a ref to _context, must be dropped first. row_stats_handler: Option>, builder_config: QueryBuilderConfig, + /// `true` is the connection is in an open connection state + is_txn: bool, + on_txn_status_change_cb: Option>, _context: Seal::Context>>, } @@ -65,6 +65,7 @@ impl LibsqlConnection { hook_ctx: ::Context, row_stats_callback: Option>, builder_config: QueryBuilderConfig, + on_txn_status_change_cb: Option>, ) -> Result> { let mut ctx = Box::new(hook_ctx); let this = LibsqlConnection { @@ -74,9 +75,10 @@ impl LibsqlConnection { unsafe { &mut *(ctx.as_mut() as *mut _) }, None, )?, - timeout_deadline: None, + on_txn_status_change_cb, builder_config, row_stats_handler: row_stats_callback, + is_txn: false, _context: Seal::new(ctx), }; @@ -105,18 +107,12 @@ impl LibsqlConnection { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; - let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { let res = self.execute_step(step, &results, builder)?; results.push(res); } - // A transaction is still open, set up a timeout - if is_autocommit_before && !self.conn.is_autocommit() { - self.timeout_deadline = Some(Instant::now() + TXN_TIMEOUT) - } - let is_txn = !self.conn.is_autocommit(); if !builder.finnalize(is_txn, None)? && is_txn { let _ = self.conn.execute("ROLLBACK", ()); @@ -160,6 +156,15 @@ impl LibsqlConnection { builder.finish_step(affected_row_count, last_insert_rowid)?; + let is_txn = !self.conn.is_autocommit(); + if self.is_txn != is_txn { + // txn status changed + if let Some(ref cb) = self.on_txn_status_change_cb { + cb(is_txn) + } + } + self.is_txn = is_txn; + Ok(enabled) } @@ -217,6 +222,10 @@ impl LibsqlConnection { Ok((affected_row_count, last_insert_rowid)) } + + pub fn set_on_txn_status_change_cb(&mut self, cb: impl Fn(bool) + Send + Sync + 'static) { + self.on_txn_status_change_cb = Some(Box::new(cb)); + } } fn eval_cond(cond: &Cond, results: &[bool]) -> Result { diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 1cc884a8..a42bfdc7 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -178,6 +178,7 @@ impl Database for LibsqlDatabase { QueryBuilderConfig { max_size: Some(self.response_size_limit), }, + None, )?) } } diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 61c39c64..43fa0dac 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,5 +1,3 @@ -use std::time::Duration; - use crate::connection::Connection; use crate::error::Error; @@ -13,8 +11,6 @@ pub use frame::{Frame, FrameHeader}; pub type FrameNo = u64; -pub const TXN_TIMEOUT: Duration = Duration::from_secs(5); - #[derive(Debug)] pub enum InjectError {} diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index d69ac35b..c5f159e7 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -196,7 +196,7 @@ impl StepResultsBuilder { } } -impl>> ResultBuilder for StepResultsBuilder { +impl, String>>> ResultBuilder for StepResultsBuilder { fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { self.current = None; self.step_results.clear(); @@ -248,9 +248,16 @@ impl>> ResultBuilder for StepResultsBuilder { self.ret .take() .expect("finnalize called more than once") - .send(std::mem::take(&mut self.step_results)); + .send(Ok(std::mem::take(&mut self.step_results))); Ok(true) } + + fn finnalize_error(&mut self, e: String) { + self.ret + .take() + .expect("finnalize called more than once") + .send(Err(e)); + } } impl ResultBuilder for () {} @@ -362,6 +369,10 @@ impl ResultBuilder for Take { ) -> Result { self.inner.finnalize(is_txn, frame_no) } + + fn finnalize_error(&mut self, e: String) { + self.inner.finnalize_error(e) + } } #[cfg(test)]