diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7e7e2b9b..7d1e8fe6 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,5 +1,7 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; +use std::mem::size_of; +use std::ops::Deref; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; @@ -7,12 +9,14 @@ use std::time::{Duration, Instant}; use bytes::Bytes; use either::Either; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; +use libsqlx::program::Program; use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; -use libsqlx::result_builder::ResultBuilder; +use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; use libsqlx::{ Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, ReplicationLogger, }; +use parking_lot::Mutex; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; use tokio::time::timeout; @@ -20,28 +24,26 @@ use tokio::time::timeout; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; -use crate::linc::bus::Dispatch; -use crate::linc::proto::{Enveloppe, Frames, Message}; +use crate::linc::bus::{Bus, Dispatch}; +use crate::linc::proto::{ + BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, +}; use crate::linc::{Inbound, NodeId, Outbound}; +use crate::manager::Manager; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; pub mod config; -type LibsqlConnection = Either< - libsqlx::libsql::LibsqlConnection, - WriteProxyConnection, DummyConn>, ->; -type ExecFn = Box; +/// the maximum number of frame a Frame messahe is allowed to contain +const FRAMES_MESSAGE_MAX_COUNT: usize = 5; + +type ProxyConnection = + WriteProxyConnection, RemoteConn>; +type ExecFn = Box; -#[derive(Clone)] -pub struct ConnectionId { - id: u32, - close_sender: mpsc::Sender<()>, -} pub enum AllocationMessage { - NewConnection(oneshot::Sender), HranaPipelineReq { req: PipelineRequestBody, ret: oneshot::Sender>, @@ -49,43 +51,240 @@ pub enum AllocationMessage { Inbound(Inbound), } -pub struct DummyDb; -pub struct DummyConn; +pub struct RemoteDb; -impl libsqlx::Connection for DummyConn { - fn execute_program( +#[derive(Clone)] +pub struct RemoteConn { + inner: Arc, +} + +struct Request { + id: Option, + builder: Box, + pgm: Option, + next_seq_no: u32, +} + +pub struct RemoteConnInner { + current_req: Mutex>, +} + +impl Deref for RemoteConn { + type Target = RemoteConnInner; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl libsqlx::Connection for RemoteConn { + fn execute_program( &mut self, - _pgm: &libsqlx::program::Program, - _result_builder: B, + program: &libsqlx::program::Program, + builder: Box, ) -> libsqlx::Result<()> { - todo!() + // When we need to proxy a query, we place it in the current request slot. When we are + // back in a async context, we'll send it to the primary, and asynchrously drive the + // builder. + let mut lock = self.inner.current_req.lock(); + *lock = match *lock { + Some(_) => unreachable!("conccurent request on the same connection!"), + None => Some(Request { + id: None, + builder, + pgm: Some(program.clone()), + next_seq_no: 0, + }), + }; + + Ok(()) } fn describe(&self, _sql: String) -> libsqlx::Result { - todo!() + unreachable!("Describe request should not be proxied") } } -impl libsqlx::Database for DummyDb { - type Connection = DummyConn; +impl libsqlx::Database for RemoteDb { + type Connection = RemoteConn; fn connect(&self) -> Result { - Ok(DummyConn) + Ok(RemoteConn { + inner: Arc::new(RemoteConnInner { + current_req: Default::default(), + }), + }) } } -type ProxyDatabase = WriteProxyDatabase, DummyDb>; +pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; + +pub struct PrimaryDatabase { + pub db: LibsqlDatabase, + pub replica_streams: HashMap)>, + pub frame_notifier: tokio::sync::watch::Receiver, +} + +struct ProxyResponseBuilder { + dispatcher: Arc, + buffer: Vec, + to: NodeId, + database_id: DatabaseId, + req_id: u32, + connection_id: u32, + next_seq_no: u32, +} + +const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb + +impl ProxyResponseBuilder { + fn maybe_send(&mut self) { + // FIXME: this is stupid: compute current buffer size on the go instead + let size = self + .buffer + .iter() + .map(|s| match s { + BuilderStep::FinishStep(_, _) => 2 * 8, + BuilderStep::StepError(StepError(s)) => s.len(), + BuilderStep::ColsDesc(ref d) => d + .iter() + .map(|c| c.name.len() + c.decl_ty.as_ref().map(|t| t.len()).unwrap_or_default()) + .sum(), + BuilderStep::Finnalize { .. } => 9, + BuilderStep::AddRowValue(v) => match v { + crate::linc::proto::Value::Text(s) | crate::linc::proto::Value::Blob(s) => { + s.len() + } + _ => size_of::(), + }, + _ => 8, + }) + .sum::(); + + if size > MAX_STEP_BATCH_SIZE { + self.send() + } + } + + fn send(&mut self) { + let msg = Outbound { + to: self.to, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyResponse(crate::linc::proto::ProxyResponse { + connection_id: self.connection_id, + req_id: self.req_id, + row_steps: std::mem::take(&mut self.buffer), + seq_no: self.next_seq_no, + }), + }, + }; + + self.next_seq_no += 1; + tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); + } +} + +impl ResultBuilder for ProxyResponseBuilder { + fn init( + &mut self, + _config: &libsqlx::result_builder::QueryBuilderConfig, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::Init); + self.maybe_send(); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginStep); + self.maybe_send(); + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishStep( + affected_row_count, + last_insert_rowid, + )); + self.maybe_send(); + Ok(()) + } + + fn step_error( + &mut self, + error: libsqlx::error::Error, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::StepError(StepError(error.to_string()))); + self.maybe_send(); + Ok(()) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); + self.maybe_send(); + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRows); + self.maybe_send(); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRow); + self.maybe_send(); + Ok(()) + } + + fn add_row_value( + &mut self, + v: libsqlx::result_builder::ValueRef, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::AddRowValue(v.into())); + self.maybe_send(); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRow); + self.maybe_send(); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRows); + self.maybe_send(); + Ok(()) + } + + fn finnalize( + &mut self, + is_txn: bool, + frame_no: Option, + ) -> Result { + self.buffer + .push(BuilderStep::Finnalize { is_txn, frame_no }); + self.send(); + Ok(true) + } +} pub enum Database { - Primary { - db: LibsqlDatabase, - replica_streams: HashMap)>, - frame_notifier: tokio::sync::watch::Receiver, - }, + Primary(PrimaryDatabase), Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, - primary_node_id: NodeId, + primary_id: NodeId, last_received_frame_ts: Option, }, } @@ -194,9 +393,6 @@ struct FrameStreamer { buffer: Vec, } -// the maximum number of frame a Frame messahe is allowed to contain -const FRAMES_MESSAGE_MAX_COUNT: usize = 5; - impl FrameStreamer { async fn run(mut self) { loop { @@ -261,15 +457,15 @@ impl Database { ) .unwrap(); - Self::Primary { + Self::Primary(PrimaryDatabase { db, replica_streams: HashMap::new(), frame_notifier: receiver, - } + }) } DbConfig::Replica { primary_node_id } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); - let wdb = DummyDb; + let wdb = RemoteDb; let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); let injector = db.injector().unwrap(); let (sender, receiver) = mpsc::channel(16); @@ -291,17 +487,191 @@ impl Database { Self::Replica { db, injector_handle: sender, - primary_node_id, + primary_id: primary_node_id, last_received_frame_ts: None, } } } } - fn connect(&self) -> LibsqlConnection { + fn connect(&self, connection_id: u32, alloc: &Allocation) -> impl ConnectionHandler { + match self { + Database::Primary(PrimaryDatabase { db, .. }) => Either::Right(PrimaryConnection { + conn: db.connect().unwrap(), + }), + Database::Replica { db, primary_id, .. } => Either::Left(ReplicaConnection { + conn: db.connect().unwrap(), + connection_id, + next_req_id: 0, + primary_id: *primary_id, + database_id: DatabaseId::from_name(&alloc.db_name), + dispatcher: alloc.bus.clone(), + }), + } + } + + pub fn is_primary(&self) -> bool { + matches!(self, Self::Primary(..)) + } +} + +struct PrimaryConnection { + conn: libsqlx::libsql::LibsqlConnection, +} + +#[async_trait::async_trait] +impl ConnectionHandler for PrimaryConnection { + fn exec_ready(&self) -> bool { + true + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + } + + async fn handle_inbound(&mut self, _msg: Inbound) { + tracing::debug!("primary connection received message, ignoring.") + } +} + +struct ReplicaConnection { + conn: ProxyConnection, + connection_id: u32, + next_req_id: u32, + primary_id: NodeId, + database_id: DatabaseId, + dispatcher: Arc, +} + +impl ReplicaConnection { + fn handle_proxy_response(&mut self, resp: ProxyResponse) { + let mut lock = self.conn.writer().inner.current_req.lock(); + let finnalized = match *lock { + Some(ref mut req) if req.id == Some(resp.req_id) && resp.seq_no == req.next_seq_no => { + self.next_req_id += 1; + // TODO: pass actual config + let config = QueryBuilderConfig { max_size: None }; + let mut finnalized = false; + for step in resp.row_steps.iter() { + if finnalized { break }; + match step { + BuilderStep::Init => req.builder.init(&config).unwrap(), + BuilderStep::BeginStep => req.builder.begin_step().unwrap(), + BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req + .builder + .finish_step(*affected_row_count, *last_insert_rowid) + .unwrap(), + BuilderStep::StepError(e) => req.builder.step_error(todo!()).unwrap(), + BuilderStep::ColsDesc(cols) => req + .builder + .cols_description(&mut cols.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decl_ty.as_deref(), + })) + .unwrap(), + BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), + BuilderStep::BeginRow => req.builder.begin_row().unwrap(), + BuilderStep::AddRowValue(v) => req.builder.add_row_value(v.into()).unwrap(), + BuilderStep::FinishRow => req.builder.finish_row().unwrap(), + BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), + BuilderStep::Finnalize { is_txn, frame_no } => { + let _ = req.builder.finnalize(*is_txn, *frame_no).unwrap(); + finnalized = true; + } + } + } + finnalized + } + Some(_) => todo!("error processing response"), + None => { + tracing::error!("received builder message, but there is no pending request"); + false + } + }; + + if finnalized { + *lock = None; + } + } +} + +#[async_trait::async_trait] +impl ConnectionHandler for ReplicaConnection { + fn exec_ready(&self) -> bool { + // we are currently handling a request on this connection + self.conn.writer().current_req.lock().is_none() + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + let msg = { + let mut lock = self.conn.writer().inner.current_req.lock(); + match *lock { + Some(ref mut req) if req.id.is_none() => { + let program = req + .pgm + .take() + .expect("unsent request should have a program"); + let req_id = self.next_req_id; + self.next_req_id += 1; + req.id = Some(req_id); + + let msg = Outbound { + to: self.primary_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyRequest { + connection_id: self.connection_id, + req_id, + program, + }, + }, + }; + + Some(msg) + } + _ => None, + } + }; + + if let Some(msg) = msg { + self.dispatcher.dispatch(msg).await; + } + } + + async fn handle_inbound(&mut self, msg: Inbound) { + match msg.enveloppe.message { + Message::ProxyResponse(resp) => { + self.handle_proxy_response(resp); + } + _ => (), // ignore anything else + } + } +} + +#[async_trait::async_trait] +impl ConnectionHandler for Either +where + L: ConnectionHandler, + R: ConnectionHandler, +{ + fn exec_ready(&self) -> bool { + match self { + Either::Left(l) => l.exec_ready(), + Either::Right(r) => r.exec_ready(), + } + } + + async fn handle_exec(&mut self, exec: ExecFn) { + match self { + Either::Left(l) => l.handle_exec(exec).await, + Either::Right(r) => r.handle_exec(exec).await, + } + } + async fn handle_inbound(&mut self, msg: Inbound) { match self { - Database::Primary { db, .. } => Either::Left(db.connect().unwrap()), - Database::Replica { db, .. } => Either::Right(db.connect().unwrap()), + Either::Left(l) => l.handle_inbound(msg).await, + Either::Right(r) => r.handle_inbound(msg).await, } } } @@ -310,29 +680,31 @@ pub struct Allocation { pub inbox: mpsc::Receiver, pub database: Database, /// spawned connection futures, returning their connection id on completion. - pub connections_futs: JoinSet, + pub connections_futs: JoinSet<(NodeId, u32)>, pub next_conn_id: u32, pub max_concurrent_connections: u32, + pub connections: HashMap>, pub hrana_server: Arc, - /// handle to the message bus, to send messages - pub dispatcher: Arc, + /// handle to the message bus + pub bus: Arc>>, pub db_name: String, } +#[derive(Clone)] pub struct ConnectionHandle { exec: mpsc::Sender, - exit: oneshot::Sender<()>, + inbound: mpsc::Sender, } impl ConnectionHandle { pub async fn exec(&self, f: F) -> crate::Result where - F: for<'a> FnOnce(&'a mut LibsqlConnection) -> R + Send + 'static, + F: for<'a> FnOnce(&'a mut dyn libsqlx::Connection) -> R + Send + 'static, R: Send + 'static, { let (sender, ret) = oneshot::channel(); - let cb = move |conn: &mut LibsqlConnection| { + let cb = move |conn: &mut dyn libsqlx::Connection| { let res = f(conn); let _ = sender.send(res); }; @@ -349,15 +721,12 @@ impl Allocation { tokio::select! { Some(msg) = self.inbox.recv() => { match msg { - AllocationMessage::NewConnection(ret) => { - let _ =ret.send(self.new_conn().await); - }, - AllocationMessage::HranaPipelineReq { req, ret} => { - let res = handle_pipeline(&self.hrana_server.clone(), req, || async { - let conn= self.new_conn().await; + AllocationMessage::HranaPipelineReq { req, ret } => { + let server = self.hrana_server.clone(); + handle_pipeline(server, req, ret, || async { + let conn = self.new_conn(None).await; Ok(conn) - }).await; - let _ = ret.send(res); + }).await.unwrap(); } AllocationMessage::Inbound(msg) => { self.handle_inbound(msg).await; @@ -388,11 +757,12 @@ impl Allocation { req_no, next_frame_no, } => match &mut self.database { - Database::Primary { + Database::Primary(PrimaryDatabase { db, replica_streams, frame_notifier, - } => { + .. + }) => { let streamer = FrameStreamer { logger: db.logger(), database_id: DatabaseId::from_name(&self.db_name), @@ -400,7 +770,7 @@ impl Allocation { next_frame_no, req_no, seq_no: 0, - dipatcher: self.dispatcher.clone(), + dipatcher: self.bus.clone() as _, notifier: frame_notifier.clone(), buffer: Vec::new(), }; @@ -435,62 +805,139 @@ impl Allocation { *last_received_frame_ts = Some(Instant::now()); injector_handle.send(frames).await.unwrap(); } - Database::Primary { .. } => todo!("handle primary receiving txn"), + Database::Primary(PrimaryDatabase { .. }) => todo!("handle primary receiving txn"), }, - Message::ProxyRequest { .. } => todo!(), - Message::ProxyResponse { .. } => todo!(), + Message::ProxyRequest { + connection_id, + req_id, + program, + } => { + self.handle_proxy(msg.from, connection_id, req_id, program) + .await + } + Message::ProxyResponse(ref r) => { + if let Some(conn) = self + .connections + .get(&self.bus.node_id()) + .and_then(|m| m.get(&r.connection_id).cloned()) + { + conn.inbound.send(msg).await.unwrap(); + } + } Message::CancelRequest { .. } => todo!(), Message::CloseConnection { .. } => todo!(), Message::Error(_) => todo!(), } } - async fn new_conn(&mut self) -> ConnectionHandle { - let id = self.next_conn_id(); - let conn = block_in_place(|| self.database.connect()); - let (close_sender, exit) = oneshot::channel(); + async fn handle_proxy( + &mut self, + node_id: NodeId, + connection_id: u32, + req_id: u32, + program: Program, + ) { + let dispatcher = self.bus.clone(); + let database_id = DatabaseId::from_name(&self.db_name); + let exec = |conn: ConnectionHandle| async move { + let _ = conn + .exec(move |conn| { + let builder = ProxyResponseBuilder { + dispatcher, + req_id, + buffer: Vec::new(), + to: node_id, + database_id, + connection_id, + next_seq_no: 0, + }; + conn.execute_program(&program, Box::new(builder)).unwrap(); + }) + .await; + }; + + if self.database.is_primary() { + match self + .connections + .get(&node_id) + .and_then(|m| m.get(&connection_id).cloned()) + { + Some(handle) => { + tokio::spawn(exec(handle)); + } + None => { + let handle = self.new_conn(Some((node_id, connection_id))).await; + tokio::spawn(exec(handle)); + } + } + } + } + + async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { + let conn_id = self.next_conn_id(); + let conn = block_in_place(|| self.database.connect(conn_id, self)); let (exec_sender, exec_receiver) = mpsc::channel(1); + let (inbound_sender, inbound_receiver) = mpsc::channel(1); + let id = remote.unwrap_or((self.bus.node_id(), conn_id)); let conn = Connection { id, conn, - exit, exec: exec_receiver, + inbound: inbound_receiver, }; self.connections_futs.spawn(conn.run()); - - ConnectionHandle { + let handle = ConnectionHandle { exec: exec_sender, - exit: close_sender, - } + inbound: inbound_sender, + }; + self.connections + .entry(id.0) + .or_insert_with(HashMap::new) + .insert(id.1, handle.clone()); + handle } fn next_conn_id(&mut self) -> u32 { loop { self.next_conn_id = self.next_conn_id.wrapping_add(1); - return self.next_conn_id; - // if !self.connections.contains_key(&self.next_conn_id) { - // return self.next_conn_id; - // } + if self + .connections + .get(&self.bus.node_id()) + .and_then(|m| m.get(&self.next_conn_id)) + .is_none() + { + return self.next_conn_id; + } } } } -struct Connection { - id: u32, - conn: LibsqlConnection, - exit: oneshot::Receiver<()>, +struct Connection { + id: (NodeId, u32), + conn: C, exec: mpsc::Receiver, + inbound: mpsc::Receiver, +} + +#[async_trait::async_trait] +trait ConnectionHandler: Send { + fn exec_ready(&self) -> bool; + async fn handle_exec(&mut self, exec: ExecFn); + async fn handle_inbound(&mut self, msg: Inbound); } -impl Connection { - async fn run(mut self) -> u32 { +impl Connection { + async fn run(mut self) -> (NodeId, u32) { loop { tokio::select! { - _ = &mut self.exit => break, - Some(exec) = self.exec.recv() => { - tokio::task::block_in_place(|| exec(&mut self.conn)); + Some(inbound) = self.inbound.recv() => { + self.conn.handle_inbound(inbound).await; } + Some(exec) = self.exec.recv(), if self.conn.exec_ready() => { + self.conn.handle_exec(exec).await; + }, + else => break, } } diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index c4131c45..a9ed0553 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -8,7 +8,6 @@ use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; use super::{proto, ProtocolError, Version}; use color_eyre::eyre::anyhow; -use libsqlx::Connection; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; use libsqlx::query::{Params, Query}; @@ -78,7 +77,7 @@ pub async fn execute_batch( let fut = db .exec(move |conn| -> color_eyre::Result<_> { let (builder, ret) = HranaBatchProtoBuilder::new(); - conn.execute_program(&pgm, builder)?; + conn.execute_program(&pgm, Box::new(builder))?; Ok(ret) }) .await??; @@ -116,20 +115,18 @@ pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_ey .exec(move |conn| -> color_eyre::Result<_> { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); - conn.execute_program(&pgm, builder)?; + conn.execute_program(&pgm, Box::new(builder))?; Ok(rcv) }) .await??; - fut.await? - .into_iter() - .try_for_each(|result| match result { - StepResult::Ok => Ok(()), - StepResult::Err(e) => match stmt_error_from_sqld_error(e) { - Ok(stmt_err) => Err(anyhow!(stmt_err)), - Err(sqld_err) => Err(anyhow!(sqld_err)), - }, - StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), - }) + fut.await?.into_iter().try_for_each(|result| match result { + StepResult::Ok => Ok(()), + StepResult::Err(e) => match stmt_error_from_sqld_error(e) { + Ok(stmt_err) => Err(anyhow!(stmt_err)), + Err(sqld_err) => Err(anyhow!(sqld_err)), + }, + StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + }) } diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index 651ab3f0..521d33ff 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -1,7 +1,10 @@ +use std::sync::Arc; + use color_eyre::eyre::Context; use futures::Future; use parking_lot::Mutex; use serde::{de::DeserializeOwned, Serialize}; +use tokio::sync::oneshot; use crate::allocation::ConnectionHandle; @@ -47,31 +50,38 @@ fn handle_index() -> color_eyre::Result> { } pub async fn handle_pipeline( - server: &Server, + server: Arc, req: PipelineRequestBody, + ret: oneshot::Sender>, mk_conn: F, -) -> color_eyre::Result +) -> color_eyre::Result<()> where F: FnOnce() -> Fut, Fut: Future>, { - let mut stream_guard = stream::acquire(server, req.baton.as_deref(), mk_conn).await?; - - let mut results = Vec::with_capacity(req.requests.len()); - for request in req.requests.into_iter() { - let result = request::handle(&mut stream_guard, request) - .await - .context("Could not execute a request in pipeline")?; - results.push(result); - } - - let resp_body = proto::PipelineResponseBody { - baton: stream_guard.release(), - base_url: server.self_url.clone(), - results, - }; - - Ok(resp_body) + let mut stream_guard = stream::acquire(server.clone(), req.baton.as_deref(), mk_conn).await?; + + tokio::spawn(async move { + let f = async move { + let mut results = Vec::with_capacity(req.requests.len()); + for request in req.requests.into_iter() { + let result = request::handle(&mut stream_guard, request) + .await + .context("Could not execute a request in pipeline")?; + results.push(result); + } + + Ok(proto::PipelineResponseBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + results, + }) + }; + + let _ = ret.send(f.await); + }); + + Ok(()) } async fn read_request_json( diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs index ac6d8912..eb1623cd 100644 --- a/libsqlx-server/src/hrana/http/request.rs +++ b/libsqlx-server/src/hrana/http/request.rs @@ -13,7 +13,7 @@ pub enum StreamResponseError { } pub async fn handle( - stream_guard: &mut stream::Guard<'_>, + stream_guard: &mut stream::Guard, request: proto::StreamRequest, ) -> color_eyre::Result { let result = match try_handle(stream_guard, request).await { @@ -31,7 +31,7 @@ pub async fn handle( } async fn try_handle( - stream_guard: &mut stream::Guard<'_>, + stream_guard: &mut stream::Guard, request: proto::StreamRequest, ) -> color_eyre::Result { Ok(match request { diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index 5f40537e..25c1e719 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -1,6 +1,7 @@ use std::cmp::Reverse; use std::collections::{HashMap, VecDeque}; use std::pin::Pin; +use std::sync::Arc; use std::{future, mem, task}; use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; @@ -67,8 +68,8 @@ struct Stream { /// Guard object that is used to access a stream from the outside. The guard makes sure that the /// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with /// [`Handle::Available`] after the guard goes out of scope. -pub struct Guard<'srv> { - server: &'srv Server, +pub struct Guard { + server: Arc, /// The guarded stream. This is only set to `None` in the destructor. stream: Option>, /// If set to `true`, the destructor will release the stream for further use (saving it as @@ -101,18 +102,18 @@ impl ServerStreamState { /// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream, /// otherwise we create a new stream. -pub async fn acquire<'srv, F, Fut>( - server: &'srv Server, +pub async fn acquire( + server: Arc, baton: Option<&str>, mk_conn: F, -) -> color_eyre::Result> +) -> color_eyre::Result where F: FnOnce() -> Fut, Fut: Future>, { let stream = match baton { Some(baton) => { - let (stream_id, baton_seq) = decode_baton(server, baton)?; + let (stream_id, baton_seq) = decode_baton(&server, baton)?; let mut state = server.stream_state.lock(); let handle = state.handles.get_mut(&stream_id); @@ -182,7 +183,7 @@ where }) } -impl<'srv> Guard<'srv> { +impl Guard { pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> { let stream = self.stream.as_ref().unwrap(); stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed) @@ -211,7 +212,7 @@ impl<'srv> Guard<'srv> { if stream.conn.is_some() { self.release = true; // tell destructor to make the stream available again Some(encode_baton( - self.server, + &self.server, stream.stream_id, stream.baton_seq, )) @@ -221,7 +222,7 @@ impl<'srv> Guard<'srv> { } } -impl<'srv> Drop for Guard<'srv> { +impl Drop for Guard { fn drop(&mut self) { let stream = self.stream.take().unwrap(); let stream_id = stream.stream_id; diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index 1047f091..e91bca28 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -11,16 +11,22 @@ use super::proto; pub struct SingleStatementBuilder { builder: StatementBuilder, - ret: oneshot::Sender>, + ret: Option>>, } impl SingleStatementBuilder { - pub fn new() -> (Self, oneshot::Receiver>) { + pub fn new() -> ( + Self, + oneshot::Receiver>, + ) { let (ret, rcv) = oneshot::channel(); - (Self { - builder: StatementBuilder::default(), - ret, - }, rcv) + ( + Self { + builder: StatementBuilder::default(), + ret: Some(ret), + }, + rcv, + ) } } @@ -38,7 +44,8 @@ impl ResultBuilder for SingleStatementBuilder { affected_row_count: u64, last_insert_rowid: Option, ) -> Result<(), QueryResultBuilderError> { - self.builder.finish_step(affected_row_count, last_insert_rowid) + self.builder + .finish_step(affected_row_count, last_insert_rowid) } fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { @@ -61,19 +68,16 @@ impl ResultBuilder for SingleStatementBuilder { } fn finnalize( - self, + &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result - where Self: Sized - { - let res = self.builder.into_ret(); - let _ = self.ret.send(res); + ) -> Result { + let res = self.builder.take_ret(); + let _ = self.ret.take().unwrap().send(res); Ok(true) } } - #[derive(Debug, Default)] struct StatementBuilder { has_step: bool, @@ -191,12 +195,12 @@ impl StatementBuilder { Ok(()) } - pub fn into_ret(self) -> Result { - match self.err { + pub fn take_ret(&mut self) -> Result { + match self.err.take() { Some(err) => Err(err), None => Ok(proto::StmtResult { - cols: self.cols, - rows: self.rows, + cols: std::mem::take(&mut self.cols), + rows: std::mem::take(&mut self.rows), affected_row_count: self.affected_row_count, last_insert_rowid: self.last_insert_rowid, }), @@ -262,23 +266,24 @@ pub struct HranaBatchProtoBuilder { current_size: u64, max_response_size: u64, step_empty: bool, - ret: oneshot::Sender + ret: oneshot::Sender, } impl HranaBatchProtoBuilder { pub fn new() -> (Self, oneshot::Receiver) { let (ret, rcv) = oneshot::channel(); - (Self { - step_results: Vec::new(), - step_errors: Vec::new(), - stmt_builder: StatementBuilder::default(), - current_size: 0, - max_response_size: u64::MAX, - step_empty: false, - ret, - }, - rcv) - + ( + Self { + step_results: Vec::new(), + step_errors: Vec::new(), + stmt_builder: StatementBuilder::default(), + current_size: 0, + max_response_size: u64::MAX, + step_empty: false, + ret, + }, + rcv, + ) } pub fn into_ret(self) -> proto::BatchResult { proto::BatchResult { @@ -314,7 +319,7 @@ impl ResultBuilder for HranaBatchProtoBuilder { max_response_size: self.max_response_size - self.current_size, ..Default::default() }; - match std::mem::replace(&mut self.stmt_builder, new_builder).into_ret() { + match std::mem::replace(&mut self.stmt_builder, new_builder).take_ret() { Ok(res) => { self.step_results.push((!self.step_empty).then_some(res)); self.step_errors.push(None); diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index e6c002a1..1a8c03f6 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -3,7 +3,6 @@ use std::collections::HashMap; use color_eyre::eyre::{anyhow, bail}; use libsqlx::analysis::Statement; use libsqlx::query::{Params, Query, Value}; -use libsqlx::Connection; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; @@ -52,7 +51,7 @@ pub async fn execute_stmt( .exec(move |conn| -> color_eyre::Result<_> { let (builder, ret) = SingleStatementBuilder::new(); let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute_program(&pgm, builder)?; + conn.execute_program(&pgm, Box::new(builder))?; Ok(ret) }) diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index bf5bd97e..09e2ec44 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -163,7 +163,7 @@ where self.conn.feed(m).await.unwrap(); } self.conn.flush().await.unwrap(); - } + }, else => { self.state = ConnectionState::Close; } diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index bec6ff7a..a9aa529d 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -1,4 +1,5 @@ use bytes::Bytes; +use libsqlx::{program::Program, FrameNo}; use serde::{Deserialize, Serialize}; use uuid::Uuid; @@ -6,9 +7,7 @@ use crate::meta::DatabaseId; use super::NodeId; -pub type Program = String; - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize)] pub struct Enveloppe { pub database_id: Option, pub message: Message, @@ -25,7 +24,18 @@ pub struct Frames { pub frames: Vec, } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +#[derive(Debug, Serialize, Deserialize)] +/// Response to a proxied query +pub struct ProxyResponse { + pub connection_id: u32, + /// id of the request this message is a response to. + pub req_id: u32, + pub seq_no: u32, + /// Collection of steps to drive the query builder transducer. + pub row_steps: Vec, +} + +#[derive(Debug, Serialize, Deserialize)] pub enum Message { /// Initial message exchanged between nodes when connecting Handshake { @@ -58,13 +68,7 @@ pub enum Message { req_id: u32, program: Program, }, - /// Response to a proxied query - ProxyResponse { - /// id of the request this message is a response to. - req_id: u32, - /// Collection of steps to drive the query builder transducer. - row_step: Vec, - }, + ProxyResponse(ProxyResponse), /// Stop processing request `id`. CancelRequest { req_id: u32, @@ -85,101 +89,79 @@ pub enum ProtoError { UnknownDatabase(String), } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ReplicationMessage { - ReplicationHandshake { - database_name: String, - }, - ReplicationHandshakeResponse { - /// id of the replication log - log_id: Uuid, - /// current frame_no of the primary - current_frame_no: u64, - }, - Replicate { - /// next frame no to send - next_frame_no: u64, - }, - /// a batch of frames that are part of the same transaction - Transaction { - /// if not None, then the last frame is a commit frame, and this is the new size of the database. - size_after: Option, - /// frame_no of the last frame in frames - end_frame_no: u64, - /// a batch of frames part of the transaction. - frames: Vec, - }, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub struct Frame { - /// Page id of that frame - page_id: u32, - /// Data - data: Bytes, -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ProxyMessage { - /// Proxy a query to a primary - ProxyRequest { - /// id of the connection to perform the query against - /// If the connection doesn't already exist it is created - /// Id of the request. - /// Responses to this request must have the same id. - connection_id: u32, - req_id: u32, - program: Program, - }, - /// Response to a proxied query - ProxyResponse { - /// id of the request this message is a response to. - req_id: u32, - /// Collection of steps to drive the query builder transducer. - row_step: Vec, - }, - /// Stop processing request `id`. - CancelRequest { req_id: u32 }, - /// Close Connection with passed id. - CloseConnection { connection_id: u32 }, -} - /// Steps applied to the query builder transducer to build a response to a proxied query. /// Those types closely mirror those of the `QueryBuilderTrait`. -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] +#[derive(Debug, Serialize, Deserialize, Clone)] pub enum BuilderStep { + Init, BeginStep, - FinishStep(u64, Option), + FinishStep(u64, Option), StepError(StepError), ColsDesc(Vec), BeginRows, BeginRow, AddRowValue(Value), FinishRow, - FinishRos, - Finish(ConnectionState), + FinishRows, + Finnalize { + is_txn: bool, + frame_no: Option, + }, } -// State of the connection after a query was executed -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ConnectionState { - /// The connection is still in a open transaction state - OpenTxn, - /// The connection is idle. - Idle, +#[derive(Debug, Serialize, Deserialize, Clone)] +pub enum Value { + Null, + Integer(i64), + Real(f64), + // TODO: how to stream blobs/string??? + Text(Vec), + Blob(Vec), } -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum Value {} +impl<'a> Into> for &'a Value { + fn into(self) -> libsqlx::result_builder::ValueRef<'a> { + use libsqlx::result_builder::ValueRef; + match self { + Value::Null => ValueRef::Null, + Value::Integer(i) => ValueRef::Integer(*i), + Value::Real(x) => ValueRef::Real(*x), + Value::Text(ref t) => ValueRef::Text(t), + Value::Blob(ref b) => ValueRef::Blob(b), + } + } +} + +impl From> for Value { + fn from(value: libsqlx::result_builder::ValueRef) -> Self { + use libsqlx::result_builder::ValueRef; + match value { + ValueRef::Null => Self::Null, + ValueRef::Integer(i) => Self::Integer(i), + ValueRef::Real(x) => Self::Real(x), + ValueRef::Text(s) => Self::Text(s.into()), + ValueRef::Blob(b) => Self::Blob(b.into()), + } + } +} #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct Column { /// name of the column - name: String, + pub name: String, /// Declared type of the column, if any. - decl_ty: Option, + pub decl_ty: Option, +} + +impl From> for Column { + fn from(value: libsqlx::result_builder::Column) -> Self { + Self { + name: value.name.to_string(), + decl_ty: value.decl_ty.map(ToOwned::to_owned), + } + } } /// for now, the stringified version of a sqld::error::Error. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub struct StepError(String); +pub struct StepError(pub String); diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 89604569..01870144 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -1,3 +1,4 @@ +use std::collections::HashMap; use std::path::PathBuf; use std::sync::Arc; @@ -20,10 +21,6 @@ pub struct Manager { const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; -trait IsSync: Sync {} - -impl IsSync for Allocation {} - impl Manager { pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { Self { @@ -54,8 +51,9 @@ impl Manager { next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, hrana_server: Arc::new(hrana::http::Server::new(None)), - dispatcher: bus, // TODO: handle self URL? + bus, // TODO: handle self URL? db_name: config.db_name, + connections: HashMap::new(), }; tokio::spawn(alloc.run()); diff --git a/libsqlx/src/analysis.rs b/libsqlx/src/analysis.rs index 97ef5f5b..0706ebff 100644 --- a/libsqlx/src/analysis.rs +++ b/libsqlx/src/analysis.rs @@ -1,9 +1,10 @@ use fallible_iterator::FallibleIterator; +use serde::{Deserialize, Serialize}; use sqlite3_parser::ast::{Cmd, PragmaBody, QualifiedName, Stmt}; use sqlite3_parser::lexer::sql::{Parser, ParserError}; /// A group of statements to be executed together. -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Statement { pub stmt: String, pub kind: StmtKind, @@ -19,7 +20,7 @@ impl Default for Statement { } /// Classify statement in categories of interest. -#[derive(Debug, PartialEq, Clone, Copy)] +#[derive(Debug, PartialEq, Clone, Copy, Serialize, Deserialize)] pub enum StmtKind { /// The begining of a transaction TxnBegin, diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index fa027997..e767073a 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -24,10 +24,10 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - result_builder: B, + result_builder: Box, ) -> crate::Result<()>; /// Parse the SQL statement and return information about it. @@ -39,10 +39,10 @@ where T: Connection, X: Connection, { - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - result_builder: B, + result_builder: Box, ) -> crate::Result<()> { match self { Either::Left(c) => c.execute_program(pgm, result_builder), diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 27ee59e1..1f0ea7ab 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -101,14 +101,14 @@ impl LibsqlConnection { &self.conn } - fn run(&mut self, pgm: &Program, mut builder: B) -> Result<()> { + fn run(&mut self, pgm: &Program, builder: &mut dyn ResultBuilder) -> Result<()> { let mut results = Vec::with_capacity(pgm.steps.len()); builder.init(&self.builder_config)?; let is_autocommit_before = self.conn.is_autocommit(); for step in pgm.steps() { - let res = self.execute_step(step, &results, &mut builder)?; + let res = self.execute_step(step, &results, builder)?; results.push(res); } @@ -125,11 +125,11 @@ impl LibsqlConnection { Ok(()) } - fn execute_step( + fn execute_step( &mut self, step: &Step, results: &[bool], - builder: &mut B, + builder: &mut dyn ResultBuilder, ) -> Result { builder.begin_step()?; let mut enabled = match step.cond.as_ref() { @@ -163,10 +163,10 @@ impl LibsqlConnection { Ok(enabled) } - fn execute_query( + fn execute_query( &self, query: &Query, - builder: &mut B, + builder: &mut dyn ResultBuilder, ) -> Result<(u64, Option)> { tracing::trace!("executing query: {}", query.stmt.stmt); @@ -240,12 +240,12 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - builder: B, + mut builder: Box, ) -> crate::Result<()> { - self.run(pgm, builder) + self.run(pgm, &mut *builder) } fn describe(&self, sql: String) -> crate::Result { diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index 68d10c00..a06b6620 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -1,3 +1,7 @@ +use std::sync::Arc; + +use parking_lot::Mutex; + use crate::connection::{Connection, DescribeResponse}; use crate::database::FrameNo; use crate::program::Program; @@ -14,31 +18,48 @@ pub(crate) struct ConnState { /// A connection that proxies write operations to the `WriteDb` and the read operations to the /// `ReadDb` -pub struct WriteProxyConnection { - pub(crate) read_db: ReadDb, - pub(crate) write_db: WriteDb, +pub struct WriteProxyConnection { + pub(crate) read_conn: R, + pub(crate) write_conn: W, pub(crate) wait_frame_no_cb: WaitFrameNoCb, - pub(crate) state: ConnState, + pub(crate) state: Arc>, +} + +impl WriteProxyConnection { + pub fn writer_mut(&mut self) -> &mut W { + &mut self.write_conn + } + + pub fn writer(&self) -> &W { + &self.write_conn + } + + pub fn reader_mut(&mut self) -> &mut R { + &mut self.read_conn + } + + pub fn reader(&self) -> &R { + &self.read_conn + } } -struct MaybeRemoteExecBuilder<'a, 'b, B, W> { - builder: B, - conn: &'a mut W, - pgm: &'b Program, - state: &'a mut ConnState, +struct MaybeRemoteExecBuilder { + builder: Option>, + conn: W, + pgm: Program, + state: Arc>, } -impl<'a, 'b, B, W> ResultBuilder for MaybeRemoteExecBuilder<'a, 'b, B, W> +impl ResultBuilder for MaybeRemoteExecBuilder where - W: Connection, - B: ResultBuilder, + W: Connection + Send + 'static, { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { - self.builder.init(config) + self.builder.as_mut().unwrap().init(config) } fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.begin_step() + self.builder.as_mut().unwrap().begin_step() } fn finish_step( @@ -47,45 +68,47 @@ where last_insert_rowid: Option, ) -> Result<(), QueryResultBuilderError> { self.builder + .as_mut() + .unwrap() .finish_step(affected_row_count, last_insert_rowid) } fn step_error(&mut self, error: crate::error::Error) -> Result<(), QueryResultBuilderError> { - self.builder.step_error(error) + self.builder.as_mut().unwrap().step_error(error) } fn cols_description( &mut self, cols: &mut dyn Iterator, ) -> Result<(), QueryResultBuilderError> { - self.builder.cols_description(cols) + self.builder.as_mut().unwrap().cols_description(cols) } fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.begin_rows() + self.builder.as_mut().unwrap().begin_rows() } fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.begin_row() + self.builder.as_mut().unwrap().begin_row() } fn add_row_value( &mut self, v: rusqlite::types::ValueRef, ) -> Result<(), QueryResultBuilderError> { - self.builder.add_row_value(v) + self.builder.as_mut().unwrap().add_row_value(v) } fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.finish_row() + self.builder.as_mut().unwrap().finish_row() } fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { - self.builder.finish_rows() + self.builder.as_mut().unwrap().finish_rows() } fn finnalize( - self, + &mut self, is_txn: bool, frame_no: Option, ) -> Result { @@ -93,70 +116,75 @@ where // a read only connection is not allowed to leave an open transaction. We mispredicted the // final state of the connection, so we rollback, and execute again on the write proxy. let builder = ExtractFrameNoBuilder { - builder: self.builder, - state: self.state, + builder: self + .builder + .take() + .expect("finnalize called more than once"), + state: self.state.clone(), }; - self.conn.execute_program(self.pgm, builder).unwrap(); + self.conn + .execute_program(&self.pgm, Box::new(builder)) + .unwrap(); Ok(false) } else { - self.builder.finnalize(is_txn, frame_no) + self.builder.as_mut().unwrap().finnalize(is_txn, frame_no) } } } -impl Connection for WriteProxyConnection +impl Connection for WriteProxyConnection where - ReadDb: Connection, - WriteDb: Connection, + R: Connection, + W: Connection + Clone + Send + 'static, { - fn execute_program( + fn execute_program( &mut self, pgm: &Program, - builder: B, + builder: Box, ) -> crate::Result<()> { - if !self.state.is_txn && pgm.is_read_only() { - if let Some(frame_no) = self.state.last_frame_no { + if !self.state.lock().is_txn && pgm.is_read_only() { + if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); } let builder = MaybeRemoteExecBuilder { - builder, - conn: &mut self.write_db, - state: &mut self.state, - pgm, + builder: Some(builder), + conn: self.write_conn.clone(), + state: self.state.clone(), + pgm: pgm.clone(), }; // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - self.read_db.execute_program(pgm, builder)?; + self.read_conn.execute_program(pgm, Box::new(builder))?; // rollback(&mut self.conn.read_db); Ok(()) } else { let builder = ExtractFrameNoBuilder { builder, - state: &mut self.state, + state: self.state.clone(), }; - self.write_db.execute_program(pgm, builder)?; + self.write_conn.execute_program(pgm, Box::new(builder))?; Ok(()) } } fn describe(&self, sql: String) -> crate::Result { - if let Some(frame_no) = self.state.last_frame_no { + if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); } - self.read_db.describe(sql) + self.read_conn.describe(sql) } } -struct ExtractFrameNoBuilder<'a, B> { - builder: B, - state: &'a mut ConnState, +struct ExtractFrameNoBuilder { + builder: Box, + state: Arc>, } -impl ResultBuilder for ExtractFrameNoBuilder<'_, B> { +impl ResultBuilder for ExtractFrameNoBuilder { fn init(&mut self, config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { self.builder.init(config) } @@ -209,12 +237,13 @@ impl ResultBuilder for ExtractFrameNoBuilder<'_, B> { } fn finnalize( - self, + &mut self, is_txn: bool, frame_no: Option, ) -> Result { - self.state.last_frame_no = frame_no; - self.state.is_txn = is_txn; + let mut state = self.state.lock(); + state.last_frame_no = frame_no; + state.is_txn = is_txn; self.builder.finnalize(is_txn, frame_no) } } @@ -225,7 +254,6 @@ mod test { use std::rc::Rc; use std::sync::Arc; - use crate::connection::Connection; use crate::database::test_utils::MockDatabase; use crate::database::{proxy::database::WriteProxyDatabase, Database}; use crate::program::Program; diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index e9add71f..adc82ed8 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -24,13 +24,14 @@ impl Database for WriteProxyDatabase where RDB: Database, WDB: Database, + WDB::Connection: Clone + Send + 'static, { type Connection = WriteProxyConnection; /// Create a new connection to the database fn connect(&self) -> Result { Ok(WriteProxyConnection { - read_db: self.read_db.connect()?, - write_db: self.write_db.connect()?, + read_conn: self.read_db.connect()?, + write_conn: self.write_db.connect()?, wait_frame_no_cb: self.wait_frame_no_cb.clone(), state: Default::default(), }) diff --git a/libsqlx/src/program.rs b/libsqlx/src/program.rs index 0b5c7980..b2b627af 100644 --- a/libsqlx/src/program.rs +++ b/libsqlx/src/program.rs @@ -1,8 +1,10 @@ use std::sync::Arc; +use serde::{Deserialize, Serialize}; + use crate::query::Query; -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Program { pub steps: Arc<[Step]>, } @@ -59,13 +61,13 @@ impl Program { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Step { pub cond: Option, pub query: Query, } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub enum Cond { Ok { step: usize }, Err { step: usize }, diff --git a/libsqlx/src/query.rs b/libsqlx/src/query.rs index 2d37e514..d3b1e5eb 100644 --- a/libsqlx/src/query.rs +++ b/libsqlx/src/query.rs @@ -46,7 +46,7 @@ impl TryFrom> for Value { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, Serialize, Deserialize)] pub struct Query { pub stmt: Statement, pub params: Params, @@ -67,7 +67,7 @@ impl ToSql for Value { } } -#[derive(Debug, Serialize, Clone)] +#[derive(Debug, Serialize, Clone, Deserialize)] pub enum Params { Named(HashMap), Positional(Vec), diff --git a/libsqlx/src/result_builder.rs b/libsqlx/src/result_builder.rs index 98f598c1..fed13fd3 100644 --- a/libsqlx/src/result_builder.rs +++ b/libsqlx/src/result_builder.rs @@ -80,7 +80,7 @@ pub struct QueryBuilderConfig { pub max_size: Option, } -pub trait ResultBuilder { +pub trait ResultBuilder: Send + 'static { /// (Re)initialize the builder. This method can be called multiple times. fn init(&mut self, _config: &QueryBuilderConfig) -> Result<(), QueryResultBuilderError> { Ok(()) @@ -132,13 +132,10 @@ pub trait ResultBuilder { /// finish the builder, and pass the transaction state. /// If false is returned, and is_txn is true, then the transaction is rolledback. fn finnalize( - self, + &mut self, _is_txn: bool, _frame_no: Option, - ) -> Result - where - Self: Sized, - { + ) -> Result { Ok(true) } } @@ -171,15 +168,15 @@ pub struct StepResultsBuilder { current: Option, step_results: Vec, is_skipped: bool, - ret: R + ret: Option, } -pub trait RetChannel { +pub trait RetChannel: Send + 'static { fn send(self, t: T); } #[cfg(feature = "tokio")] -impl RetChannel for tokio::sync::oneshot::Sender { +impl RetChannel for tokio::sync::oneshot::Sender { fn send(self, t: T) { let _ = self.send(t); } @@ -191,7 +188,7 @@ impl StepResultsBuilder { current: None, step_results: Vec::new(), is_skipped: false, - ret, + ret: Some(ret), } } } @@ -241,11 +238,14 @@ impl>> ResultBuilder for StepResultsBuilder { } fn finnalize( - self, + &mut self, _is_txn: bool, _frame_no: Option, ) -> Result { - self.ret.send(self.step_results); + self.ret + .take() + .expect("finnalize called more than once") + .send(std::mem::take(&mut self.step_results)); Ok(true) } } @@ -353,7 +353,7 @@ impl ResultBuilder for Take { } fn finnalize( - self, + &mut self, is_txn: bool, frame_no: Option, ) -> Result {