From 198ed752f98f9d65df500c67a4283c7ff5bfd2e2 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Tue, 11 Jul 2023 19:20:23 +0200 Subject: [PATCH 1/5] changes to proto --- libsqlx-server/src/allocation/mod.rs | 5 - libsqlx-server/src/config.rs | 51 +-- libsqlx-server/src/http/user/mod.rs | 1 - libsqlx-server/src/linc/bus.rs | 189 ++-------- libsqlx-server/src/linc/connection.rs | 398 ++++++--------------- libsqlx-server/src/linc/connection_pool.rs | 9 +- libsqlx-server/src/linc/handler.rs | 6 + libsqlx-server/src/linc/mod.rs | 30 +- libsqlx-server/src/linc/proto.rs | 110 +----- libsqlx-server/src/linc/server.rs | 16 +- 10 files changed, 203 insertions(+), 612 deletions(-) create mode 100644 libsqlx-server/src/linc/handler.rs diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index b0393165..38c29ff3 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -133,11 +133,8 @@ impl Allocation { } async fn new_conn(&mut self) -> ConnectionHandle { - dbg!(); let id = self.next_conn_id(); - dbg!(); let conn = block_in_place(|| self.database.connect()); - dbg!(); let (close_sender, exit) = oneshot::channel(); let (exec_sender, exec_receiver) = mpsc::channel(1); let conn = Connection { @@ -147,9 +144,7 @@ impl Allocation { exec: exec_receiver, }; - dbg!(); self.connections_futs.spawn(conn.run()); - dbg!(); ConnectionHandle { exec: exec_sender, diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index 4772b53f..cb2d68b5 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -1,8 +1,8 @@ use std::net::SocketAddr; use std::path::PathBuf; -use serde::Deserialize; use serde::de::Visitor; +use serde::Deserialize; #[derive(Deserialize, Debug, Clone)] pub struct Config { @@ -62,7 +62,7 @@ fn default_linc_addr() -> SocketAddr { } #[derive(Debug, Clone)] -struct Peer { +pub struct Peer { id: u64, addr: String, } @@ -70,29 +70,32 @@ struct Peer { impl<'de> Deserialize<'de> for Peer { fn deserialize(deserializer: D) -> Result where - D: serde::Deserializer<'de> { - struct V; - - impl Visitor<'_> for V { - type Value = Peer; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a string in the format :") - } - - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, { - - let mut iter = v.split(":"); - let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; - let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; - let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; - Ok(Peer { id, addr: addr.to_string() }) - } + D: serde::Deserializer<'de>, + { + struct V; + + impl Visitor<'_> for V { + type Value = Peer; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a string in the format :") } - deserializer.deserialize_str(V) + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + let mut iter = v.split(":"); + let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; + let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; + let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; + Ok(Peer { + id, + addr: addr.to_string(), + }) + } } -} + deserializer.deserialize_str(V) + } +} diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index bc3265e9..f357499c 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -46,6 +46,5 @@ async fn handle_hrana_pipeline( Json(req): Json, ) -> Json { let resp = db.hrana_pipeline(req).await; - dbg!(); Json(resp.unwrap()) } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index f9533347..7c52ec42 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,186 +1,59 @@ -use std::collections::{hash_map::Entry, HashMap}; use std::sync::Arc; -use color_eyre::eyre::{anyhow, bail}; -use parking_lot::Mutex; -use tokio::sync::{mpsc, Notify}; use uuid::Uuid; -use super::connection::{ConnectionHandle, Stream}; +use super::{connection::SendQueue, handler::Handler, Outbound, Inbound}; type NodeId = Uuid; type DatabaseId = Uuid; -#[must_use] -pub struct Subscription { - receiver: mpsc::Receiver, - bus: Bus, - database_id: DatabaseId, +pub struct Bus { + inner: Arc>, } -impl Drop for Subscription { - fn drop(&mut self) { - self.bus - .inner - .lock() - .subscriptions - .remove(&self.database_id); +impl Clone for Bus { + fn clone(&self) -> Self { + Self { inner: self.inner.clone() } } } -impl futures::Stream for Subscription { - type Item = Stream; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.receiver.poll_recv(cx) - } -} - -#[derive(Clone)] -pub struct Bus { - inner: Arc>, - pub node_id: NodeId, -} - -enum ConnectionSlot { - Handle(ConnectionHandle), - // Interest in the connection when it becomes available - Interest(Arc), -} - -struct BusInner { - connections: HashMap, - subscriptions: HashMap>, +struct BusInner { + node_id: NodeId, + handler: H, + send_queue: SendQueue, } -impl Bus { - pub fn new(node_id: NodeId) -> Self { +impl Bus { + pub fn new(node_id: NodeId, handler: H) -> Self { + let send_queue = SendQueue::new(); Self { - node_id, - inner: Arc::new(Mutex::new(BusInner { - connections: HashMap::new(), - subscriptions: HashMap::new(), - })), + inner: Arc::new(BusInner { + node_id, + handler, + send_queue, + }), } } - /// open a new stream to the database at `database_id` on the node `node_id` - pub async fn new_stream( - &self, - node_id: NodeId, - database_id: DatabaseId, - ) -> color_eyre::Result { - let get_conn = || { - let mut lock = self.inner.lock(); - match lock.connections.entry(node_id) { - Entry::Occupied(mut e) => match e.get_mut() { - ConnectionSlot::Handle(h) => Ok(h.clone()), - ConnectionSlot::Interest(notify) => Err(notify.clone()), - }, - Entry::Vacant(e) => { - let notify = Arc::new(Notify::new()); - e.insert(ConnectionSlot::Interest(notify.clone())); - Err(notify) - } - } - }; - - let conn = match get_conn() { - Ok(conn) => conn, - Err(notify) => { - notify.notified().await; - get_conn().map_err(|_| anyhow!("failed to create stream"))? - } - }; - - conn.new_stream(database_id).await + pub fn node_id(&self) -> NodeId { + self.inner.node_id } - /// Notify a subscription that new stream was openned - pub async fn notify_subscription( - &mut self, - database_id: DatabaseId, - stream: Stream, - ) -> color_eyre::Result<()> { - let maybe_sender = self.inner.lock().subscriptions.get(&database_id).cloned(); - - match maybe_sender { - Some(sender) => { - if sender.send(stream).await.is_err() { - bail!("subscription for {database_id} closed"); - } - - Ok(()) - } - None => { - bail!("no subscription for {database_id}") - } - } + pub async fn incomming(&self, incomming: Inbound) { + self.inner.handler.handle(self, incomming); } - #[cfg(test)] - pub fn is_empty(&self) -> bool { - self.inner.lock().connections.is_empty() + pub async fn dispatch(&self, msg: Outbound) { + assert!( + msg.to != self.node_id(), + "trying to send a message to ourself!" + ); + // This message is outbound. + self.inner.send_queue.enqueue(msg).await; } - #[must_use] - pub fn register_connection(&self, node_id: NodeId, conn: ConnectionHandle) -> Registration { - let mut lock = self.inner.lock(); - match lock.connections.entry(node_id) { - Entry::Occupied(mut e) => { - if let ConnectionSlot::Interest(ref notify) = e.get() { - notify.notify_waiters(); - } - - *e.get_mut() = ConnectionSlot::Handle(conn); - } - Entry::Vacant(e) => { - e.insert(ConnectionSlot::Handle(conn)); - } - } - - Registration { - bus: self.clone(), - node_id, - } - } - - pub fn subscribe(&self, database_id: DatabaseId) -> color_eyre::Result { - let (sender, receiver) = mpsc::channel(1); - { - let mut inner = self.inner.lock(); - - if inner.subscriptions.contains_key(&database_id) { - bail!("a subscription already exist for that database"); - } - - inner.subscriptions.insert(database_id, sender); - } - - Ok(Subscription { - receiver, - bus: self.clone(), - database_id, - }) - } -} - -pub struct Registration { - bus: Bus, - node_id: NodeId, -} + pub fn send_queue(&self) -> &SendQueue { + &self.inner.send_queue -impl Drop for Registration { - fn drop(&mut self) { - assert!(self - .bus - .inner - .lock() - .connections - .remove(&self.node_id) - .is_some()); } } diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 1d598cef..55977623 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -2,156 +2,35 @@ use std::collections::HashMap; use async_bincode::tokio::AsyncBincodeStream; use async_bincode::AsyncDestination; -use color_eyre::eyre::{anyhow, bail}; +use color_eyre::eyre::bail; use futures::{SinkExt, StreamExt}; +use parking_lot::RwLock; use tokio::io::{AsyncRead, AsyncWrite}; -use tokio::sync::mpsc::error::TrySendError; -use tokio::sync::{mpsc, oneshot}; +use tokio::sync::mpsc; use tokio::time::{Duration, Instant}; -use tokio_util::sync::PollSender; -use crate::linc::proto::{NodeError, NodeMessage}; +use crate::linc::proto::ProtoError; use crate::linc::CURRENT_PROTO_VERSION; -use super::bus::{Bus, Registration}; -use super::proto::{Message, StreamId, StreamMessage}; -use super::{DatabaseId, NodeId}; -use super::{StreamIdAllocator, MAX_STREAM_MSG}; - -#[derive(Debug, Clone)] -pub struct ConnectionHandle { - connection_sender: mpsc::Sender, -} - -impl ConnectionHandle { - pub async fn new_stream(&self, database_id: DatabaseId) -> color_eyre::eyre::Result { - let (send, ret) = oneshot::channel(); - self.connection_sender - .send(ConnectionMessage::StreamCreate { - database_id, - ret: send, - }) - .await - .unwrap(); - - Ok(ret.await?) - } -} - -/// A Bidirectional stream between databases on two nodes. -#[derive(Debug)] -pub struct Stream { - stream_id: StreamId, - /// sender to the connection - sender: tokio_util::sync::PollSender, - /// incoming message for this stream - recv: tokio_stream::wrappers::ReceiverStream, -} - -impl futures::Sink for Stream { - type Error = tokio_util::sync::PollSendError; - - fn poll_ready( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.sender.poll_ready_unpin(cx) - } - - fn start_send( - mut self: std::pin::Pin<&mut Self>, - payload: StreamMessage, - ) -> Result<(), Self::Error> { - let stream_id = self.stream_id; - self.sender - .start_send_unpin(ConnectionMessage::Message(Message::Stream { - stream_id, - payload, - })) - } - - fn poll_flush( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.sender.poll_flush_unpin(cx) - } - - fn poll_close( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.sender.poll_close_unpin(cx) - } -} - -impl futures::Stream for Stream { - type Item = StreamMessage; - - fn poll_next( - mut self: std::pin::Pin<&mut Self>, - cx: &mut std::task::Context<'_>, - ) -> std::task::Poll> { - self.recv.poll_next_unpin(cx) - } -} - -impl Drop for Stream { - fn drop(&mut self) { - self.recv.close(); - assert!(self.recv.as_mut().try_recv().is_err()); - let mut sender = self.sender.clone(); - let id = self.stream_id; - if let Some(sender_ref) = sender.get_ref() { - // Try send here is mostly for turmoil, since it stops polling the future as soon as - // the test future returns which causes spawn to panic. In the tests, the channel will - // always have capacity. - if let Err(TrySendError::Full(m)) = - sender_ref.try_send(ConnectionMessage::CloseStream(id)) - { - tokio::task::spawn(async move { - let _ = sender.send(m).await; - }); - } - } - } -} - -struct StreamState { - sender: mpsc::Sender, -} +use super::bus::{Bus}; +use super::handler::Handler; +use super::proto::{Enveloppe, Message}; +use super::{NodeId, Outbound, Inbound}; /// A connection to another node. Manage the connection state, and (de)register streams with the /// `Bus` -pub struct Connection { +pub struct Connection { /// Id of the current node pub peer: Option, /// State of the connection pub state: ConnectionState, /// Sink/Stream for network messages - conn: AsyncBincodeStream, - /// Collection of streams for that connection - streams: HashMap, - /// internal connection messages - connection_messages: mpsc::Receiver, - connection_messages_sender: mpsc::Sender, + conn: AsyncBincodeStream, /// Are we the initiator of this connection? is_initiator: bool, - bus: Bus, - stream_id_allocator: StreamIdAllocator, - /// handle to the registration of this connection to the bus. - /// Dropping this deregister this connection from the bus - registration: Option, -} - -#[derive(Debug)] -pub enum ConnectionMessage { - StreamCreate { - database_id: DatabaseId, - ret: oneshot::Sender, - }, - CloseStream(StreamId), - Message(Message), + /// send queue for this connection + send_queue: Option>, + bus: Bus, } #[derive(Debug)] @@ -170,49 +49,61 @@ pub fn handshake_deadline() -> Instant { Instant::now() + HANDSHAKE_TIMEOUT } -impl Connection +// TODO: limit send queue depth +pub struct SendQueue { + senders: RwLock>>, +} + +impl SendQueue { + pub fn new() -> Self { + Self { + senders: Default::default(), + } + } + + pub async fn enqueue(&self, msg: Outbound) { + let sender = match self.senders.read().get(&msg.to) { + Some(sender) => sender.clone(), + None => todo!("no queue"), + }; + + sender.send(msg.enveloppe); + } + + pub fn register(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { + let (sender, receiver) = mpsc::unbounded_channel(); + self.senders.write().insert(node_id, sender); + + receiver + } +} + +impl Connection where S: AsyncRead + AsyncWrite + Unpin, + H: Handler, { const MAX_CONNECTION_MESSAGES: usize = 128; - pub fn new_initiator(stream: S, bus: Bus) -> Self { - let (connection_messages_sender, connection_messages) = - mpsc::channel(Self::MAX_CONNECTION_MESSAGES); + pub fn new_initiator(stream: S, bus: Bus) -> Self { Self { peer: None, state: ConnectionState::Init, conn: AsyncBincodeStream::from(stream).for_async(), - streams: HashMap::new(), is_initiator: true, - bus, - stream_id_allocator: StreamIdAllocator::new(true), - connection_messages, - connection_messages_sender, - registration: None, + send_queue: None, + bus, } } - pub fn new_acceptor(stream: S, bus: Bus) -> Self { - let (connection_messages_sender, connection_messages) = - mpsc::channel(Self::MAX_CONNECTION_MESSAGES); + pub fn new_acceptor(stream: S, bus: Bus) -> Self { Connection { peer: None, state: ConnectionState::Connecting, - streams: HashMap::new(), - connection_messages, - connection_messages_sender, is_initiator: false, bus, + send_queue: None, conn: AsyncBincodeStream::from(stream).for_async(), - stream_id_allocator: StreamIdAllocator::new(false), - registration: None, - } - } - - pub fn handle(&self) -> ConnectionHandle { - ConnectionHandle { - connection_sender: self.connection_messages_sender.clone(), } } @@ -262,135 +153,34 @@ where self.state = ConnectionState::Close; } } - } - Some(command) = self.connection_messages.recv() => { - self.handle_command(command).await; }, + // TODO: pop send queue + Some(m) = self.send_queue.as_mut().unwrap().recv() => { + self.conn.feed(m).await.unwrap(); + // send as many as possible + while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { + self.conn.feed(m).await.unwrap(); + } + self.conn.flush().await.unwrap(); + } else => { self.state = ConnectionState::Close; } } } - async fn handle_message(&mut self, message: Message) { - match message { - Message::Node(NodeMessage::OpenStream { - stream_id, - database_id, - }) => { - if self.streams.contains_key(&stream_id) { - self.send_message(Message::Node(NodeMessage::Error( - NodeError::StreamAlreadyExist(stream_id), - ))) - .await; - return; - } - let stream = self.create_stream(stream_id); - if let Err(e) = self.bus.notify_subscription(database_id, stream).await { - tracing::error!("{e}"); - self.send_message(Message::Node(NodeMessage::Error( - NodeError::UnknownDatabase(database_id, stream_id), - ))) - .await; - } - } - Message::Node(NodeMessage::Handshake { .. }) => { - self.close_error(anyhow!("unexpected handshake: closing connection")); - } - Message::Node(NodeMessage::CloseStream { stream_id: id }) => { - self.close_stream(id); - } - Message::Node(NodeMessage::Error(e @ NodeError::HandshakeVersionMismatch { .. })) => { - self.close_error(anyhow!("unexpected peer error: {e}")); - } - Message::Node(NodeMessage::Error(NodeError::UnknownStream(id))) => { - tracing::error!("unkown stream: {id}"); - self.close_stream(id); - } - Message::Node(NodeMessage::Error(e @ NodeError::StreamAlreadyExist(_))) => { - self.state = ConnectionState::CloseError(e.into()); - } - Message::Node(NodeMessage::Error(ref e @ NodeError::UnknownDatabase(_, stream_id))) => { - tracing::error!("{e}"); - self.close_stream(stream_id); - } - Message::Stream { stream_id, payload } => { - match self.streams.get_mut(&stream_id) { - Some(s) => { - // TODO: there is not stream-independant control-flow for now. - // When/if control-flow is implemented, it will be handled here. - if s.sender.send(payload).await.is_err() { - self.close_stream(stream_id); - } - } - None => { - self.send_message(Message::Node(NodeMessage::Error( - NodeError::UnknownStream(stream_id), - ))) - .await; - } - } - } - } + async fn handle_message(&mut self, enveloppe: Enveloppe) { + let incomming = Inbound { + from: self.peer.expect("peer id should be known at this point"), + enveloppe, + }; + self.bus.incomming(incomming).await; } fn close_error(&mut self, error: color_eyre::eyre::Error) { self.state = ConnectionState::CloseError(error); } - fn close_stream(&mut self, id: StreamId) { - self.streams.remove(&id); - } - - async fn handle_command(&mut self, command: ConnectionMessage) { - match command { - ConnectionMessage::Message(m) => { - self.send_message(m).await; - } - ConnectionMessage::CloseStream(stream_id) => { - self.close_stream(stream_id); - self.send_message(Message::Node(NodeMessage::CloseStream { stream_id })) - .await; - } - ConnectionMessage::StreamCreate { database_id, ret } => { - let Some(stream_id) = self.stream_id_allocator.allocate() else { - // TODO: We close the connection here, which will cause a reconnections, and - // reset the stream_id allocator. If that happens in practice, it should be very quick to - // re-establish a connection. If this is an issue, we can either start using - // i64 stream_ids, or use a smarter id allocator. - self.state = ConnectionState::CloseError(anyhow!("Ran out of stream ids")); - return - }; - assert_eq!(stream_id.is_positive(), self.is_initiator); - assert!(!self.streams.contains_key(&stream_id)); - let stream = self.create_stream(stream_id); - self.send_message(Message::Node(NodeMessage::OpenStream { - stream_id, - database_id, - })) - .await; - let _ = ret.send(stream); - } - } - } - - async fn send_message(&mut self, message: Message) { - if let Err(e) = self.conn.send(message).await { - self.close_error(e.into()); - } - } - - fn create_stream(&mut self, stream_id: StreamId) -> Stream { - let (sender, recv) = mpsc::channel(MAX_STREAM_MSG); - let stream = Stream { - stream_id, - sender: PollSender::new(self.connection_messages_sender.clone()), - recv: recv.into(), - }; - self.streams.insert(stream_id, StreamState { sender }); - stream - } - /// wait for a handshake response from peer pub async fn wait_handshake_response_with_deadline( &mut self, @@ -399,41 +189,49 @@ where assert!(matches!(self.state, ConnectionState::Connecting)); match tokio::time::timeout_at(deadline, self.conn.next()).await { - Ok(Some(Ok(Message::Node(NodeMessage::Handshake { - protocol_version, - node_id, - })))) => { + Ok(Some(Ok(Enveloppe { + message: + Message::Handshake { + protocol_version, + node_id, + }, + .. + }))) => { if protocol_version != CURRENT_PROTO_VERSION { - let _ = self - .conn - .send(Message::Node(NodeMessage::Error( - NodeError::HandshakeVersionMismatch { - expected: CURRENT_PROTO_VERSION, - }, - ))) - .await; + let msg = Enveloppe { + from: None, + to: None, + message: Message::Error(ProtoError::HandshakeVersionMismatch { + expected: CURRENT_PROTO_VERSION, + }), + }; + + let _ = self.conn.send(msg).await; bail!("handshake error: invalid peer protocol version"); } else { // when not initiating a connection, respond to handshake message with a // handshake message if !self.is_initiator { - self.conn - .send(Message::Node(NodeMessage::Handshake { + let msg = Enveloppe { + from: None, + to: None, + message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, - node_id: self.bus.node_id, - })) - .await?; + node_id: self.bus.node_id(), + }, + }; + self.conn.send(msg).await?; } self.peer = Some(node_id); self.state = ConnectionState::Connected; - self.registration = Some(self.bus.register_connection(node_id, self.handle())); + self.send_queue = Some(self.bus.send_queue().register(node_id)); Ok(()) } } - Ok(Some(Ok(Message::Node(NodeMessage::Error(e))))) => { + Ok(Some(Ok(Enveloppe { message: Message::Error(e), ..}))) => { bail!("handshake error: {e}"); } Ok(Some(Ok(_))) => { @@ -452,12 +250,16 @@ where } async fn initiate_connection(&mut self) -> color_eyre::Result<()> { - self.conn - .send(Message::Node(NodeMessage::Handshake { + let msg = Enveloppe { + from: None, + to: None, + message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, - node_id: self.bus.node_id, - })) - .await?; + node_id: self.bus.node_id(), + }, + }; + + self.conn.send(msg).await?; Ok(()) } diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index f5f29c61..812745d3 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -5,18 +5,19 @@ use tokio::task::JoinSet; use tokio::time::Duration; use super::connection::Connection; +use super::handler::Handler; use super::net::Connector; use super::{bus::Bus, NodeId}; /// Manages a pool of connections to other peers, handling re-connection. -struct ConnectionPool { +struct ConnectionPool { managed_peers: HashMap, connections: JoinSet, - bus: Bus, + bus: Bus, } -impl ConnectionPool { - pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { +impl ConnectionPool { + pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { Self { managed_peers: managed_peers.into_iter().collect(), connections: JoinSet::new(), diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs new file mode 100644 index 00000000..c8db9a89 --- /dev/null +++ b/libsqlx-server/src/linc/handler.rs @@ -0,0 +1,6 @@ +use super::{bus::{Bus}, Inbound}; + +pub trait Handler: Sized + Send + Sync + 'static { + fn handle(&self, bus: &Bus, msg: Inbound); +} + diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index 30b06285..8a3747bd 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -1,6 +1,6 @@ use uuid::Uuid; -use self::proto::StreamId; +use self::proto::Enveloppe; pub mod bus; pub mod connection; @@ -8,6 +8,7 @@ pub mod connection_pool; pub mod net; pub mod proto; pub mod server; +pub mod handler; type NodeId = Uuid; type DatabaseId = Uuid; @@ -16,23 +17,16 @@ const CURRENT_PROTO_VERSION: u32 = 1; const MAX_STREAM_MSG: usize = 64; #[derive(Debug)] -pub struct StreamIdAllocator { - direction: i32, - next_id: i32, +pub struct Inbound { + /// Id of the node sending the message + pub from: NodeId, + /// payload + pub enveloppe: Enveloppe, } -impl StreamIdAllocator { - fn new(positive: bool) -> Self { - let direction = if positive { 1 } else { -1 }; - Self { - direction, - next_id: direction, - } - } - - pub fn allocate(&mut self) -> Option { - let id = self.next_id; - self.next_id = id.checked_add(self.direction)?; - Some(StreamId::new(id)) - } +#[derive(Debug)] +pub struct Outbound { + pub to: NodeId, + pub enveloppe: Enveloppe, } + diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index 7de1002a..617f2d87 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -1,112 +1,42 @@ -use std::fmt; - use bytes::Bytes; -use serde::{de::Error, Deserialize, Deserializer, Serialize}; +use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::DatabaseId; +use super::{DatabaseId}; pub type Program = String; -#[derive(Serialize, Deserialize, Debug, PartialEq, Eq, Hash, Clone, Copy)] -pub struct StreamId(#[serde(deserialize_with = "non_zero")] i32); - -impl fmt::Display for StreamId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - self.0.fmt(f) - } -} - -fn non_zero<'de, D>(d: D) -> Result -where - D: Deserializer<'de>, -{ - let value = i32::deserialize(d)?; - - if value == 0 { - return Err(D::Error::custom("invalid stream_id")); - } - - Ok(value) -} - -impl StreamId { - /// creates a new stream_id. - /// panics if val is zero. - pub fn new(val: i32) -> Self { - assert!(val != 0); - Self(val) - } - - pub fn is_positive(&self) -> bool { - self.0.is_positive() - } - - #[cfg(test)] - pub fn new_unchecked(i: i32) -> Self { - Self(i) - } -} - #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub enum Message { - /// Messages destined to a node - Node(NodeMessage), - /// message destined to a database - Stream { - stream_id: StreamId, - payload: StreamMessage, - }, +pub struct Enveloppe { + pub from: Option, + pub to: Option, + pub message: Message, } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] -pub enum NodeMessage { +pub enum Message { /// Initial message exchanged between nodes when connecting Handshake { protocol_version: u32, node_id: Uuid, }, - /// Request to open a bi-directional stream between the client and the server - OpenStream { - /// Id to give to the newly opened stream - /// Initiator of the connection create streams with positive ids, - /// and acceptor of the connection create streams with negative ids. - stream_id: StreamId, - /// Id of the database to open the stream to. - database_id: Uuid, - }, - /// Close a previously opened stream - CloseStream { stream_id: StreamId }, - /// Error type returned while handling a node message - Error(NodeError), + Replication(ReplicationMessage), + Proxy(ProxyMessage), + Error(ProtoError), } #[derive(Debug, Serialize, Deserialize, thiserror::Error, PartialEq, Eq)] -pub enum NodeError { - /// The requested stream does not exist - #[error("unknown stream: {0}")] - UnknownStream(StreamId), +pub enum ProtoError { /// Incompatible protocol versions #[error("invalid protocol version, expected: {expected}")] HandshakeVersionMismatch { expected: u32 }, - #[error("stream {0} already exists")] - StreamAlreadyExist(StreamId), - #[error("cannot open stream {1}: unknown database {0}")] - UnknownDatabase(DatabaseId, StreamId), -} - -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum StreamMessage { - /// Replication message between a replica and a primary - Replication(ReplicationMessage), - /// Proxy message between a replica and a primary - Proxy(ProxyMessage), - #[cfg(test)] - Dummy, + #[error("unknown database {0}")] + UnknownDatabase(DatabaseId), } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub enum ReplicationMessage { + Handshake {}, HandshakeResponse { /// id of the replication log log_id: Uuid, @@ -126,8 +56,6 @@ pub enum ReplicationMessage { /// a batch of frames part of the transaction. frames: Vec, }, - /// Error occurred handling a replication message - Error(ReplicationError), } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] @@ -161,8 +89,6 @@ pub enum ProxyMessage { CancelRequest { req_id: u32 }, /// Close Connection with passed id. CloseConnection { connection_id: u32 }, - /// Error returned when handling a proxied query message. - Error(ProxyError), } /// Steps applied to the query builder transducer to build a response to a proxied query. @@ -204,11 +130,3 @@ pub struct Column { /// for now, the stringified version of a sqld::error::Error. #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub struct StepError(String); - -/// TBD -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ProxyError {} - -/// TBD -#[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] -pub enum ReplicationError {} diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index 08c205ef..0594bd9e 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -3,17 +3,18 @@ use tokio::task::JoinSet; use crate::linc::connection::Connection; -use super::bus::Bus; +use super::bus::{Bus}; +use super::handler::Handler; -pub struct Server { +pub struct Server { /// reference to the bus - bus: Bus, + bus: Bus, /// Connection tasks owned by the server connections: JoinSet>, } -impl Server { - pub fn new(bus: Bus) -> Self { +impl Server { + pub fn new(bus: Bus) -> Self { Self { bus, connections: JoinSet::new(), @@ -25,7 +26,6 @@ impl Server { pub async fn close_connections(&mut self) { self.connections.abort_all(); while self.connections.join_next().await.is_some() {} - assert!(self.bus.is_empty()); } pub async fn run(mut self, mut listener: L) @@ -57,7 +57,7 @@ impl Server { { let bus = self.bus.clone(); let fut = async move { - let connection = Connection::new_acceptor(stream, bus.clone()); + let connection = Connection::new_acceptor(stream, bus); connection.run().await; Ok(()) }; @@ -71,7 +71,7 @@ mod test { use std::sync::Arc; use crate::linc::{ - proto::{ProxyMessage, StreamMessage}, + proto::{ProxyMessage}, DatabaseId, NodeId, }; From 58b3aaf74dadab21657eb4a6eede994576716465 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 12 Jul 2023 11:56:53 +0200 Subject: [PATCH 2/5] deliver message to allocations --- Cargo.lock | 23 ++++++- libsqlx-server/Cargo.toml | 2 + libsqlx-server/src/allocation/config.rs | 6 +- libsqlx-server/src/allocation/mod.rs | 35 ++++++++++- libsqlx-server/src/config.rs | 1 + libsqlx-server/src/http/admin.rs | 7 ++- libsqlx-server/src/http/user/extractors.rs | 7 ++- libsqlx-server/src/http/user/mod.rs | 4 ++ libsqlx-server/src/linc/bus.rs | 52 +++++++-------- libsqlx-server/src/linc/connection.rs | 18 +++--- libsqlx-server/src/linc/connection_pool.rs | 12 ++-- libsqlx-server/src/linc/handler.rs | 9 ++- libsqlx-server/src/linc/mod.rs | 23 ++++--- libsqlx-server/src/linc/proto.rs | 63 ++++++++++++++++--- libsqlx-server/src/linc/server.rs | 21 +++---- libsqlx-server/src/main.rs | 7 ++- libsqlx-server/src/manager.rs | 40 +++++++++--- libsqlx-server/src/meta.rs | 49 ++++++++++++--- libsqlx/Cargo.toml | 2 - libsqlx/src/database/libsql/injector/mod.rs | 2 +- libsqlx/src/database/libsql/mod.rs | 8 ++- .../database/libsql/replication_log/logger.rs | 29 ++++++--- libsqlx/src/error.rs | 8 --- sqld/src/query_result_builder.rs | 1 - 24 files changed, 303 insertions(+), 126 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 2eb87fc8..2a3cc0f6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2411,6 +2411,15 @@ dependencies = [ "simple_asn1", ] +[[package]] +name = "keccak" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f6d5ed8676d904364de097082f4e7d240b571b67989ced0240f08b7f966f940" +dependencies = [ + "cpufeatures", +] + [[package]] name = "lazy_static" version = "1.4.0" @@ -2511,7 +2520,6 @@ dependencies = [ "crc", "crossbeam", "fallible-iterator 0.3.0", - "futures", "itertools 0.11.0", "nix", "once_cell", @@ -2525,7 +2533,6 @@ dependencies = [ "sqlite3-parser 0.9.0", "tempfile", "thiserror", - "tokio", "tracing", "uuid", ] @@ -2535,6 +2542,7 @@ name = "libsqlx-server" version = "0.1.0" dependencies = [ "async-bincode", + "async-trait", "axum", "base64 0.21.2", "bincode", @@ -2554,6 +2562,7 @@ dependencies = [ "serde", "serde_json", "sha2", + "sha3", "sled", "thiserror", "tokio", @@ -3986,6 +3995,16 @@ dependencies = [ "sha2", ] +[[package]] +name = "sha3" +version = "0.10.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "75872d278a8f37ef87fa0ddbda7802605cb18344497949862c0d4dcb291eba60" +dependencies = [ + "digest", + "keccak", +] + [[package]] name = "sharded-slab" version = "0.1.4" diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index 6393f6cb..efff738e 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -7,6 +7,7 @@ edition = "2021" [dependencies] async-bincode = { version = "0.7.1", features = ["tokio"] } +async-trait = "0.1.71" axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" @@ -26,6 +27,7 @@ regex = "1.9.1" serde = { version = "1.0.166", features = ["derive", "rc"] } serde_json = "1.0.100" sha2 = "0.10.7" +sha3 = "0.10.8" sled = "0.34.7" thiserror = "1.0.43" tokio = { version = "1.29.1", features = ["full"] } diff --git a/libsqlx-server/src/allocation/config.rs b/libsqlx-server/src/allocation/config.rs index f5839e9c..9d1bab34 100644 --- a/libsqlx-server/src/allocation/config.rs +++ b/libsqlx-server/src/allocation/config.rs @@ -1,5 +1,7 @@ use serde::{Deserialize, Serialize}; +use crate::linc::NodeId; + /// Structural supertype of AllocConfig, used for checking the meta version. Subsequent version of /// AllocConfig need to conform to this prototype. #[derive(Debug, Serialize, Deserialize)] @@ -10,12 +12,12 @@ struct ConfigVersion { #[derive(Debug, Serialize, Deserialize)] pub struct AllocConfig { pub max_conccurent_connection: u32, - pub id: String, + pub db_name: String, pub db_config: DbConfig, } #[derive(Debug, Serialize, Deserialize)] pub enum DbConfig { Primary {}, - Replica { primary_node_id: String }, + Replica { primary_node_id: NodeId }, } diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 38c29ff3..c9f88ed0 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,7 +1,7 @@ use std::path::PathBuf; use std::sync::Arc; -use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType}; +use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; use libsqlx::Database as _; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; @@ -9,6 +9,9 @@ use tokio::task::{block_in_place, JoinSet}; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::linc::bus::Dispatch; +use crate::linc::{Inbound, NodeId}; +use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; @@ -28,10 +31,15 @@ pub enum AllocationMessage { req: PipelineRequestBody, ret: oneshot::Sender>, }, + Inbound(Inbound), } pub enum Database { Primary(libsqlx::libsql::LibsqlDatabase), + Replica { + db: libsqlx::libsql::LibsqlDatabase, + primary_node_id: NodeId, + }, } struct Compactor; @@ -65,6 +73,7 @@ impl Database { fn connect(&self) -> Box { match self { Database::Primary(db) => Box::new(db.connect().unwrap()), + Database::Replica { db, .. } => Box::new(db.connect().unwrap()), } } } @@ -78,6 +87,9 @@ pub struct Allocation { pub max_concurrent_connections: u32, pub hrana_server: Arc, + /// handle to the message bus, to send messages + pub dispatcher: Arc, + pub db_name: String, } pub struct ConnectionHandle { @@ -115,11 +127,13 @@ impl Allocation { AllocationMessage::HranaPipelineReq { req, ret} => { let res = handle_pipeline(&self.hrana_server.clone(), req, || async { let conn= self.new_conn().await; - dbg!(); Ok(conn) }).await; let _ = ret.send(res); } + AllocationMessage::Inbound(msg) => { + self.handle_inbound(msg).await; + } } }, maybe_id = self.connections_futs.join_next() => { @@ -132,6 +146,23 @@ impl Allocation { } } + async fn handle_inbound(&mut self, msg: Inbound) { + debug_assert_eq!(msg.enveloppe.to, Some(DatabaseId::from_name(&self.db_name))); + + match msg.enveloppe.message { + crate::linc::proto::Message::Handshake { .. } => todo!(), + crate::linc::proto::Message::ReplicationHandshake { .. } => todo!(), + crate::linc::proto::Message::ReplicationHandshakeResponse { .. } => todo!(), + crate::linc::proto::Message::Replicate { .. } => todo!(), + crate::linc::proto::Message::Transaction { .. } => todo!(), + crate::linc::proto::Message::ProxyRequest { .. } => todo!(), + crate::linc::proto::Message::ProxyResponse { .. } => todo!(), + crate::linc::proto::Message::CancelRequest { .. } => todo!(), + crate::linc::proto::Message::CloseConnection { .. } => todo!(), + crate::linc::proto::Message::Error(_) => todo!(), + } + } + async fn new_conn(&mut self) -> ConnectionHandle { let id = self.next_conn_id(); let conn = block_in_place(|| self.database.connect()); diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index cb2d68b5..f0f9ca0c 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -26,6 +26,7 @@ impl Config { #[derive(Deserialize, Debug, Clone)] pub struct ClusterConfig { + pub id: u64, /// Address to bind this node to #[serde(default = "default_linc_addr")] pub addr: SocketAddr, diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index 346987c4..8a08187e 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::{ allocation::config::{AllocConfig, DbConfig}, + linc::NodeId, meta::Store, }; @@ -55,7 +56,7 @@ struct AllocateReq { #[serde(tag = "type", rename_all = "snake_case")] pub enum DbConfigReq { Primary {}, - Replica { primary_node_id: String }, + Replica { primary_node_id: NodeId }, } async fn allocate( @@ -64,7 +65,7 @@ async fn allocate( ) -> Result, Json> { let config = AllocConfig { max_conccurent_connection: req.max_conccurent_connection.unwrap_or(16), - id: req.alloc_id.clone(), + db_name: req.alloc_id.clone(), db_config: match req.config { DbConfigReq::Primary {} => DbConfig::Primary {}, DbConfigReq::Replica { primary_node_id } => DbConfig::Replica { primary_node_id }, @@ -93,7 +94,7 @@ async fn list_allocs( .list_allocs() .await .into_iter() - .map(|cfg| AllocView { id: cfg.id }) + .map(|cfg| AllocView { id: cfg.db_name }) .collect(); Ok(Json(ListAllocResp { allocs })) diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs index 2b3f5a14..962eb060 100644 --- a/libsqlx-server/src/http/user/extractors.rs +++ b/libsqlx-server/src/http/user/extractors.rs @@ -4,7 +4,7 @@ use axum::async_trait; use axum::extract::FromRequestParts; use hyper::http::request::Parts; -use crate::database::Database; +use crate::{database::Database, meta::DatabaseId}; use super::{error::UserApiError, UserApiState}; @@ -18,8 +18,9 @@ impl FromRequestParts> for Database { ) -> Result { let Some(host) = parts.headers.get("host") else { return Err(UserApiError::MissingHost) }; let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; - let db_id = parse_host(host_str)?; - let Some(sender) = state.manager.alloc(db_id).await else { return Err(UserApiError::UnknownDatabase(db_id.to_owned())) }; + let db_name = parse_host(host_str)?; + let db_id = DatabaseId::from_name(db_name); + let Some(sender) = state.manager.alloc(db_id, state.bus.clone()).await else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; Ok(Database { sender }) } diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index f357499c..c947fb8b 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -8,6 +8,7 @@ use tokio::io::{AsyncRead, AsyncWrite}; use crate::database::Database; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::linc::bus::Bus; use crate::manager::Manager; mod error; @@ -15,10 +16,12 @@ mod extractors; pub struct Config { pub manager: Arc, + pub bus: Arc>>, } struct UserApiState { manager: Arc, + bus: Arc>>, } pub async fn run_user_api(config: Config, listener: I) -> Result<()> @@ -28,6 +31,7 @@ where { let state = UserApiState { manager: config.manager, + bus: config.bus, }; let app = Router::new() diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 7c52ec42..bf9a2cc4 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,23 +1,8 @@ use std::sync::Arc; -use uuid::Uuid; - -use super::{connection::SendQueue, handler::Handler, Outbound, Inbound}; - -type NodeId = Uuid; -type DatabaseId = Uuid; +use super::{connection::SendQueue, handler::Handler, Inbound, NodeId, Outbound}; pub struct Bus { - inner: Arc>, -} - -impl Clone for Bus { - fn clone(&self) -> Self { - Self { inner: self.inner.clone() } - } -} - -struct BusInner { node_id: NodeId, handler: H, send_queue: SendQueue, @@ -27,33 +12,38 @@ impl Bus { pub fn new(node_id: NodeId, handler: H) -> Self { let send_queue = SendQueue::new(); Self { - inner: Arc::new(BusInner { - node_id, - handler, - send_queue, - }), + node_id, + handler, + send_queue, } } pub fn node_id(&self) -> NodeId { - self.inner.node_id + self.node_id } - pub async fn incomming(&self, incomming: Inbound) { - self.inner.handler.handle(self, incomming); + pub async fn incomming(self: &Arc, incomming: Inbound) { + self.handler.handle(self.clone(), incomming); } - pub async fn dispatch(&self, msg: Outbound) { + pub fn send_queue(&self) -> &SendQueue { + &self.send_queue + } +} + +#[async_trait::async_trait] +pub trait Dispatch: Send + Sync + 'static { + async fn dispatch(&self, msg: Outbound); +} + +#[async_trait::async_trait] +impl Dispatch for Bus { + async fn dispatch(&self, msg: Outbound) { assert!( msg.to != self.node_id(), "trying to send a message to ourself!" ); // This message is outbound. - self.inner.send_queue.enqueue(msg).await; - } - - pub fn send_queue(&self) -> &SendQueue { - &self.inner.send_queue - + self.send_queue.enqueue(msg).await; } } diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 55977623..a96b8179 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use async_bincode::tokio::AsyncBincodeStream; use async_bincode::AsyncDestination; @@ -12,10 +13,10 @@ use tokio::time::{Duration, Instant}; use crate::linc::proto::ProtoError; use crate::linc::CURRENT_PROTO_VERSION; -use super::bus::{Bus}; +use super::bus::Bus; use super::handler::Handler; use super::proto::{Enveloppe, Message}; -use super::{NodeId, Outbound, Inbound}; +use super::{Inbound, NodeId, Outbound}; /// A connection to another node. Manage the connection state, and (de)register streams with the /// `Bus` @@ -30,7 +31,7 @@ pub struct Connection { is_initiator: bool, /// send queue for this connection send_queue: Option>, - bus: Bus, + bus: Arc>, } #[derive(Debug)] @@ -85,18 +86,18 @@ where { const MAX_CONNECTION_MESSAGES: usize = 128; - pub fn new_initiator(stream: S, bus: Bus) -> Self { + pub fn new_initiator(stream: S, bus: Arc>) -> Self { Self { peer: None, state: ConnectionState::Init, conn: AsyncBincodeStream::from(stream).for_async(), is_initiator: true, send_queue: None, - bus, + bus, } } - pub fn new_acceptor(stream: S, bus: Bus) -> Self { + pub fn new_acceptor(stream: S, bus: Arc>) -> Self { Connection { peer: None, state: ConnectionState::Connecting, @@ -231,7 +232,10 @@ where Ok(()) } } - Ok(Some(Ok(Enveloppe { message: Message::Error(e), ..}))) => { + Ok(Some(Ok(Enveloppe { + message: Message::Error(e), + .. + }))) => { bail!("handshake error: {e}"); } Ok(Some(Ok(_))) => { diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index 812745d3..26c5d923 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::sync::Arc; use itertools::Itertools; use tokio::task::JoinSet; @@ -13,11 +14,14 @@ use super::{bus::Bus, NodeId}; struct ConnectionPool { managed_peers: HashMap, connections: JoinSet, - bus: Bus, + bus: Arc>, } impl ConnectionPool { - pub fn new(bus: Bus, managed_peers: impl IntoIterator) -> Self { + pub fn new( + bus: Arc>, + managed_peers: impl IntoIterator, + ) -> Self { Self { managed_peers: managed_peers.into_iter().collect(), connections: JoinSet::new(), @@ -77,14 +81,14 @@ mod test { use tokio::sync::Notify; use tokio_stream::StreamExt; - use crate::linc::{server::Server, DatabaseId}; + use crate::linc::{server::Server, AllocId}; use super::*; #[test] fn manage_connections() { let mut sim = turmoil::Builder::new().build(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); let notify = Arc::new(Notify::new()); let expected_msg = crate::linc::proto::StreamMessage::Proxy( diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index c8db9a89..6a6ae6f8 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -1,6 +1,9 @@ -use super::{bus::{Bus}, Inbound}; +use std::sync::Arc; +use super::bus::Bus; +use super::Inbound; + +#[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { - fn handle(&self, bus: &Bus, msg: Inbound); + async fn handle(&self, bus: Arc>, msg: Inbound); } - diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index 8a3747bd..fa787e87 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -1,17 +1,14 @@ -use uuid::Uuid; - -use self::proto::Enveloppe; +use self::proto::{Enveloppe, Message}; pub mod bus; pub mod connection; pub mod connection_pool; +pub mod handler; pub mod net; pub mod proto; pub mod server; -pub mod handler; -type NodeId = Uuid; -type DatabaseId = Uuid; +pub type NodeId = u64; const CURRENT_PROTO_VERSION: u32 = 1; const MAX_STREAM_MSG: usize = 64; @@ -24,9 +21,21 @@ pub struct Inbound { pub enveloppe: Enveloppe, } +impl Inbound { + pub fn respond(&self, message: Message) -> Outbound { + Outbound { + to: self.from, + enveloppe: Enveloppe { + from: self.enveloppe.to, + to: self.enveloppe.from, + message, + }, + } + } +} + #[derive(Debug)] pub struct Outbound { pub to: NodeId, pub enveloppe: Enveloppe, } - diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index 617f2d87..c099cbd1 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -2,7 +2,9 @@ use bytes::Bytes; use serde::{Deserialize, Serialize}; use uuid::Uuid; -use super::{DatabaseId}; +use crate::meta::DatabaseId; + +use super::NodeId; pub type Program = String; @@ -18,10 +20,55 @@ pub enum Message { /// Initial message exchanged between nodes when connecting Handshake { protocol_version: u32, - node_id: Uuid, + node_id: NodeId, + }, + ReplicationHandshake { + database_name: String, + }, + ReplicationHandshakeResponse { + /// id of the replication log + log_id: Uuid, + /// current frame_no of the primary + current_frame_no: u64, + }, + Replicate { + /// next frame no to send + next_frame_no: u64, + }, + /// a batch of frames that are part of the same transaction + Transaction { + /// if not None, then the last frame is a commit frame, and this is the new size of the database. + size_after: Option, + /// frame_no of the last frame in frames + end_frame_no: u64, + /// a batch of frames part of the transaction. + frames: Vec, + }, + /// Proxy a query to a primary + ProxyRequest { + /// id of the connection to perform the query against + /// If the connection doesn't already exist it is created + /// Id of the request. + /// Responses to this request must have the same id. + connection_id: u32, + req_id: u32, + program: Program, + }, + /// Response to a proxied query + ProxyResponse { + /// id of the request this message is a response to. + req_id: u32, + /// Collection of steps to drive the query builder transducer. + row_step: Vec, + }, + /// Stop processing request `id`. + CancelRequest { + req_id: u32, + }, + /// Close Connection with passed id. + CloseConnection { + connection_id: u32, }, - Replication(ReplicationMessage), - Proxy(ProxyMessage), Error(ProtoError), } @@ -31,13 +78,15 @@ pub enum ProtoError { #[error("invalid protocol version, expected: {expected}")] HandshakeVersionMismatch { expected: u32 }, #[error("unknown database {0}")] - UnknownDatabase(DatabaseId), + UnknownDatabase(String), } #[derive(Debug, Serialize, Deserialize, PartialEq, Eq, Clone)] pub enum ReplicationMessage { - Handshake {}, - HandshakeResponse { + ReplicationHandshake { + database_name: String, + }, + ReplicationHandshakeResponse { /// id of the replication log log_id: Uuid, /// current frame_no of the primary diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index 0594bd9e..b462d0a1 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -1,20 +1,22 @@ +use std::sync::Arc; + use tokio::io::{AsyncRead, AsyncWrite}; use tokio::task::JoinSet; use crate::linc::connection::Connection; -use super::bus::{Bus}; +use super::bus::Bus; use super::handler::Handler; pub struct Server { /// reference to the bus - bus: Bus, + bus: Arc>, /// Connection tasks owned by the server connections: JoinSet>, } impl Server { - pub fn new(bus: Bus) -> Self { + pub fn new(bus: Arc>) -> Self { Self { bus, connections: JoinSet::new(), @@ -70,10 +72,7 @@ impl Server { mod test { use std::sync::Arc; - use crate::linc::{ - proto::{ProxyMessage}, - DatabaseId, NodeId, - }; + use crate::linc::{proto::ProxyMessage, AllocId, NodeId}; use super::*; @@ -125,7 +124,7 @@ mod test { let mut sim = turmoil::Builder::new().build(); let host_node_id = NodeId::new_v4(); - let stream_db_id = DatabaseId::new_v4(); + let stream_db_id = AllocId::new_v4(); let notify = Arc::new(Notify::new()); let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { connection_id: 12, @@ -195,7 +194,7 @@ mod test { let mut sim = turmoil::Builder::new().build(); let host_node_id = NodeId::new_v4(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); let notify = Arc::new(Notify::new()); sim.host("host", { @@ -251,7 +250,7 @@ mod test { let host_node_id = NodeId::new_v4(); let notify = Arc::new(Notify::new()); let client_id = NodeId::new_v4(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); let expected_msg = StreamMessage::Proxy(ProxyMessage::ProxyRequest { connection_id: 12, req_id: 1, @@ -309,7 +308,7 @@ mod test { let host_node_id = NodeId::new_v4(); let client_id = NodeId::new_v4(); - let database_id = DatabaseId::new_v4(); + let database_id = AllocId::new_v4(); sim.host("host", { move || async move { diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index a8360402..d5b0c35f 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -8,6 +8,7 @@ use config::{AdminApiConfig, UserApiConfig}; use http::admin::run_admin_api; use http::user::run_user_api; use hyper::server::conn::AddrIncoming; +use linc::bus::Bus; use manager::Manager; use meta::Store; use tokio::task::JoinSet; @@ -49,10 +50,11 @@ async fn spawn_user_api( set: &mut JoinSet>, config: &UserApiConfig, manager: Arc, + bus: Arc>>, ) -> Result<()> { let user_api_listener = tokio::net::TcpListener::bind(config.addr).await?; set.spawn(run_user_api( - http::user::Config { manager }, + http::user::Config { manager, bus }, AddrIncoming::from_listener(user_api_listener)?, )); @@ -71,9 +73,10 @@ async fn main() -> Result<()> { let store = Arc::new(Store::new(&config.db_path)); let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); + let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); spawn_admin_api(&mut join_set, &config.admin_api, store.clone()).await?; - spawn_user_api(&mut join_set, &config.user_api, manager).await?; + spawn_user_api(&mut join_set, &config.user_api, manager, bus).await?; join_set.join_next().await; diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 48315e0a..62f86479 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -7,10 +7,13 @@ use tokio::task::JoinSet; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::hrana; -use crate::meta::Store; +use crate::linc::bus::Bus; +use crate::linc::handler::Handler; +use crate::linc::Inbound; +use crate::meta::{DatabaseId, Store}; pub struct Manager { - cache: Cache>, + cache: Cache>, meta_store: Arc, db_path: PathBuf, } @@ -27,13 +30,17 @@ impl Manager { } /// Returns a handle to an allocation, lazily initializing if it isn't already loaded. - pub async fn alloc(&self, alloc_id: &str) -> Option> { - if let Some(sender) = self.cache.get(alloc_id) { + pub async fn alloc( + self: &Arc, + database_id: DatabaseId, + bus: Arc>>, + ) -> Option> { + if let Some(sender) = self.cache.get(&database_id) { return Some(sender.clone()); } - if let Some(config) = self.meta_store.meta(alloc_id).await { - let path = self.db_path.join("dbs").join(alloc_id); + if let Some(config) = self.meta_store.meta(&database_id).await { + let path = self.db_path.join("dbs").join(database_id.to_string()); tokio::fs::create_dir_all(&path).await.unwrap(); let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { @@ -42,14 +49,14 @@ impl Manager { connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, - hrana_server: Arc::new(hrana::http::Server::new(None)), // TODO: handle self URL? + hrana_server: Arc::new(hrana::http::Server::new(None)), + dispatcher: bus, // TODO: handle self URL? + db_name: config.db_name, }; tokio::spawn(alloc.run()); - self.cache - .insert(alloc_id.to_string(), alloc_sender.clone()) - .await; + self.cache.insert(database_id, alloc_sender.clone()).await; return Some(alloc_sender); } @@ -57,3 +64,16 @@ impl Manager { None } } + +#[async_trait::async_trait] +impl Handler for Arc { + async fn handle(&self, bus: Arc>, msg: Inbound) { + if let Some(sender) = self + .clone() + .alloc(msg.enveloppe.to.unwrap(), bus.clone()) + .await + { + let _ = sender.send(AllocationMessage::Inbound(msg)).await; + } + } +} diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 06e37a76..b71b33eb 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -1,7 +1,11 @@ +use core::fmt; use std::path::Path; +use serde::{Deserialize, Serialize}; +use sha3::digest::{ExtendableOutput, Update, XofReader}; +use sha3::Shake128; use sled::Tree; -use uuid::Uuid; +use tokio::task::block_in_place; use crate::allocation::config::AllocConfig; @@ -11,6 +15,32 @@ pub struct Store { meta_store: Tree, } +#[derive(Serialize, Deserialize, PartialEq, Eq, Debug, Hash, Clone, Copy)] +pub struct DatabaseId([u8; 16]); + +impl DatabaseId { + pub fn from_name(name: &str) -> Self { + let mut hasher = Shake128::default(); + hasher.update(name.as_bytes()); + let mut reader = hasher.finalize_xof(); + let mut out = [0; 16]; + reader.read(&mut out); + Self(out) + } +} + +impl fmt::Display for DatabaseId { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:x}", u128::from_be_bytes(self.0)) + } +} + +impl AsRef<[u8]> for DatabaseId { + fn as_ref(&self) -> &[u8] { + &self.0 + } +} + impl Store { pub fn new(path: &Path) -> Self { std::fs::create_dir_all(&path).unwrap(); @@ -21,31 +51,32 @@ impl Store { Self { meta_store } } - pub async fn allocate(&self, alloc_id: &str, meta: &AllocConfig) { + pub async fn allocate(&self, database_name: &str, meta: &AllocConfig) { //TODO: Handle conflict - tokio::task::block_in_place(|| { + block_in_place(|| { let meta_bytes = bincode::serialize(meta).unwrap(); + let id = DatabaseId::from_name(database_name); self.meta_store - .compare_and_swap(alloc_id, None as Option<&[u8]>, Some(meta_bytes)) + .compare_and_swap(id, None as Option<&[u8]>, Some(meta_bytes)) .unwrap() .unwrap(); }); } - pub async fn deallocate(&self, _alloc_id: Uuid) { + pub async fn deallocate(&self, _database_name: &str) { todo!() } - pub async fn meta(&self, alloc_id: &str) -> Option { - tokio::task::block_in_place(|| { - let config = self.meta_store.get(alloc_id).unwrap()?; + pub async fn meta(&self, database_id: &DatabaseId) -> Option { + block_in_place(|| { + let config = self.meta_store.get(database_id).unwrap()?; let config = bincode::deserialize(config.as_ref()).unwrap(); Some(config) }) } pub async fn list_allocs(&self) -> Vec { - tokio::task::block_in_place(|| { + block_in_place(|| { let mut out = Vec::new(); for kv in self.meta_store.iter() { let (_k, v) = kv.unwrap(); diff --git a/libsqlx/Cargo.toml b/libsqlx/Cargo.toml index b0f00521..abdb39ad 100644 --- a/libsqlx/Cargo.toml +++ b/libsqlx/Cargo.toml @@ -12,8 +12,6 @@ serde = "1.0.164" serde_json = "1.0.99" rusqlite = { workspace = true } anyhow = "1.0.71" -futures = "0.3.28" -tokio = { version = "1.28.2", features = ["sync", "time"] } sqlite3-parser = "0.9.0" fallible-iterator = "0.3.0" bytes = "1.4.0" diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index 1682e3b4..19fd51ce 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -37,7 +37,7 @@ pub struct Injector { /// This trait trait is used to record the last committed frame_no to the log. /// The implementer can persist the pre and post commit frame no, and compare them in the event of /// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -pub trait InjectorCommitHandler: 'static { +pub trait InjectorCommitHandler: Send + 'static { fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 2844a204..9cb32c20 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -139,7 +139,12 @@ impl LibsqlDatabase { dirty: bool, ) -> crate::Result { let ty = PrimaryType { - logger: Arc::new(ReplicationLogger::open(&db_path, dirty, compactor)?), + logger: Arc::new(ReplicationLogger::open( + &db_path, + dirty, + compactor, + Box::new(|_| ()), + )?), }; Ok(Self::new(db_path, ty)) } @@ -174,7 +179,6 @@ impl Database for LibsqlDatabase { type Connection = LibsqlConnection<::Context>; fn connect(&self) -> Result { - dbg!(); Ok( LibsqlConnection::<::Context>::new( &self.db_path, diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index aebff0db..fe371258 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -19,7 +19,6 @@ use sqld_libsql_bindings::ffi::types::{ use sqld_libsql_bindings::ffi::PageHdrIter; use sqld_libsql_bindings::init_static_wal_method; use sqld_libsql_bindings::wal_hook::WalHook; -use tokio::sync::watch; use uuid::Uuid; use crate::database::frame::{Frame, FrameHeader}; @@ -329,7 +328,7 @@ impl ReplicationLoggerHookCtx { fn commit(&self) -> anyhow::Result<()> { let new_frame_no = self.logger.commit()?; - let _ = self.logger.new_frame_notifier.send(new_frame_no); + let _ = (self.logger.new_frame_notifier)(new_frame_no); Ok(()) } @@ -748,6 +747,8 @@ impl LogCompactor for () { } } +pub type FrameNotifierCb = Box; + pub struct ReplicationLogger { pub generation: Generation, pub log_file: RwLock, @@ -755,11 +756,16 @@ pub struct ReplicationLogger { db_path: PathBuf, /// a notifier channel other tasks can subscribe to, and get notified when new frames become /// available. - pub new_frame_notifier: watch::Sender, + pub new_frame_notifier: FrameNotifierCb, } impl ReplicationLogger { - pub fn open(db_path: &Path, dirty: bool, compactor: impl LogCompactor) -> crate::Result { + pub fn open( + db_path: &Path, + dirty: bool, + compactor: impl LogCompactor, + new_frame_notifier: FrameNotifierCb, + ) -> crate::Result { let log_path = db_path.join("wallog"); let data_path = db_path.join("data"); @@ -788,9 +794,14 @@ impl ReplicationLogger { }; if should_recover { - Self::recover(log_file, data_path, compactor) + Self::recover(log_file, data_path, compactor, new_frame_notifier) } else { - Self::from_log_file(db_path.to_path_buf(), log_file, compactor) + Self::from_log_file( + db_path.to_path_buf(), + log_file, + compactor, + new_frame_notifier, + ) } } @@ -798,12 +809,11 @@ impl ReplicationLogger { db_path: PathBuf, log_file: LogFile, compactor: impl LogCompactor, + new_frame_notifier: FrameNotifierCb, ) -> crate::Result { let header = log_file.header(); let generation_start_frame_no = header.start_frame_no + header.frame_count; - let (new_frame_notifier, _) = watch::channel(generation_start_frame_no); - Ok(Self { generation: Generation::new(generation_start_frame_no), compactor: Box::new(compactor), @@ -817,6 +827,7 @@ impl ReplicationLogger { log_file: LogFile, mut data_path: PathBuf, compactor: impl LogCompactor, + new_frame_notifier: FrameNotifierCb, ) -> crate::Result { // It is necessary to checkpoint before we restore the replication log, since the WAL may // contain pages that are not in the database file. @@ -849,7 +860,7 @@ impl ReplicationLogger { assert!(data_path.pop()); - Self::from_log_file(data_path, log_file, compactor) + Self::from_log_file(data_path, log_file, compactor, new_frame_notifier) } pub fn database_id(&self) -> anyhow::Result { diff --git a/libsqlx/src/error.rs b/libsqlx/src/error.rs index 07a71831..fd7828c1 100644 --- a/libsqlx/src/error.rs +++ b/libsqlx/src/error.rs @@ -46,11 +46,3 @@ pub enum Error { #[error(transparent)] LexerError(#[from] sqlite3_parser::lexer::sql::Error), } - -impl From for Error { - fn from(inner: tokio::sync::oneshot::error::RecvError) -> Self { - Self::Internal(format!( - "Failed to receive response via oneshot channel: {inner}" - )) - } -} diff --git a/sqld/src/query_result_builder.rs b/sqld/src/query_result_builder.rs index a9aeadd7..29a2dc91 100644 --- a/sqld/src/query_result_builder.rs +++ b/sqld/src/query_result_builder.rs @@ -642,7 +642,6 @@ pub mod test { } // this can be usefull to help debug the generated test case - dbg!(trace); b } From 20dc440c563e0ee5a2c3d643ade19815d8a9ca72 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Wed, 12 Jul 2023 19:50:20 +0200 Subject: [PATCH 3/5] replica: send replication request to primary --- libsqlx-server/src/allocation/mod.rs | 174 +++++++++++++++++-- libsqlx-server/src/linc/bus.rs | 18 +- libsqlx-server/src/linc/connection.rs | 10 +- libsqlx-server/src/linc/handler.rs | 1 + libsqlx-server/src/linc/mod.rs | 3 +- libsqlx-server/src/linc/proto.rs | 26 +-- libsqlx-server/src/manager.rs | 8 +- libsqlx/src/database/libsql/injector/hook.rs | 4 +- libsqlx/src/database/libsql/injector/mod.rs | 32 ++-- libsqlx/src/database/libsql/mod.rs | 30 +--- libsqlx/src/database/mod.rs | 9 +- libsqlx/src/database/proxy/database.rs | 2 +- libsqlx/src/database/proxy/mod.rs | 2 +- libsqlx/src/lib.rs | 5 +- 14 files changed, 236 insertions(+), 88 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index c9f88ed0..2b6c5faf 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,16 +1,20 @@ use std::path::PathBuf; use std::sync::Arc; +use std::time::{Duration, Instant}; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; -use libsqlx::Database as _; +use libsqlx::proxy::WriteProxyDatabase; +use libsqlx::{Database as _, DescribeResponse, Frame, InjectableDatabase, Injector, FrameNo}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; +use tokio::time::timeout; use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; -use crate::linc::{Inbound, NodeId}; +use crate::linc::proto::{Enveloppe, Message, Frames}; +use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; use self::config::{AllocConfig, DbConfig}; @@ -24,7 +28,6 @@ pub struct ConnectionId { id: u32, close_sender: mpsc::Sender<()>, } - pub enum AllocationMessage { NewConnection(oneshot::Sender), HranaPipelineReq { @@ -34,11 +37,40 @@ pub enum AllocationMessage { Inbound(Inbound), } +pub struct DummyDb; +pub struct DummyConn; + +impl libsqlx::Connection for DummyConn { + fn execute_program( + &mut self, + _pgm: libsqlx::program::Program, + _result_builder: &mut dyn libsqlx::result_builder::ResultBuilder, + ) -> libsqlx::Result<()> { + todo!() + } + + fn describe(&self, _sql: String) -> libsqlx::Result { + todo!() + } +} + +impl libsqlx::Database for DummyDb { + type Connection = DummyConn; + + fn connect(&self) -> Result { + todo!() + } +} + +type ProxyDatabase = WriteProxyDatabase, DummyDb>; + pub enum Database { - Primary(libsqlx::libsql::LibsqlDatabase), + Primary(LibsqlDatabase), Replica { - db: libsqlx::libsql::LibsqlDatabase, + db: ProxyDatabase, + injector_handle: mpsc::Sender, primary_node_id: NodeId, + last_received_frame_ts: Option, }, } @@ -59,14 +91,107 @@ impl LogCompactor for Compactor { } } +const MAX_INJECTOR_BUFFER_CAP: usize = 32; + +struct Replicator { + dispatcher: Arc, + req_id: u32, + last_committed: FrameNo, + next_seq: u32, + database_id: DatabaseId, + primary_node_id: NodeId, + injector: Box, + receiver: mpsc::Receiver, +} + +impl Replicator { + async fn run(mut self) { + loop { + match timeout(Duration::from_secs(5), self.receiver.recv()).await { + Ok(Some(Frames { + req_id, + seq, + frames, + })) => { + // ignore frames from a previous call to Replicate + if req_id != self.req_id { continue } + if seq != self.next_seq { + // this is not the batch of frame we were expecting, drop what we have, and + // ask again from last checkpoint + self.query_replicate().await; + continue; + }; + self.next_seq += 1; + for bytes in frames { + let frame = Frame::try_from_bytes(bytes).unwrap(); + block_in_place(|| { + if let Some(last_committed) = self.injector.inject(frame).unwrap() { + self.last_committed = last_committed; + } + }); + } + } + Err(_) => self.query_replicate().await, + Ok(None) => break, + } + } + } + + async fn query_replicate(&mut self) { + self.req_id += 1; + self.next_seq = 0; + // clear buffered, uncommitted frames + self.injector.clear(); + self.dispatcher + .dispatch(Outbound { + to: self.primary_node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Replicate { + next_frame_no: self.last_committed + 1, + req_id: self.req_id - 1, + }, + }, + }) + .await; + } +} + impl Database { - pub fn from_config(config: &AllocConfig, path: PathBuf) -> Self { + pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { DbConfig::Primary {} => { let db = LibsqlDatabase::new_primary(path, Compactor, false).unwrap(); Self::Primary(db) } - DbConfig::Replica { .. } => todo!(), + DbConfig::Replica { primary_node_id } => { + let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); + let wdb = DummyDb; + let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); + let injector = db.injector().unwrap(); + let (sender, receiver) = mpsc::channel(16); + let database_id = DatabaseId::from_name(&config.db_name); + + let replicator = Replicator { + dispatcher, + req_id: 0, + last_committed: 0, // TODO: load the last commited from meta file + next_seq: 0, + database_id, + primary_node_id, + injector, + receiver, + }; + + tokio::spawn(replicator.run()); + + Self::Replica { + db, + injector_handle: sender, + primary_node_id, + last_received_frame_ts: None, + } + } } } @@ -147,19 +272,32 @@ impl Allocation { } async fn handle_inbound(&mut self, msg: Inbound) { - debug_assert_eq!(msg.enveloppe.to, Some(DatabaseId::from_name(&self.db_name))); + debug_assert_eq!( + msg.enveloppe.database_id, + Some(DatabaseId::from_name(&self.db_name)) + ); match msg.enveloppe.message { - crate::linc::proto::Message::Handshake { .. } => todo!(), - crate::linc::proto::Message::ReplicationHandshake { .. } => todo!(), - crate::linc::proto::Message::ReplicationHandshakeResponse { .. } => todo!(), - crate::linc::proto::Message::Replicate { .. } => todo!(), - crate::linc::proto::Message::Transaction { .. } => todo!(), - crate::linc::proto::Message::ProxyRequest { .. } => todo!(), - crate::linc::proto::Message::ProxyResponse { .. } => todo!(), - crate::linc::proto::Message::CancelRequest { .. } => todo!(), - crate::linc::proto::Message::CloseConnection { .. } => todo!(), - crate::linc::proto::Message::Error(_) => todo!(), + Message::Handshake { .. } => todo!(), + Message::ReplicationHandshake { .. } => todo!(), + Message::ReplicationHandshakeResponse { .. } => todo!(), + Message::Replicate { .. } => todo!(), + Message::Frames(frames) => match &mut self.database { + Database::Replica { + injector_handle, + last_received_frame_ts, + .. + } => { + *last_received_frame_ts = Some(Instant::now()); + injector_handle.send(frames).await; + } + Database::Primary(_) => todo!("handle primary receiving txn"), + }, + Message::ProxyRequest { .. } => todo!(), + Message::ProxyResponse { .. } => todo!(), + Message::CancelRequest { .. } => todo!(), + Message::CloseConnection { .. } => todo!(), + Message::Error(_) => todo!(), } } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index bf9a2cc4..8beae22d 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -1,10 +1,16 @@ +use std::collections::HashSet; use std::sync::Arc; -use super::{connection::SendQueue, handler::Handler, Inbound, NodeId, Outbound}; +use parking_lot::RwLock; + +use super::connection::SendQueue; +use super::handler::Handler; +use super::{Inbound, NodeId, Outbound}; pub struct Bus { node_id: NodeId, handler: H, + peers: RwLock>, send_queue: SendQueue, } @@ -15,6 +21,7 @@ impl Bus { node_id, handler, send_queue, + peers: Default::default(), } } @@ -29,6 +36,15 @@ impl Bus { pub fn send_queue(&self) -> &SendQueue { &self.send_queue } + + pub fn connect(&self, node_id: NodeId) { + // TODO: handle peer already exists + self.peers.write().insert(node_id); + } + + pub fn disconnect(&self, node_id: NodeId) { + self.peers.write().remove(&node_id); + } } #[async_trait::async_trait] diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index a96b8179..170d1f2a 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -200,8 +200,7 @@ where }))) => { if protocol_version != CURRENT_PROTO_VERSION { let msg = Enveloppe { - from: None, - to: None, + database_id: None, message: Message::Error(ProtoError::HandshakeVersionMismatch { expected: CURRENT_PROTO_VERSION, }), @@ -215,8 +214,7 @@ where // handshake message if !self.is_initiator { let msg = Enveloppe { - from: None, - to: None, + database_id: None, message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), @@ -228,6 +226,7 @@ where self.peer = Some(node_id); self.state = ConnectionState::Connected; self.send_queue = Some(self.bus.send_queue().register(node_id)); + self.bus.connect(node_id); Ok(()) } @@ -255,8 +254,7 @@ where async fn initiate_connection(&mut self) -> color_eyre::Result<()> { let msg = Enveloppe { - from: None, - to: None, + database_id: None, message: Message::Handshake { protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 6a6ae6f8..6403906e 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -5,5 +5,6 @@ use super::Inbound; #[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { + /// Handle inbound message async fn handle(&self, bus: Arc>, msg: Inbound); } diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index fa787e87..638f56e2 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -26,8 +26,7 @@ impl Inbound { Outbound { to: self.from, enveloppe: Enveloppe { - from: self.enveloppe.to, - to: self.enveloppe.from, + database_id: None, message, }, } diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index c099cbd1..93ac445e 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -10,11 +10,21 @@ pub type Program = String; #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub struct Enveloppe { - pub from: Option, - pub to: Option, + pub database_id: Option, pub message: Message, } +#[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] +/// a batch of frames to inject +pub struct Frames{ + /// must match the Replicate request id + pub req_id: u32, + /// sequence id, monotonically incremented, reset when req_id changes. + /// Used to detect gaps in received frames. + pub seq: u32, + pub frames: Vec, +} + #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] pub enum Message { /// Initial message exchanged between nodes when connecting @@ -32,18 +42,12 @@ pub enum Message { current_frame_no: u64, }, Replicate { + /// incremental request id, used when responding with a Frames message + req_id: u32, /// next frame no to send next_frame_no: u64, }, - /// a batch of frames that are part of the same transaction - Transaction { - /// if not None, then the last frame is a commit frame, and this is the new size of the database. - size_after: Option, - /// frame_no of the last frame in frames - end_frame_no: u64, - /// a batch of frames part of the transaction. - frames: Vec, - }, + Frames(Frames), /// Proxy a query to a primary ProxyRequest { /// id of the connection to perform the query against diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 62f86479..89604569 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -20,6 +20,10 @@ pub struct Manager { const MAX_ALLOC_MESSAGE_QUEUE_LEN: usize = 32; +trait IsSync: Sync {} + +impl IsSync for Allocation {} + impl Manager { pub fn new(db_path: PathBuf, meta_store: Arc, max_conccurent_allocs: u64) -> Self { Self { @@ -45,7 +49,7 @@ impl Manager { let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, - database: Database::from_config(&config, path), + database: Database::from_config(&config, path, bus.clone()), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, @@ -70,7 +74,7 @@ impl Handler for Arc { async fn handle(&self, bus: Arc>, msg: Inbound) { if let Some(sender) = self .clone() - .alloc(msg.enveloppe.to.unwrap(), bus.clone()) + .alloc(msg.enveloppe.database_id.unwrap(), bus.clone()) .await { let _ = sender.send(AllocationMessage::Inbound(msg)).await; diff --git a/libsqlx/src/database/libsql/injector/hook.rs b/libsqlx/src/database/libsql/injector/hook.rs index f87172db..2cb5348d 100644 --- a/libsqlx/src/database/libsql/injector/hook.rs +++ b/libsqlx/src/database/libsql/injector/hook.rs @@ -42,7 +42,7 @@ impl InjectorHookCtx { wal: *mut Wal, ) -> anyhow::Result<()> { self.is_txn = true; - let buffer = self.buffer.borrow(); + let buffer = self.buffer.lock(); let (mut headers, last_frame_no, size_after) = make_page_header(buffer.iter().map(|f| &**f)); if size_after != 0 { @@ -157,7 +157,7 @@ unsafe impl WalHook for InjectorHook { return LIBSQL_INJECT_FATAL; } - ctx.buffer.borrow_mut().clear(); + ctx.buffer.lock().clear(); if !ctx.is_txn { LIBSQL_INJECT_OK diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index 19fd51ce..0c2c2207 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -1,15 +1,15 @@ -use std::cell::RefCell; use std::collections::VecDeque; use std::path::Path; -use std::rc::Rc; +use std::sync::Arc; +use parking_lot::Mutex; use rusqlite::OpenFlags; use crate::database::frame::Frame; use crate::database::libsql::injector::hook::{ INJECTOR_METHODS, LIBSQL_INJECT_FATAL, LIBSQL_INJECT_OK, LIBSQL_INJECT_OK_TXN, }; -use crate::database::FrameNo; +use crate::database::{FrameNo, InjectError}; use crate::seal::Seal; use hook::InjectorHookCtx; @@ -17,7 +17,7 @@ use hook::InjectorHookCtx; mod headers; mod hook; -pub type FrameBuffer = Rc>>; +pub type FrameBuffer = Arc>>; pub struct Injector { /// The injector is in a transaction state @@ -33,11 +33,22 @@ pub struct Injector { _hook_ctx: Seal>, } +impl crate::database::Injector for Injector { + fn inject(&mut self, frame: Frame) -> Result, InjectError> { + let res = self.inject_frame(frame).unwrap(); + Ok(res) + } + + fn clear(&mut self) { + self.buffer.lock().clear(); + } +} + /// Methods from this trait are called before and after performing a frame injection. /// This trait trait is used to record the last committed frame_no to the log. /// The implementer can persist the pre and post commit frame no, and compare them in the event of /// a crash; if the pre and post commit frame_no don't match, then the log may be corrupted. -pub trait InjectorCommitHandler: Send + 'static { +pub trait InjectorCommitHandler: Send + Sync + 'static { fn pre_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; fn post_commit(&mut self, frame_no: FrameNo) -> anyhow::Result<()>; } @@ -52,7 +63,6 @@ impl InjectorCommitHandler for Box { } } -#[cfg(test)] impl InjectorCommitHandler for () { fn pre_commit(&mut self, _frame_no: FrameNo) -> anyhow::Result<()> { Ok(()) @@ -95,8 +105,8 @@ impl Injector { /// Inject on frame into the log. If this was a commit frame, returns Ok(Some(FrameNo)). pub(crate) fn inject_frame(&mut self, frame: Frame) -> crate::Result> { let frame_close_txn = frame.header().size_after != 0; - self.buffer.borrow_mut().push_back(frame); - if frame_close_txn || self.buffer.borrow().len() >= self.capacity { + self.buffer.lock().push_back(frame); + if frame_close_txn || self.buffer.lock().len() >= self.capacity { if !self.is_txn { self.begin_txn(); } @@ -110,7 +120,7 @@ impl Injector { /// Trigger a dummy write, and flush the cache to trigger a call to xFrame. The buffer's frame /// are then injected into the wal. fn flush(&mut self) -> crate::Result> { - let last_frame_no = match self.buffer.borrow().back() { + let last_frame_no = match self.buffer.lock().back() { Some(f) => f.header().frame_no, None => { tracing::trace!("nothing to inject"); @@ -130,11 +140,11 @@ impl Injector { .pragma_update(None, "writable_schema", "reset")?; self.commit(); self.is_txn = false; - assert!(self.buffer.borrow().is_empty()); + assert!(self.buffer.lock().is_empty()); return Ok(Some(last_frame_no)); } else if e.extended_code == LIBSQL_INJECT_OK_TXN { self.is_txn = true; - assert!(self.buffer.borrow().is_empty()); + assert!(self.buffer.lock().is_empty()); return Ok(None); } else if e.extended_code == LIBSQL_INJECT_FATAL { todo!("handle fatal error"); diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 9cb32c20..dbd1d285 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -4,8 +4,7 @@ use std::sync::Arc; use sqld_libsql_bindings::wal_hook::{TransparentMethods, WalHook, TRANSPARENT_METHODS}; use sqld_libsql_bindings::WalMethodsHook; -use crate::database::frame::Frame; -use crate::database::{Database, InjectError, InjectableDatabase}; +use crate::database::{Database, InjectableDatabase}; use crate::error::Error; use crate::result_builder::QueryBuilderConfig; @@ -68,18 +67,6 @@ pub trait LibsqlDbType { fn hook_context(&self) -> ::Context; } -pub struct PlainType; - -impl LibsqlDbType for PlainType { - type ConnectionHook = TransparentMethods; - - fn hook() -> &'static WalMethodsHook { - &TRANSPARENT_METHODS - } - - fn hook_context(&self) -> ::Context {} -} - /// A generic wrapper around a libsql database. /// `LibsqlDatabase` can be specialized into either a `ReplicaType` or a `PrimaryType`. /// In `PrimaryType` mode, the LibsqlDatabase maintains a replication log that can be replicated to @@ -125,12 +112,6 @@ impl LibsqlDatabase { } } -impl LibsqlDatabase { - pub fn new_plain(db_path: PathBuf) -> crate::Result { - Ok(Self::new(db_path, PlainType)) - } -} - impl LibsqlDatabase { pub fn new_primary( db_path: PathBuf, @@ -195,7 +176,7 @@ impl Database for LibsqlDatabase { } impl InjectableDatabase for LibsqlDatabase { - fn injector(&mut self) -> crate::Result> { + fn injector(&mut self) -> crate::Result> { let Some(commit_handler) = self.ty.commit_handler.take() else { panic!("there can be only one injector") }; Ok(Box::new(Injector::new( &self.db_path, @@ -205,13 +186,6 @@ impl InjectableDatabase for LibsqlDatabase { } } -impl super::Injector for Injector { - fn inject(&mut self, frame: Frame) -> Result<(), InjectError> { - self.inject_frame(frame).unwrap(); - Ok(()) - } -} - #[cfg(test)] mod test { use std::fs::File; diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 62581402..368ac5ac 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -1,6 +1,5 @@ use std::time::Duration; -use self::frame::Frame; use crate::connection::Connection; use crate::error::Error; @@ -10,6 +9,8 @@ pub mod proxy; #[cfg(test)] mod test_utils; +pub use frame::Frame; + pub type FrameNo = u64; pub const TXN_TIMEOUT: Duration = Duration::from_secs(5); @@ -24,10 +25,12 @@ pub trait Database { } pub trait InjectableDatabase { - fn injector(&mut self) -> crate::Result>; + fn injector(&mut self) -> crate::Result>; } // Trait implemented by databases that support frame injection pub trait Injector { - fn inject(&mut self, frame: Frame) -> Result<(), InjectError>; + fn inject(&mut self, frame: Frame) -> Result, InjectError>; + /// clear internal buffer + fn clear(&mut self); } diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index 129cc5e2..e9add71f 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -41,7 +41,7 @@ impl InjectableDatabase for WriteProxyDatabase where RDB: InjectableDatabase, { - fn injector(&mut self) -> crate::Result> { + fn injector(&mut self) -> crate::Result> { self.read_db.injector() } } diff --git a/libsqlx/src/database/proxy/mod.rs b/libsqlx/src/database/proxy/mod.rs index 1b5b3226..62c6925d 100644 --- a/libsqlx/src/database/proxy/mod.rs +++ b/libsqlx/src/database/proxy/mod.rs @@ -8,4 +8,4 @@ mod database; pub use database::WriteProxyDatabase; // Waits until passed frameno has been replicated back to the database -type WaitFrameNoCb = Arc; +type WaitFrameNoCb = Arc; diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index a89c2771..e004317e 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -10,10 +10,11 @@ mod seal; pub type Result = std::result::Result; -pub use connection::Connection; +pub use connection::{Connection, DescribeResponse}; pub use database::libsql; pub use database::libsql::replication_log::FrameNo; pub use database::proxy; -pub use database::Database; +pub use database::Frame; +pub use database::{Database, InjectableDatabase, Injector}; pub use rusqlite; From 1989ae8b4b0d1165c6f4391b042fceffcee01d87 Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 13 Jul 2023 10:21:22 +0200 Subject: [PATCH 4/5] replica sends replicate request to primary --- libsqlx-server/Cargo.toml | 4 ++-- libsqlx-server/src/allocation/mod.rs | 12 +++++++--- libsqlx-server/src/config.rs | 8 +++---- libsqlx-server/src/linc/bus.rs | 2 +- libsqlx-server/src/linc/connection.rs | 10 +++++++-- libsqlx-server/src/linc/connection_pool.rs | 13 ++++++++--- libsqlx-server/src/linc/net.rs | 5 +++++ libsqlx-server/src/linc/server.rs | 5 ++++- libsqlx-server/src/main.rs | 26 +++++++++++++++++++--- 9 files changed, 66 insertions(+), 19 deletions(-) diff --git a/libsqlx-server/Cargo.toml b/libsqlx-server/Cargo.toml index efff738e..86beceda 100644 --- a/libsqlx-server/Cargo.toml +++ b/libsqlx-server/Cargo.toml @@ -11,7 +11,7 @@ async-trait = "0.1.71" axum = "0.6.18" base64 = "0.21.2" bincode = "1.3.3" -bytes = "1.4.0" +bytes = { version = "1.4.0", features = ["serde"] } clap = { version = "4.3.11", features = ["derive"] } color-eyre = "0.6.2" futures = "0.3.28" @@ -36,7 +36,7 @@ tokio-util = "0.7.8" toml = "0.7.6" tracing = "0.1.37" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } -uuid = { version = "1.4.0", features = ["v4"] } +uuid = { version = "1.4.0", features = ["v4", "serde"] } [dev-dependencies] turmoil = "0.5.5" diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 2b6c5faf..743dbdd4 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -58,7 +58,7 @@ impl libsqlx::Database for DummyDb { type Connection = DummyConn; fn connect(&self) -> Result { - todo!() + Ok(DummyConn) } } @@ -106,6 +106,9 @@ struct Replicator { impl Replicator { async fn run(mut self) { + dbg!(); + self.query_replicate().await; + dbg!(); loop { match timeout(Duration::from_secs(5), self.receiver.recv()).await { Ok(Some(Frames { @@ -281,7 +284,10 @@ impl Allocation { Message::Handshake { .. } => todo!(), Message::ReplicationHandshake { .. } => todo!(), Message::ReplicationHandshakeResponse { .. } => todo!(), - Message::Replicate { .. } => todo!(), + Message::Replicate { .. } => match &mut self.database { + Database::Primary(_) => todo!(), + Database::Replica { .. } => (), + }, Message::Frames(frames) => match &mut self.database { Database::Replica { injector_handle, @@ -289,7 +295,7 @@ impl Allocation { .. } => { *last_received_frame_ts = Some(Instant::now()); - injector_handle.send(frames).await; + injector_handle.send(frames).await.unwrap(); } Database::Primary(_) => todo!("handle primary receiving txn"), }, diff --git a/libsqlx-server/src/config.rs b/libsqlx-server/src/config.rs index f0f9ca0c..84b961eb 100644 --- a/libsqlx-server/src/config.rs +++ b/libsqlx-server/src/config.rs @@ -30,7 +30,7 @@ pub struct ClusterConfig { /// Address to bind this node to #[serde(default = "default_linc_addr")] pub addr: SocketAddr, - /// List of peers in the format `:` + /// List of peers in the format `@` pub peers: Vec, } @@ -64,8 +64,8 @@ fn default_linc_addr() -> SocketAddr { #[derive(Debug, Clone)] pub struct Peer { - id: u64, - addr: String, + pub id: u64, + pub addr: String, } impl<'de> Deserialize<'de> for Peer { @@ -86,7 +86,7 @@ impl<'de> Deserialize<'de> for Peer { where E: serde::de::Error, { - let mut iter = v.split(":"); + let mut iter = v.split("@"); let Some(id) = iter.next() else { return Err(E::custom("node id is missing")) }; let Ok(id) = id.parse::() else { return Err(E::custom("failed to parse node id")) }; let Some(addr) = iter.next() else { return Err(E::custom("node address is missing")) }; diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 8beae22d..4707c989 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -30,7 +30,7 @@ impl Bus { } pub async fn incomming(self: &Arc, incomming: Inbound) { - self.handler.handle(self.clone(), incomming); + self.handler.handle(self.clone(), incomming).await; } pub fn send_queue(&self) -> &SendQueue { diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index 170d1f2a..e12838cd 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -68,7 +68,8 @@ impl SendQueue { None => todo!("no queue"), }; - sender.send(msg.enveloppe); + dbg!(); + sender.send(msg.enveloppe).unwrap(); } pub fn register(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { @@ -145,6 +146,7 @@ where m = self.conn.next() => { match m { Some(Ok(m)) => { + dbg!(); self.handle_message(m).await; } Some(Err(e)) => { @@ -157,11 +159,13 @@ where }, // TODO: pop send queue Some(m) = self.send_queue.as_mut().unwrap().recv() => { + dbg!(); self.conn.feed(m).await.unwrap(); // send as many as possible while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { self.conn.feed(m).await.unwrap(); } + dbg!(); self.conn.flush().await.unwrap(); } else => { @@ -216,13 +220,15 @@ where let msg = Enveloppe { database_id: None, message: Message::Handshake { - protocol_version: CURRENT_PROTO_VERSION, + protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), }, }; self.conn.send(msg).await?; } + tracing::info!("Connected to peer {node_id}"); + self.peer = Some(node_id); self.state = ConnectionState::Connected; self.send_queue = Some(self.bus.send_queue().register(node_id)); diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index 26c5d923..89a43a15 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -11,7 +11,7 @@ use super::net::Connector; use super::{bus::Bus, NodeId}; /// Manages a pool of connections to other peers, handling re-connection. -struct ConnectionPool { +pub struct ConnectionPool { managed_peers: HashMap, connections: JoinSet, bus: Arc>, @@ -23,16 +23,22 @@ impl ConnectionPool { managed_peers: impl IntoIterator, ) -> Self { Self { - managed_peers: managed_peers.into_iter().collect(), + managed_peers: managed_peers.into_iter().filter(|(id, _)| *id < bus.node_id()).collect(), connections: JoinSet::new(), bus, } } - pub async fn run(mut self) { + pub fn managed_count(&self) -> usize { + self.managed_peers.len() + } + + pub async fn run(mut self) -> color_eyre::Result<()> { self.init::().await; while self.tick::().await {} + + Ok(()) } pub async fn tick(&mut self) -> bool { @@ -66,6 +72,7 @@ impl ConnectionPool { let connection = Connection::new_initiator(stream, bus.clone()); connection.run().await; + dbg!(); peer_id }; diff --git a/libsqlx-server/src/linc/net.rs b/libsqlx-server/src/linc/net.rs index 430b6d08..2123c041 100644 --- a/libsqlx-server/src/linc/net.rs +++ b/libsqlx-server/src/linc/net.rs @@ -31,6 +31,7 @@ pub trait Listener { Self: 'a; fn accept(&self) -> Self::Future<'_>; + fn local_addr(&self) -> color_eyre::Result; } pub struct AcceptFut<'a>(&'a TcpListener); @@ -53,6 +54,10 @@ impl Listener for TcpListener { fn accept(&self) -> Self::Future<'_> { AcceptFut(self) } + + fn local_addr(&self) -> color_eyre::Result { + Ok(self.local_addr()?) + } } #[cfg(test)] diff --git a/libsqlx-server/src/linc/server.rs b/libsqlx-server/src/linc/server.rs index b462d0a1..f3eacec2 100644 --- a/libsqlx-server/src/linc/server.rs +++ b/libsqlx-server/src/linc/server.rs @@ -30,11 +30,14 @@ impl Server { while self.connections.join_next().await.is_some() {} } - pub async fn run(mut self, mut listener: L) + pub async fn run(mut self, mut listener: L) -> color_eyre::Result<()> where L: super::net::Listener, { + tracing::info!("Cluster server listening on {}", listener.local_addr()?); while self.tick(&mut listener).await {} + + Ok(()) } pub async fn tick(&mut self, listener: &mut L) -> bool diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index d5b0c35f..0d89f6f4 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -4,13 +4,14 @@ use std::sync::Arc; use clap::Parser; use color_eyre::eyre::Result; -use config::{AdminApiConfig, UserApiConfig}; +use config::{AdminApiConfig, UserApiConfig, ClusterConfig}; use http::admin::run_admin_api; use http::user::run_user_api; use hyper::server::conn::AddrIncoming; use linc::bus::Bus; use manager::Manager; use meta::Store; +use tokio::net::{TcpListener, TcpStream}; use tokio::task::JoinSet; use tracing::metadata::LevelFilter; use tracing_subscriber::prelude::*; @@ -36,7 +37,7 @@ async fn spawn_admin_api( config: &AdminApiConfig, meta_store: Arc, ) -> Result<()> { - let admin_api_listener = tokio::net::TcpListener::bind(config.addr).await?; + let admin_api_listener = TcpListener::bind(config.addr).await?; let fut = run_admin_api( http::admin::Config { meta_store }, AddrIncoming::from_listener(admin_api_listener)?, @@ -52,7 +53,7 @@ async fn spawn_user_api( manager: Arc, bus: Arc>>, ) -> Result<()> { - let user_api_listener = tokio::net::TcpListener::bind(config.addr).await?; + let user_api_listener = TcpListener::bind(config.addr).await?; set.spawn(run_user_api( http::user::Config { manager, bus }, AddrIncoming::from_listener(user_api_listener)?, @@ -61,6 +62,24 @@ async fn spawn_user_api( Ok(()) } +async fn spawn_cluster_networking( + set: &mut JoinSet>, + config: &ClusterConfig, + bus: Arc>>, +) -> Result<()> { + let server = linc::server::Server::new(bus.clone()); + + let listener = TcpListener::bind(config.addr).await?; + set.spawn(server.run(listener)); + + let pool = linc::connection_pool::ConnectionPool::new(bus, config.peers.iter().map(|p| (p.id, dbg!(p.addr.clone())))); + if pool.managed_count() > 0 { + set.spawn(pool.run::()); + } + + Ok(()) +} + #[tokio::main] async fn main() -> Result<()> { init(); @@ -75,6 +94,7 @@ async fn main() -> Result<()> { let manager = Arc::new(Manager::new(config.db_path.clone(), store.clone(), 100)); let bus = Arc::new(Bus::new(config.cluster.id, manager.clone())); + spawn_cluster_networking(&mut join_set, &config.cluster, bus.clone()).await?; spawn_admin_api(&mut join_set, &config.admin_api, store.clone()).await?; spawn_user_api(&mut join_set, &config.user_api, manager, bus).await?; From cc9c2c1f576f0c30102c35328dd4ac79387adc2a Mon Sep 17 00:00:00 2001 From: ad hoc Date: Thu, 13 Jul 2023 15:15:26 +0200 Subject: [PATCH 5/5] primary replicate to replica --- libsqlx-server/src/allocation/mod.rs | 167 +++++++++++++++--- libsqlx-server/src/linc/connection.rs | 6 +- libsqlx-server/src/linc/connection_pool.rs | 6 +- libsqlx-server/src/linc/proto.rs | 8 +- libsqlx-server/src/main.rs | 9 +- libsqlx/src/database/libsql/mod.rs | 8 +- .../database/libsql/replication_log/logger.rs | 13 +- libsqlx/src/lib.rs | 1 + 8 files changed, 176 insertions(+), 42 deletions(-) diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index 743dbdd4..fdd08a88 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -1,10 +1,16 @@ +use std::collections::HashMap; +use std::collections::hash_map::Entry; use std::path::PathBuf; use std::sync::Arc; use std::time::{Duration, Instant}; +use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, LogCompactor, LogFile, PrimaryType, ReplicaType}; use libsqlx::proxy::WriteProxyDatabase; -use libsqlx::{Database as _, DescribeResponse, Frame, InjectableDatabase, Injector, FrameNo}; +use libsqlx::{ + Database as _, DescribeResponse, Frame, FrameNo, InjectableDatabase, Injector, LogReadError, + ReplicationLogger, +}; use tokio::sync::{mpsc, oneshot}; use tokio::task::{block_in_place, JoinSet}; use tokio::time::timeout; @@ -13,7 +19,7 @@ use crate::hrana; use crate::hrana::http::handle_pipeline; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Dispatch; -use crate::linc::proto::{Enveloppe, Message, Frames}; +use crate::linc::proto::{Enveloppe, Frames, Message}; use crate::linc::{Inbound, NodeId, Outbound}; use crate::meta::DatabaseId; @@ -65,7 +71,11 @@ impl libsqlx::Database for DummyDb { type ProxyDatabase = WriteProxyDatabase, DummyDb>; pub enum Database { - Primary(LibsqlDatabase), + Primary { + db: LibsqlDatabase, + replica_streams: HashMap)>, + frame_notifier: tokio::sync::watch::Receiver, + }, Replica { db: ProxyDatabase, injector_handle: mpsc::Sender, @@ -96,7 +106,7 @@ const MAX_INJECTOR_BUFFER_CAP: usize = 32; struct Replicator { dispatcher: Arc, req_id: u32, - last_committed: FrameNo, + next_frame_no: FrameNo, next_seq: u32, database_id: DatabaseId, primary_node_id: NodeId, @@ -106,30 +116,36 @@ struct Replicator { impl Replicator { async fn run(mut self) { - dbg!(); self.query_replicate().await; - dbg!(); loop { match timeout(Duration::from_secs(5), self.receiver.recv()).await { Ok(Some(Frames { - req_id, - seq, + req_no: req_id, + seq_no: seq, frames, })) => { // ignore frames from a previous call to Replicate - if req_id != self.req_id { continue } - if seq != self.next_seq { + if req_id != self.req_id { + tracing::debug!(req_id, self.req_id, "wrong req_id"); + continue; + } + if seq != self.next_seq { // this is not the batch of frame we were expecting, drop what we have, and // ask again from last checkpoint + tracing::debug!(seq, self.next_seq, "wrong seq"); self.query_replicate().await; continue; }; self.next_seq += 1; + + tracing::debug!("injecting {} frames", frames.len()); + for bytes in frames { let frame = Frame::try_from_bytes(bytes).unwrap(); block_in_place(|| { if let Some(last_committed) = self.injector.inject(frame).unwrap() { - self.last_committed = last_committed; + tracing::debug!(last_committed); + self.next_frame_no = last_committed + 1; } }); } @@ -151,12 +167,71 @@ impl Replicator { enveloppe: Enveloppe { database_id: Some(self.database_id), message: Message::Replicate { - next_frame_no: self.last_committed + 1, - req_id: self.req_id - 1, + next_frame_no: self.next_frame_no, + req_no: self.req_id, }, }, }) - .await; + .await; + } +} + +struct FrameStreamer { + logger: Arc, + database_id: DatabaseId, + node_id: NodeId, + next_frame_no: FrameNo, + req_no: u32, + seq_no: u32, + dipatcher: Arc, + notifier: tokio::sync::watch::Receiver, + buffer: Vec, +} + +// the maximum number of frame a Frame messahe is allowed to contain +const FRAMES_MESSAGE_MAX_COUNT: usize = 5; + +impl FrameStreamer { + async fn run(mut self) { + loop { + match block_in_place(|| self.logger.get_frame(self.next_frame_no)) { + Ok(frame) => { + if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { + self.send_frames().await; + } + self.buffer.push(frame.bytes()); + self.next_frame_no += 1; + } + Err(LogReadError::Ahead) => { + tracing::debug!("frame {} not yet avaiblable", self.next_frame_no); + if !self.buffer.is_empty() { + self.send_frames().await; + } + if self.notifier.wait_for(|fno| dbg!(*fno) >= self.next_frame_no).await.is_err() { + break; + } + } + Err(LogReadError::Error(_)) => todo!("handle log read error"), + Err(LogReadError::SnapshotRequired) => todo!("handle reading from snapshot"), + } + } + } + + async fn send_frames(&mut self) { + let frames = std::mem::take(&mut self.buffer); + let outbound = Outbound { + to: self.node_id, + enveloppe: Enveloppe { + database_id: Some(self.database_id), + message: Message::Frames(Frames { + req_no: self.req_no, + seq_no: self.seq_no, + frames, + }), + }, + }; + self.seq_no += 1; + self.dipatcher.dispatch(outbound).await; } } @@ -164,9 +239,21 @@ impl Database { pub fn from_config(config: &AllocConfig, path: PathBuf, dispatcher: Arc) -> Self { match config.db_config { DbConfig::Primary {} => { - let db = LibsqlDatabase::new_primary(path, Compactor, false).unwrap(); - Self::Primary(db) - } + let (sender, receiver) = tokio::sync::watch::channel(0); + let db = LibsqlDatabase::new_primary( + path, + Compactor, + false, + Box::new(move |fno| { let _ = sender.send(fno); } ), + ) + .unwrap(); + + Self::Primary { + db, + replica_streams: HashMap::new(), + frame_notifier: receiver, + } + }, DbConfig::Replica { primary_node_id } => { let rdb = LibsqlDatabase::new_replica(path, MAX_INJECTOR_BUFFER_CAP, ()).unwrap(); let wdb = DummyDb; @@ -178,7 +265,7 @@ impl Database { let replicator = Replicator { dispatcher, req_id: 0, - last_committed: 0, // TODO: load the last commited from meta file + next_frame_no: 0, // TODO: load the last commited from meta file next_seq: 0, database_id, primary_node_id, @@ -200,7 +287,7 @@ impl Database { fn connect(&self) -> Box { match self { - Database::Primary(db) => Box::new(db.connect().unwrap()), + Database::Primary { db, .. } => Box::new(db.connect().unwrap()), Database::Replica { db, .. } => Box::new(db.connect().unwrap()), } } @@ -281,12 +368,44 @@ impl Allocation { ); match msg.enveloppe.message { - Message::Handshake { .. } => todo!(), + Message::Handshake { .. } => unreachable!("handshake should have been caught earlier"), Message::ReplicationHandshake { .. } => todo!(), Message::ReplicationHandshakeResponse { .. } => todo!(), - Message::Replicate { .. } => match &mut self.database { - Database::Primary(_) => todo!(), - Database::Replica { .. } => (), + Message::Replicate { req_no, next_frame_no } => match &mut self.database { + Database::Primary { db, replica_streams, frame_notifier } => { + dbg!(next_frame_no); + let streamer = FrameStreamer { + logger: db.logger(), + database_id: DatabaseId::from_name(&self.db_name), + node_id: msg.from, + next_frame_no, + req_no, + seq_no: 0, + dipatcher: self.dispatcher.clone(), + notifier: frame_notifier.clone(), + buffer: Vec::new(), + }; + + match replica_streams.entry(msg.from) { + Entry::Occupied(mut e) => { + let (old_req_no, old_handle) = e.get_mut(); + // ignore req_no older that the current req_no + if *old_req_no < req_no { + let handle = tokio::spawn(streamer.run()); + let old_handle = std::mem::replace(old_handle, handle); + *old_req_no = req_no; + old_handle.abort(); + } + }, + Entry::Vacant(e) => { + let handle = tokio::spawn(streamer.run()); + // For some reason, not yielding causes the task not to be spawned + tokio::task::yield_now().await; + e.insert((req_no, handle)); + }, + } + }, + Database::Replica { .. } => todo!("not a primary!"), }, Message::Frames(frames) => match &mut self.database { Database::Replica { @@ -297,7 +416,7 @@ impl Allocation { *last_received_frame_ts = Some(Instant::now()); injector_handle.send(frames).await.unwrap(); } - Database::Primary(_) => todo!("handle primary receiving txn"), + Database::Primary { .. } => todo!("handle primary receiving txn"), }, Message::ProxyRequest { .. } => todo!(), Message::ProxyResponse { .. } => todo!(), diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index e12838cd..bf5bd97e 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -68,7 +68,6 @@ impl SendQueue { None => todo!("no queue"), }; - dbg!(); sender.send(msg.enveloppe).unwrap(); } @@ -146,7 +145,6 @@ where m = self.conn.next() => { match m { Some(Ok(m)) => { - dbg!(); self.handle_message(m).await; } Some(Err(e)) => { @@ -159,13 +157,11 @@ where }, // TODO: pop send queue Some(m) = self.send_queue.as_mut().unwrap().recv() => { - dbg!(); self.conn.feed(m).await.unwrap(); // send as many as possible while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { self.conn.feed(m).await.unwrap(); } - dbg!(); self.conn.flush().await.unwrap(); } else => { @@ -220,7 +216,7 @@ where let msg = Enveloppe { database_id: None, message: Message::Handshake { - protocol_version: CURRENT_PROTO_VERSION, + protocol_version: CURRENT_PROTO_VERSION, node_id: self.bus.node_id(), }, }; diff --git a/libsqlx-server/src/linc/connection_pool.rs b/libsqlx-server/src/linc/connection_pool.rs index 89a43a15..b6113a80 100644 --- a/libsqlx-server/src/linc/connection_pool.rs +++ b/libsqlx-server/src/linc/connection_pool.rs @@ -23,7 +23,10 @@ impl ConnectionPool { managed_peers: impl IntoIterator, ) -> Self { Self { - managed_peers: managed_peers.into_iter().filter(|(id, _)| *id < bus.node_id()).collect(), + managed_peers: managed_peers + .into_iter() + .filter(|(id, _)| *id < bus.node_id()) + .collect(), connections: JoinSet::new(), bus, } @@ -72,7 +75,6 @@ impl ConnectionPool { let connection = Connection::new_initiator(stream, bus.clone()); connection.run().await; - dbg!(); peer_id }; diff --git a/libsqlx-server/src/linc/proto.rs b/libsqlx-server/src/linc/proto.rs index 93ac445e..bec6ff7a 100644 --- a/libsqlx-server/src/linc/proto.rs +++ b/libsqlx-server/src/linc/proto.rs @@ -16,12 +16,12 @@ pub struct Enveloppe { #[derive(Debug, Serialize, Deserialize, PartialEq, Eq)] /// a batch of frames to inject -pub struct Frames{ +pub struct Frames { /// must match the Replicate request id - pub req_id: u32, + pub req_no: u32, /// sequence id, monotonically incremented, reset when req_id changes. /// Used to detect gaps in received frames. - pub seq: u32, + pub seq_no: u32, pub frames: Vec, } @@ -43,7 +43,7 @@ pub enum Message { }, Replicate { /// incremental request id, used when responding with a Frames message - req_id: u32, + req_no: u32, /// next frame no to send next_frame_no: u64, }, diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index 0d89f6f4..454ae954 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -4,7 +4,7 @@ use std::sync::Arc; use clap::Parser; use color_eyre::eyre::Result; -use config::{AdminApiConfig, UserApiConfig, ClusterConfig}; +use config::{AdminApiConfig, ClusterConfig, UserApiConfig}; use http::admin::run_admin_api; use http::user::run_user_api; use hyper::server::conn::AddrIncoming; @@ -72,7 +72,10 @@ async fn spawn_cluster_networking( let listener = TcpListener::bind(config.addr).await?; set.spawn(server.run(listener)); - let pool = linc::connection_pool::ConnectionPool::new(bus, config.peers.iter().map(|p| (p.id, dbg!(p.addr.clone())))); + let pool = linc::connection_pool::ConnectionPool::new( + bus, + config.peers.iter().map(|p| (p.id, p.addr.clone())), + ); if pool.managed_count() > 0 { set.spawn(pool.run::()); } @@ -80,7 +83,7 @@ async fn spawn_cluster_networking( Ok(()) } -#[tokio::main] +#[tokio::main(flavor = "multi_thread", worker_threads = 10)] async fn main() -> Result<()> { init(); let args = Args::parse(); diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index dbd1d285..44952df6 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -15,6 +15,7 @@ use replication_log::logger::{ }; use self::injector::InjectorCommitHandler; +use self::replication_log::logger::FrameNotifierCb; pub use connection::LibsqlConnection; pub use replication_log::logger::{LogCompactor, LogFile}; @@ -118,17 +119,22 @@ impl LibsqlDatabase { compactor: impl LogCompactor, // whether the log is dirty and might need repair dirty: bool, + new_frame_notifier: FrameNotifierCb, ) -> crate::Result { let ty = PrimaryType { logger: Arc::new(ReplicationLogger::open( &db_path, dirty, compactor, - Box::new(|_| ()), + new_frame_notifier, )?), }; Ok(Self::new(db_path, ty)) } + + pub fn logger(&self) -> Arc { + self.ty.logger.clone() + } } impl LibsqlDatabase { diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index fe371258..e17c286c 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -441,6 +441,7 @@ impl LogFile { } pub fn commit(&mut self) -> crate::Result<()> { + dbg!(&self); self.header.frame_count += self.uncommitted_frame_count; self.uncommitted_frame_count = 0; self.commited_checksum = self.uncommitted_checksum; @@ -550,6 +551,7 @@ impl LogFile { /// If the requested frame is before the first frame in the log, or after the last frame, /// Ok(None) is returned. pub fn frame(&self, frame_no: FrameNo) -> std::result::Result { + dbg!(frame_no); if frame_no < self.header.start_frame_no { return Err(LogReadError::SnapshotRequired); } @@ -695,8 +697,12 @@ pub struct LogFileHeader { } impl LogFileHeader { - pub fn last_frame_no(&self) -> FrameNo { - self.start_frame_no + self.frame_count + pub fn last_frame_no(&self) -> Option { + if self.start_frame_no == 0 && self.frame_count == 0 { + None + } else { + Some(self.start_frame_no + self.frame_count - 1) + } } fn sqld_version(&self) -> Version { @@ -871,6 +877,7 @@ impl ReplicationLogger { /// Returns the new frame count and checksum to commit fn write_pages(&self, pages: &[WalPage]) -> anyhow::Result<()> { let mut log_file = self.log_file.write(); + dbg!(); for page in pages.iter() { log_file.push_page(page)?; } @@ -899,7 +906,7 @@ impl ReplicationLogger { fn commit(&self) -> anyhow::Result { let mut log_file = self.log_file.write(); log_file.commit()?; - Ok(log_file.header().last_frame_no()) + Ok(log_file.header().last_frame_no().expect("there should be at least one frame after commit")) } pub fn get_snapshot_file(&self, from: FrameNo) -> anyhow::Result> { diff --git a/libsqlx/src/lib.rs b/libsqlx/src/lib.rs index e004317e..13223d22 100644 --- a/libsqlx/src/lib.rs +++ b/libsqlx/src/lib.rs @@ -12,6 +12,7 @@ pub type Result = std::result::Result; pub use connection::{Connection, DescribeResponse}; pub use database::libsql; +pub use database::libsql::replication_log::logger::{LogReadError, ReplicationLogger}; pub use database::libsql::replication_log::FrameNo; pub use database::proxy; pub use database::Frame;