From b587be53332bedd53994f6e273b1654e6c0a9cc6 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 20 Jul 2023 17:55:55 +0200 Subject: [PATCH] reorganize allocation file --- libsqlx-server/src/allocation/mod.rs | 660 ++--------------------- libsqlx-server/src/allocation/primary.rs | 275 ++++++++++ libsqlx-server/src/allocation/replica.rs | 342 ++++++++++++ 3 files changed, 676 insertions(+), 601 deletions(-) create mode 100644 libsqlx-server/src/allocation/primary.rs create mode 100644 libsqlx-server/src/allocation/replica.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 7bea3ddd..8c7ba873 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,49 +1,40 @@ use std::collections::hash_map::Entry; use std::collections::HashMap; use std::future::poll_fn; -use std::mem::size_of; -use std::ops::Deref; use std::path::PathBuf; -use std::pin::Pin; use std::sync::Arc; use std::task::{Context, Poll}; -use std::time::{Duration, Instant}; +use std::time::Instant; -use bytes::Bytes; use either::Either; -use futures::Future; -use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; +use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile}; use libsqlx::program::Program; -use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; -use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; -use libsqlx::{ - Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, - ReplicationLogger, -}; -use parking_lot::Mutex; +use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::{Database as _, InjectableDatabase}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; -use tokio::time::{timeout, Sleep}; +use crate::allocation::primary::FrameStreamer; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; -use crate::linc::proto::{ - BuilderStep, Enveloppe, Frames, Message, ProxyResponse, StepError, Value, -}; -use crate::linc::{Inbound, NodeId, Outbound}; +use crate::linc::proto::{Frames, Message}; +use crate::linc::{Inbound, NodeId}; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; +use self::primary::{PrimaryConnection, PrimaryDatabase, ProxyResponseBuilder}; +use self::replica::{ProxyDatabase, RemoteDb, ReplicaConnection, Replicator}; pub mod config; +mod primary; +mod replica; /// the maximum number of frame a Frame messahe is allowed to contain const FRAMES_MESSAGE_MAX_COUNT: usize = 5; +const MAX_INJECTOR_BUFFER_CAP: usize = 32; -type ProxyConnection = - WriteProxyConnection, RemoteConn>; type ExecFn = Box; pub enum AllocationMessage { @@ -54,240 +45,6 @@ pub enum AllocationMessage { Inbound(Inbound), } -pub struct RemoteDb { - proxy_request_timeout_duration: Duration, -} - -#[derive(Clone)] -pub struct RemoteConn { - inner: Arc, -} - -struct Request { - id: Option, - builder: Box, - pgm: Option, - next_seq_no: u32, - timeout: Pin>, -} - -pub struct RemoteConnInner { - current_req: Mutex>, - request_timeout_duration: Duration, -} - -impl Deref for RemoteConn { - type Target = RemoteConnInner; - - fn deref(&self) -> &Self::Target { - self.inner.as_ref() - } -} - -impl libsqlx::Connection for RemoteConn { - fn execute_program( - &mut self, - program: &libsqlx::program::Program, - builder: Box, - ) -> libsqlx::Result<()> { - // When we need to proxy a query, we place it in the current request slot. When we are - // back in a async context, we'll send it to the primary, and asynchrously drive the - // builder. - let mut lock = self.inner.current_req.lock(); - *lock = match *lock { - Some(_) => unreachable!("conccurent request on the same connection!"), - None => Some(Request { - id: None, - builder, - pgm: Some(program.clone()), - next_seq_no: 0, - timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), - }), - }; - - Ok(()) - } - - fn describe(&self, _sql: String) -> libsqlx::Result { - unreachable!("Describe request should not be proxied") - } -} - -impl libsqlx::Database for RemoteDb { - type Connection = RemoteConn; - - fn connect(&self) -> Result { - Ok(RemoteConn { - inner: Arc::new(RemoteConnInner { - current_req: Default::default(), - request_timeout_duration: self.proxy_request_timeout_duration, - }), - }) - } -} - -pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; - -pub struct PrimaryDatabase { - pub db: LibsqlDatabase, - pub replica_streams: HashMap)>, - pub frame_notifier: tokio::sync::watch::Receiver, -} - -struct ProxyResponseBuilder { - dispatcher: Arc, - buffer: Vec, - to: NodeId, - database_id: DatabaseId, - req_id: u32, - connection_id: u32, - next_seq_no: u32, -} - -const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb - -impl ProxyResponseBuilder { - fn maybe_send(&mut self) { - // FIXME: this is stupid: compute current buffer size on the go instead - let size = self - .buffer - .iter() - .map(|s| match s { - BuilderStep::FinishStep(_, _) => 2 * 8, - BuilderStep::StepError(StepError(s)) => s.len(), - BuilderStep::ColsDesc(ref d) => d - .iter() - .map(|c| c.name.len() + c.decl_ty.as_ref().map(|t| t.len()).unwrap_or_default()) - .sum(), - BuilderStep::Finnalize { .. } => 9, - BuilderStep::AddRowValue(v) => match v { - crate::linc::proto::Value::Text(s) | crate::linc::proto::Value::Blob(s) => { - s.len() - } - _ => size_of::(), - }, - _ => 8, - }) - .sum::(); - - if size > MAX_STEP_BATCH_SIZE { - self.send() - } - } - - fn send(&mut self) { - let msg = Outbound { - to: self.to, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::ProxyResponse(crate::linc::proto::ProxyResponse { - connection_id: self.connection_id, - req_id: self.req_id, - row_steps: std::mem::take(&mut self.buffer), - seq_no: self.next_seq_no, - }), - }, - }; - - self.next_seq_no += 1; - tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); - } -} - -impl ResultBuilder for ProxyResponseBuilder { - fn init( - &mut self, - _config: &libsqlx::result_builder::QueryBuilderConfig, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::Init); - self.maybe_send(); - Ok(()) - } - - fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::BeginStep); - self.maybe_send(); - Ok(()) - } - - fn finish_step( - &mut self, - affected_row_count: u64, - last_insert_rowid: Option, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::FinishStep( - affected_row_count, - last_insert_rowid, - )); - self.maybe_send(); - Ok(()) - } - - fn step_error( - &mut self, - error: libsqlx::error::Error, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer - .push(BuilderStep::StepError(StepError(error.to_string()))); - self.maybe_send(); - Ok(()) - } - - fn cols_description( - &mut self, - cols: &mut dyn Iterator, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer - .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); - self.maybe_send(); - Ok(()) - } - - fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::BeginRows); - self.maybe_send(); - Ok(()) - } - - fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::BeginRow); - self.maybe_send(); - Ok(()) - } - - fn add_row_value( - &mut self, - v: libsqlx::result_builder::ValueRef, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::AddRowValue(v.into())); - self.maybe_send(); - Ok(()) - } - - fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::FinishRow); - self.maybe_send(); - Ok(()) - } - - fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { - self.buffer.push(BuilderStep::FinishRows); - self.maybe_send(); - Ok(()) - } - - fn finnalize( - &mut self, - is_txn: bool, - frame_no: Option, - ) -> Result { - self.buffer - .push(BuilderStep::Finnalize { is_txn, frame_no }); - self.send(); - Ok(true) - } -} - pub enum Database { Primary(PrimaryDatabase), Replica { @@ -315,142 +72,6 @@ impl LogCompactor for Compactor { } } -const MAX_INJECTOR_BUFFER_CAP: usize = 32; - -struct Replicator { - dispatcher: Arc, - req_id: u32, - next_frame_no: FrameNo, - next_seq: u32, - database_id: DatabaseId, - primary_node_id: NodeId, - injector: Box, - receiver: mpsc::Receiver, -} - -impl Replicator { - async fn run(mut self) { - self.query_replicate().await; - loop { - match timeout(Duration::from_secs(5), self.receiver.recv()).await { - Ok(Some(Frames { - req_no: req_id, - seq_no: seq, - frames, - })) => { - // ignore frames from a previous call to Replicate - if req_id != self.req_id { - tracing::debug!(req_id, self.req_id, "wrong req_id"); - continue; - } - if seq != self.next_seq { - // this is not the batch of frame we were expecting, drop what we have, and - // ask again from last checkpoint - tracing::debug!(seq, self.next_seq, "wrong seq"); - self.query_replicate().await; - continue; - }; - self.next_seq += 1; - - tracing::debug!("injecting {} frames", frames.len()); - - for bytes in frames { - let frame = Frame::try_from_bytes(bytes).unwrap(); - block_in_place(|| { - if let Some(last_committed) = self.injector.inject(frame).unwrap() { - tracing::debug!(last_committed); - self.next_frame_no = last_committed + 1; - } - }); - } - } - Err(_) => self.query_replicate().await, - Ok(None) => break, - } - } - } - - async fn query_replicate(&mut self) { - self.req_id += 1; - self.next_seq = 0; - // clear buffered, uncommitted frames - self.injector.clear(); - self.dispatcher - .dispatch(Outbound { - to: self.primary_node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::Replicate { - next_frame_no: self.next_frame_no, - req_no: self.req_id, - }, - }, - }) - .await; - } -} - -struct FrameStreamer { - logger: Arc, - database_id: DatabaseId, - node_id: NodeId, - next_frame_no: FrameNo, - req_no: u32, - seq_no: u32, - dipatcher: Arc, - notifier: tokio::sync::watch::Receiver, - buffer: Vec, -} - -impl FrameStreamer { - async fn run(mut self) { - loop { - match block_in_place(|| self.logger.get_frame(self.next_frame_no)) { - Ok(frame) => { - if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { - self.send_frames().await; - } - self.buffer.push(frame.bytes()); - self.next_frame_no += 1; - } - Err(LogReadError::Ahead) => { - tracing::debug!("frame {} not yet avaiblable", self.next_frame_no); - if !self.buffer.is_empty() { - self.send_frames().await; - } - if self - .notifier - .wait_for(|fno| *fno >= self.next_frame_no) - .await - .is_err() - { - break; - } - } - Err(LogReadError::Error(_)) => todo!("handle log read error"), - Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), - } - } - } - - async fn send_frames(&mut self) { - let frames = std::mem::take(&mut self.buffer); - let outbound = Outbound { - to: self.node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::Frames(Frames { - req_no: self.req_no, - seq_no: self.seq_no, - frames, - }), - }, - }; - self.seq_no += 1; - self.dipatcher.dispatch(outbound).await; - } -} - impl Database { pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { @@ -485,16 +106,14 @@ impl Database { let (sender, receiver) = mpsc::channel(16); let database_id = DatabaseId::from_name(&config.db_name); - let replicator = Replicator { + let replicator = Replicator::new( dispatcher, - req_id: 0, - next_frame_no: 0, // TODO: load the last commited from meta file - next_seq: 0, + 0, database_id, primary_node_id, injector, receiver, - }; + ); tokio::spawn(replicator.run()); @@ -529,195 +148,6 @@ impl Database { } } -struct PrimaryConnection { - conn: libsqlx::libsql::LibsqlConnection, -} - -#[async_trait::async_trait] -impl ConnectionHandler for PrimaryConnection { - fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<()> { - Poll::Ready(()) - } - - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); - } - - async fn handle_inbound(&mut self, _msg: Inbound) { - tracing::debug!("primary connection received message, ignoring.") - } -} - -struct ReplicaConnection { - conn: ProxyConnection, - connection_id: u32, - next_req_id: u32, - primary_node_id: NodeId, - database_id: DatabaseId, - dispatcher: Arc, -} - -impl ReplicaConnection { - fn handle_proxy_response(&mut self, resp: ProxyResponse) { - let mut lock = self.conn.writer().inner.current_req.lock(); - let finnalized = match *lock { - Some(ref mut req) if req.id == Some(resp.req_id) && resp.seq_no == req.next_seq_no => { - self.next_req_id += 1; - // TODO: pass actual config - let config = QueryBuilderConfig { max_size: None }; - let mut finnalized = false; - for step in resp.row_steps.into_iter() { - if finnalized { - break; - }; - match step { - BuilderStep::Init => req.builder.init(&config).unwrap(), - BuilderStep::BeginStep => req.builder.begin_step().unwrap(), - BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req - .builder - .finish_step(affected_row_count, last_insert_rowid) - .unwrap(), - BuilderStep::StepError(e) => req - .builder - .step_error(todo!("handle proxy step error")) - .unwrap(), - BuilderStep::ColsDesc(cols) => req - .builder - .cols_description(&mut cols.iter().map(|c| Column { - name: &c.name, - decl_ty: c.decl_ty.as_deref(), - })) - .unwrap(), - BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), - BuilderStep::BeginRow => req.builder.begin_row().unwrap(), - BuilderStep::AddRowValue(v) => { - req.builder.add_row_value((&v).into()).unwrap() - } - BuilderStep::FinishRow => req.builder.finish_row().unwrap(), - BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), - BuilderStep::Finnalize { is_txn, frame_no } => { - let _ = req.builder.finnalize(is_txn, frame_no).unwrap(); - finnalized = true; - } - BuilderStep::FinnalizeError(e) => { - req.builder.finnalize_error(e); - finnalized = true; - } - } - } - finnalized - } - Some(_) => todo!("error processing response"), - None => { - tracing::error!("received builder message, but there is no pending request"); - false - } - }; - - if finnalized { - *lock = None; - } - } -} - -#[async_trait::async_trait] -impl ConnectionHandler for ReplicaConnection { - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { - // we are currently handling a request on this connection - // self.conn.writer().current_req.timeout.poll() - let mut req = self.conn.writer().current_req.lock(); - let should_abort_query = match &mut *req { - Some(ref mut req) => match req.timeout.as_mut().poll(cx) { - Poll::Ready(_) => { - req.builder.finnalize_error("request timed out".to_string()); - true - } - Poll::Pending => return Poll::Pending, - }, - None => return Poll::Ready(()), - }; - - if should_abort_query { - *req = None - } - - Poll::Ready(()) - } - - async fn handle_exec(&mut self, exec: ExecFn) { - block_in_place(|| exec(&mut self.conn)); - let msg = { - let mut lock = self.conn.writer().inner.current_req.lock(); - match *lock { - Some(ref mut req) if req.id.is_none() => { - let program = req - .pgm - .take() - .expect("unsent request should have a program"); - let req_id = self.next_req_id; - self.next_req_id += 1; - req.id = Some(req_id); - - let msg = Outbound { - to: self.primary_node_id, - enveloppe: Enveloppe { - database_id: Some(self.database_id), - message: Message::ProxyRequest { - connection_id: self.connection_id, - req_id, - program, - }, - }, - }; - - Some(msg) - } - _ => None, - } - }; - - if let Some(msg) = msg { - self.dispatcher.dispatch(msg).await; - } - } - - async fn handle_inbound(&mut self, msg: Inbound) { - match msg.enveloppe.message { - Message::ProxyResponse(resp) => { - self.handle_proxy_response(resp); - } - _ => (), // ignore anything else - } - } -} - -#[async_trait::async_trait] -impl ConnectionHandler for Either -where - L: ConnectionHandler, - R: ConnectionHandler, -{ - fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { - match self { - Either::Left(l) => l.poll_ready(cx), - Either::Right(r) => r.poll_ready(cx), - } - } - - async fn handle_exec(&mut self, exec: ExecFn) { - match self { - Either::Left(l) => l.handle_exec(exec).await, - Either::Right(r) => r.handle_exec(exec).await, - } - } - async fn handle_inbound(&mut self, msg: Inbound) { - match self { - Either::Left(l) => l.handle_inbound(msg).await, - Either::Right(r) => r.handle_inbound(msg).await, - } - } -} - pub struct Allocation { pub inbox: mpsc::Receiver, pub database: Database, @@ -874,7 +304,7 @@ impl Allocation { async fn handle_proxy( &mut self, - node_id: NodeId, + to: NodeId, connection_id: u32, req_id: u32, program: Program, @@ -884,15 +314,13 @@ impl Allocation { let exec = |conn: ConnectionHandle| async move { let _ = conn .exec(move |conn| { - let builder = ProxyResponseBuilder { + let builder = ProxyResponseBuilder::new( dispatcher, - req_id, - buffer: Vec::new(), - to: node_id, database_id, + to, + req_id, connection_id, - next_seq_no: 0, - }; + ); conn.execute_program(&program, Box::new(builder)).unwrap(); }) .await; @@ -901,14 +329,14 @@ impl Allocation { if self.database.is_primary() { match self .connections - .get(&node_id) + .get(&to) .and_then(|m| m.get(&connection_id).cloned()) { Some(handle) => { tokio::spawn(exec(handle)); } None => { - let handle = self.new_conn(Some((node_id, connection_id))).await; + let handle = self.new_conn(Some((to, connection_id))).await; tokio::spawn(exec(handle)); } } @@ -955,13 +383,6 @@ impl Allocation { } } -struct Connection { - id: (NodeId, u32), - conn: C, - exec: mpsc::Receiver, - inbound: mpsc::Receiver, -} - #[async_trait::async_trait] trait ConnectionHandler: Send { fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()>; @@ -969,6 +390,40 @@ trait ConnectionHandler: Send { async fn handle_inbound(&mut self, msg: Inbound); } +#[async_trait::async_trait] +impl ConnectionHandler for Either +where + L: ConnectionHandler, + R: ConnectionHandler, +{ + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + match self { + Either::Left(l) => l.poll_ready(cx), + Either::Right(r) => r.poll_ready(cx), + } + } + + async fn handle_exec(&mut self, exec: ExecFn) { + match self { + Either::Left(l) => l.handle_exec(exec).await, + Either::Right(r) => r.handle_exec(exec).await, + } + } + async fn handle_inbound(&mut self, msg: Inbound) { + match self { + Either::Left(l) => l.handle_inbound(msg).await, + Either::Right(r) => r.handle_inbound(msg).await, + } + } +} + +struct Connection { + id: (NodeId, u32), + conn: C, + exec: mpsc::Receiver, + inbound: mpsc::Receiver, +} + impl Connection { async fn run(mut self) -> (NodeId, u32) { loop { @@ -991,8 +446,11 @@ impl Connection { #[cfg(test)] mod test { + use std::time::Duration; + use tokio::sync::Notify; + use crate::allocation::replica::ReplicaConnection; use crate::linc::bus::Bus; use super::*; diff --git a/libsqlx-server/src/allocation/primary.rs b/libsqlx-server/src/allocation/primary.rs new file mode 100644 index 00000000..15ac4dbd --- /dev/null +++ b/libsqlx-server/src/allocation/primary.rs @@ -0,0 +1,275 @@ +use std::collections::HashMap; +use std::mem::size_of; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use bytes::Bytes; +use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; +use libsqlx::result_builder::ResultBuilder; +use libsqlx::{FrameNo, LogReadError, ReplicationLogger}; +use tokio::task::block_in_place; + +use crate::linc::bus::Dispatch; +use crate::linc::proto::{BuilderStep, Enveloppe, Frames, Message, StepError, Value}; +use crate::linc::{Inbound, NodeId, Outbound}; +use crate::meta::DatabaseId; + +use super::{ConnectionHandler, ExecFn, FRAMES_MESSAGE_MAX_COUNT}; + +const MAX_STEP_BATCH_SIZE: usize = 100_000_000; // ~100kb + // +pub struct PrimaryDatabase { + pub db: LibsqlDatabase, + pub replica_streams: HashMap)>, + pub frame_notifier: tokio::sync::watch::Receiver, +} + +pub struct ProxyResponseBuilder { + dispatcher: Arc, + buffer: Vec, + database_id: DatabaseId, + to: NodeId, + req_id: u32, + connection_id: u32, + next_seq_no: u32, +} + +impl ProxyResponseBuilder { + pub fn new( + dispatcher: Arc, + database_id: DatabaseId, + to: NodeId, + req_id: u32, + connection_id: u32, + ) -> Self { + Self { + dispatcher, + buffer: Vec::new(), + database_id, + to, + req_id, + connection_id, + next_seq_no: 0, + } + } + + fn maybe_send(&mut self) { + // FIXME: this is stupid: compute current buffer size on the go instead + let size = self + .buffer + .iter() + .map(|s| match s { + BuilderStep::FinishStep(_, _) => 2 * 8, + BuilderStep::StepError(StepError(s)) => s.len(), + BuilderStep::ColsDesc(ref d) => d + .iter() + .map(|c| c.name.len() + c.decl_ty.as_ref().map(|t| t.len()).unwrap_or_default()) + .sum(), + BuilderStep::Finnalize { .. } => 9, + BuilderStep::AddRowValue(v) => match v { + crate::linc::proto::Value::Text(s) | crate::linc::proto::Value::Blob(s) => { + s.len() + } + _ => size_of::(), + }, + _ => 8, + }) + .sum::(); + + if size > MAX_STEP_BATCH_SIZE { + self.send() + } + } + + fn send(&mut self) { + let msg = Outbound { + to: self.to, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyResponse(crate::linc::proto::ProxyResponse { + connection_id: self.connection_id, + req_id: self.req_id, + row_steps: std::mem::take(&mut self.buffer), + seq_no: self.next_seq_no, + }), + }, + }; + + self.next_seq_no += 1; + tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); + } +} + +impl ResultBuilder for ProxyResponseBuilder { + fn init( + &mut self, + _config: &libsqlx::result_builder::QueryBuilderConfig, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::Init); + self.maybe_send(); + Ok(()) + } + + fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginStep); + self.maybe_send(); + Ok(()) + } + + fn finish_step( + &mut self, + affected_row_count: u64, + last_insert_rowid: Option, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishStep( + affected_row_count, + last_insert_rowid, + )); + self.maybe_send(); + Ok(()) + } + + fn step_error( + &mut self, + error: libsqlx::error::Error, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::StepError(StepError(error.to_string()))); + self.maybe_send(); + Ok(()) + } + + fn cols_description( + &mut self, + cols: &mut dyn Iterator, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer + .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); + self.maybe_send(); + Ok(()) + } + + fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRows); + self.maybe_send(); + Ok(()) + } + + fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::BeginRow); + self.maybe_send(); + Ok(()) + } + + fn add_row_value( + &mut self, + v: libsqlx::result_builder::ValueRef, + ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::AddRowValue(v.into())); + self.maybe_send(); + Ok(()) + } + + fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRow); + self.maybe_send(); + Ok(()) + } + + fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + self.buffer.push(BuilderStep::FinishRows); + self.maybe_send(); + Ok(()) + } + + fn finnalize( + &mut self, + is_txn: bool, + frame_no: Option, + ) -> Result { + self.buffer + .push(BuilderStep::Finnalize { is_txn, frame_no }); + self.send(); + Ok(true) + } +} + +pub struct FrameStreamer { + pub logger: Arc, + pub database_id: DatabaseId, + pub node_id: NodeId, + pub next_frame_no: FrameNo, + pub req_no: u32, + pub seq_no: u32, + pub dipatcher: Arc, + pub notifier: tokio::sync::watch::Receiver, + pub buffer: Vec, +} + +impl FrameStreamer { + pub async fn run(mut self) { + loop { + match block_in_place(|| self.logger.get_frame(self.next_frame_no)) { + Ok(frame) => { + if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { + self.send_frames().await; + } + self.buffer.push(frame.bytes()); + self.next_frame_no += 1; + } + Err(LogReadError::Ahead) => { + tracing::debug!("frame {} not yet avaiblable", self.next_frame_no); + if !self.buffer.is_empty() { + self.send_frames().await; + } + if self + .notifier + .wait_for(|fno| *fno >= self.next_frame_no) + .await + .is_err() + { + break; + } + } + Err(LogReadError::Error(_)) => todo!("handle log read error"), + Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), + } + } + } + + async fn send_frames(&mut self) { + let frames = std::mem::take(&mut self.buffer); + let outbound = Outbound { + to: self.node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Frames(Frames { + req_no: self.req_no, + seq_no: self.seq_no, + frames, + }), + }, + }; + self.seq_no += 1; + self.dipatcher.dispatch(outbound).await; + } +} + +pub struct PrimaryConnection { + pub conn: libsqlx::libsql::LibsqlConnection, +} + +#[async_trait::async_trait] +impl ConnectionHandler for PrimaryConnection { + fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<()> { + Poll::Ready(()) + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + } + + async fn handle_inbound(&mut self, _msg: Inbound) { + tracing::debug!("primary connection received message, ignoring.") + } +} diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs new file mode 100644 index 00000000..297d27fb --- /dev/null +++ b/libsqlx-server/src/allocation/replica.rs @@ -0,0 +1,342 @@ +use std::ops::Deref; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; +use std::time::Duration; + +use futures::Future; +use libsqlx::libsql::{LibsqlConnection, LibsqlDatabase, ReplicaType}; +use libsqlx::program::Program; +use libsqlx::proxy::{WriteProxyConnection, WriteProxyDatabase}; +use libsqlx::result_builder::{Column, QueryBuilderConfig, ResultBuilder}; +use libsqlx::{DescribeResponse, Frame, FrameNo, Injector}; +use parking_lot::Mutex; +use tokio::{ + sync::mpsc, + task::block_in_place, + time::{timeout, Sleep}, +}; + +use crate::linc::proto::{BuilderStep, ProxyResponse}; +use crate::linc::Inbound; +use crate::{ + linc::{ + bus::Dispatch, + proto::{Enveloppe, Frames, Message}, + NodeId, Outbound, + }, + meta::DatabaseId, +}; + +use super::{ConnectionHandler, ExecFn}; + +type ProxyConnection = WriteProxyConnection, RemoteConn>; +pub type ProxyDatabase = WriteProxyDatabase, RemoteDb>; + +pub struct RemoteDb { + pub proxy_request_timeout_duration: Duration, +} + +#[derive(Clone)] +pub struct RemoteConn { + inner: Arc, +} + +struct Request { + id: Option, + builder: Box, + pgm: Option, + next_seq_no: u32, + timeout: Pin>, +} + +pub struct RemoteConnInner { + current_req: Mutex>, + request_timeout_duration: Duration, +} + +impl Deref for RemoteConn { + type Target = RemoteConnInner; + + fn deref(&self) -> &Self::Target { + self.inner.as_ref() + } +} + +impl libsqlx::Connection for RemoteConn { + fn execute_program( + &mut self, + program: &libsqlx::program::Program, + builder: Box, + ) -> libsqlx::Result<()> { + // When we need to proxy a query, we place it in the current request slot. When we are + // back in a async context, we'll send it to the primary, and asynchrously drive the + // builder. + let mut lock = self.inner.current_req.lock(); + *lock = match *lock { + Some(_) => unreachable!("conccurent request on the same connection!"), + None => Some(Request { + id: None, + builder, + pgm: Some(program.clone()), + next_seq_no: 0, + timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), + }), + }; + + Ok(()) + } + + fn describe(&self, _sql: String) -> libsqlx::Result { + unreachable!("Describe request should not be proxied") + } +} + +impl libsqlx::Database for RemoteDb { + type Connection = RemoteConn; + + fn connect(&self) -> Result { + Ok(RemoteConn { + inner: Arc::new(RemoteConnInner { + current_req: Default::default(), + request_timeout_duration: self.proxy_request_timeout_duration, + }), + }) + } +} + +pub struct Replicator { + dispatcher: Arc, + req_id: u32, + next_frame_no: FrameNo, + next_seq: u32, + database_id: DatabaseId, + primary_node_id: NodeId, + injector: Box, + receiver: mpsc::Receiver, +} + +impl Replicator { + pub fn new( + dispatcher: Arc, + next_frame_no: FrameNo, + database_id: DatabaseId, + primary_node_id: NodeId, + injector: Box, + receiver: mpsc::Receiver, + ) -> Self { + Self { + dispatcher, + req_id: 0, + next_frame_no, + next_seq: 0, + database_id, + primary_node_id, + injector, + receiver, + } + } + + pub async fn run(mut self) { + self.query_replicate().await; + loop { + match timeout(Duration::from_secs(5), self.receiver.recv()).await { + Ok(Some(Frames { + req_no: req_id, + seq_no: seq, + frames, + })) => { + // ignore frames from a previous call to Replicate + if req_id != self.req_id { + tracing::debug!(req_id, self.req_id, "wrong req_id"); + continue; + } + if seq != self.next_seq { + // this is not the batch of frame we were expecting, drop what we have, and + // ask again from last checkpoint + tracing::debug!(seq, self.next_seq, "wrong seq"); + self.query_replicate().await; + continue; + }; + self.next_seq += 1; + + tracing::debug!("injecting {} frames", frames.len()); + + for bytes in frames { + let frame = Frame::try_from_bytes(bytes).unwrap(); + block_in_place(|| { + if let Some(last_committed) = self.injector.inject(frame).unwrap() { + tracing::debug!(last_committed); + self.next_frame_no = last_committed + 1; + } + }); + } + } + Err(_) => self.query_replicate().await, + Ok(None) => break, + } + } + } + + async fn query_replicate(&mut self) { + self.req_id += 1; + self.next_seq = 0; + // clear buffered, uncommitted frames + self.injector.clear(); + self.dispatcher + .dispatch(Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Replicate { + next_frame_no: self.next_frame_no, + req_no: self.req_id, + }, + }, + }) + .await; + } +} + +pub struct ReplicaConnection { + pub conn: ProxyConnection, + pub connection_id: u32, + pub next_req_id: u32, + pub primary_node_id: NodeId, + pub database_id: DatabaseId, + pub dispatcher: Arc, +} + +impl ReplicaConnection { + fn handle_proxy_response(&mut self, resp: ProxyResponse) { + let mut lock = self.conn.writer().inner.current_req.lock(); + let finnalized = match *lock { + Some(ref mut req) if req.id == Some(resp.req_id) && resp.seq_no == req.next_seq_no => { + self.next_req_id += 1; + // TODO: pass actual config + let config = QueryBuilderConfig { max_size: None }; + let mut finnalized = false; + for step in resp.row_steps.into_iter() { + if finnalized { + break; + }; + match step { + BuilderStep::Init => req.builder.init(&config).unwrap(), + BuilderStep::BeginStep => req.builder.begin_step().unwrap(), + BuilderStep::FinishStep(affected_row_count, last_insert_rowid) => req + .builder + .finish_step(affected_row_count, last_insert_rowid) + .unwrap(), + BuilderStep::StepError(e) => req + .builder + .step_error(todo!("handle proxy step error")) + .unwrap(), + BuilderStep::ColsDesc(cols) => req + .builder + .cols_description(&mut &mut cols.iter().map(|c| Column { + name: &c.name, + decl_ty: c.decl_ty.as_deref(), + })) + .unwrap(), + BuilderStep::BeginRows => req.builder.begin_rows().unwrap(), + BuilderStep::BeginRow => req.builder.begin_row().unwrap(), + BuilderStep::AddRowValue(v) => { + req.builder.add_row_value((&v).into()).unwrap() + } + BuilderStep::FinishRow => req.builder.finish_row().unwrap(), + BuilderStep::FinishRows => req.builder.finish_rows().unwrap(), + BuilderStep::Finnalize { is_txn, frame_no } => { + let _ = req.builder.finnalize(is_txn, frame_no).unwrap(); + finnalized = true; + } + BuilderStep::FinnalizeError(e) => { + req.builder.finnalize_error(e); + finnalized = true; + } + } + } + finnalized + } + Some(_) => todo!("error processing response"), + None => { + tracing::error!("received builder message, but there is no pending request"); + false + } + }; + + if finnalized { + *lock = None; + } + } +} + +#[async_trait::async_trait] +impl ConnectionHandler for ReplicaConnection { + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<()> { + // we are currently handling a request on this connection + // self.conn.writer().current_req.timeout.poll() + let mut req = self.conn.writer().current_req.lock(); + let should_abort_query = match &mut *req { + Some(ref mut req) => match req.timeout.as_mut().poll(cx) { + Poll::Ready(_) => { + req.builder.finnalize_error("request timed out".to_string()); + true + } + Poll::Pending => return Poll::Pending, + }, + None => return Poll::Ready(()), + }; + + if should_abort_query { + *req = None + } + + Poll::Ready(()) + } + + async fn handle_exec(&mut self, exec: ExecFn) { + block_in_place(|| exec(&mut self.conn)); + let msg = { + let mut lock = self.conn.writer().inner.current_req.lock(); + match *lock { + Some(ref mut req) if req.id.is_none() => { + let program = req + .pgm + .take() + .expect("unsent request should have a program"); + let req_id = self.next_req_id; + self.next_req_id += 1; + req.id = Some(req_id); + + let msg = Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::ProxyRequest { + connection_id: self.connection_id, + req_id, + program, + }, + }, + }; + + Some(msg) + } + _ => None, + } + }; + + if let Some(msg) = msg { + self.dispatcher.dispatch(msg).await; + } + } + + async fn handle_inbound(&mut self, msg: Inbound) { + match msg.enveloppe.message { + Message::ProxyResponse(resp) => { + self.handle_proxy_response(resp); + } + _ => (), // ignore anything else + } + } +}