diff --git a/docs/LINC.md b/docs/LINC.md deleted file mode 100644 index 9a915c65..00000000 --- a/docs/LINC.md +++ /dev/null @@ -1,282 +0,0 @@ -# Libsql Inter-Node Communication protocol: LINC protocol - -## Overview - -This document describes the version 1 of Libsql Inter-Node Communication (LINC) -protocol. - -The first version of the protocol aims to merge the existing two -protocol (proxy and replication) into a single one, and adds support for multi-tenancy. - -LINC v1 is designed to handle 3 tasks: -- inter-node communication -- database replication -- proxying of request from replicas to primaries - -LINC makes use of streams to multiplex messages between databases on different nodes. - -LINC v1 is implemented on top of TCP. - -LINC uses bincode for message serialization and deserialization. - -## Connection protocol - -Each node is identified by a `node_id`, and an address. -At startup, a sqld node is configured with list of peers (`(node_id, node_addr)`). A connection between two peers is initiated by the peer with the greatest node_id. - -```mermaid -graph TD -node4 --> node3 -node4 --> node2 -node4 --> node1 -node3 --> node2 -node3 --> node1 -node2 --> node1 -node1 -``` - -A new node node can be added to the cluster with no reconfiguration as long as its `node_id` is greater than all other `node_id` in the cluster and it has the address of all the other nodes. In this case, the new node will initiate a connection with all other nodes. - -On disconnection, the initiator of the connection attempts to reconnect. - -## Messages - -```rust -enum Message { - /// Messages destined to a node - Node(NodeMessage), - /// message destined to a stream - Stream { - stream_id: StreamId, - payload: StreamMessage, - }, -} - -enum NodeMessage { - /// Initial message exchanged between nodes when connecting - Handshake { - protocol_version: String, - node_id: String, - }, - /// Request to open a bi-directional stream between the client and the server - OpenStream { - /// Id to give to the newly opened stream - stream_id: StreamId, - /// Id of the database to open the stream to. - database_id: Uuid, - }, - /// Close a previously opened stream - CloseStream { - id: StreamId, - }, - /// Error type returned while handling a node message - Error(NodeError), -} - -enum NodeError { - UnknownStream(StreamId), - HandshakeVersionMismatch { expected: u32 }, - StreamAlreadyExist(StreamId), - UnknownDatabase(DatabaseId, StreamId), -} - -enum StreamMessage { - /// Replication message between a replica and a primary - Replication(ReplicationMessage), - /// Proxy message between a replica and a primary - Proxy(ProxyMessage), - Error(StreamError), -} - -enum ReplicationMessage { - HandshakeResponse { - /// id of the replication log - log_id: Uuid, - /// current frame_no of the primary - current_frame_no: u64, - }, - /// Replication request - Replicate { - /// next frame no to send - next_frame_no: u64, - }, - /// a batch of frames that are part of the same transaction - Transaction { - /// if not None, then the last frame is a commit frame, and this is the new size of the database. - size_after: Option, - /// frame_no of the last frame in frames - end_frame_no: u64 - /// a batch of frames part of the transaction. - frames: Vec - }, - /// Error occurred handling a replication message - Error(StreamError) -} - -struct Frame { - /// Page id of that frame - page_id: u32, - /// Data - data: Bytes, -} - -enum ProxyMessage { - /// Proxy a query to a primary - ProxyRequest { - /// id of the connection to perform the query against - /// If the connection doesn't already exist it is created - /// Id of the request. - /// Responses to this request must have the same id. - connection_id: u32, - req_id: u32, - query: Query, - }, - /// Response to a proxied query - ProxyResponse { - /// id of the request this message is a response to. - req_id: u32, - /// Collection of steps to drive the query builder transducer. - row_step: [RowStep] - }, - /// Stop processing request `id`. - CancelRequest { - req_id: u32, - }, - /// Close Connection with passed id. - CloseConnection { - connection_id: u32, - }, -} - -/// Steps applied to the query builder transducer to build a response to a proxied query. -/// Those types closely mirror those of the `QueryBuilderTrait`. -enum BuilderStep { - BeginStep, - FinishStep(u64, Option), - StepError(StepError), - ColsDesc([Column]), - BeginRows, - BeginRow, - AddRowValue(Value), - FinishRow, - FinishRos, - Finish(ConnectionState) -} - -// State of the connection after a query was executed -enum ConnectionState { - /// The connection is still in a open transaction state - OpenTxn, - /// The connection is idle. - Idle, -} - -struct Column { - /// name of the column - name: string, - /// Declared type of the column, if any. - decl_ty: Option, -} - -/// for now, the stringified version of a sqld::error::Error. -struct StepError(String); - -enum StreamError { - NotAPrimary, - AlreadyReplicating, -} -``` - -## Node Handshake - -When a node connects to another node, it first need to perform a handshake. The -handshake is initialized by the initializer of the connection. It sends the -following message: - -```typescipt -type NodeHandshake = { - version: string, // protocol version - node_id: string, -} -``` - -If a peer receives a connection from a peer with a id smaller than his, it must reject the handshake with a `IllegalConnection` error - -## Streams - -Messages destined to a particular database are sent as part of a stream. A -stream is created by sending a `NodeMessage::OpenStream`, specifying the id of -the stream to open, along with the id of the database for which to open this -stream. If the requested database is not on the destination node, the -destination node respond with a `NodeError::UnknownDatabase` error, and the stream in not -opened. - -If a node receives a message for a stream that was not opened before, it responds a `NodeError::UnknownStream` - -A stream is closed by sending a `CloseStream` with the id of the stream. If the -stream does not exist an `NodeError::UnknownStream` error is returned. - -Streams can be opened by either peer. Each stream is identified with by `i32` -stream id. The peer that initiated the original connection allocates positive -stream ids, while the acceptor peer allocates negative ids. 0 is not a legal -value for a stream_id. The receiver of a request for a stream with id 0 must -close the connection immediately. - -The peer opening a stream is responsible for sending the close message. The -other peer can close the stream at any point, but must not send close message -for that stream. On subsequent message to that stream, it will respond with an -`UnknownStream` message, forcing the initiator to deal with recreating a -stream if necessary. - -## Sub-protocols - -### Replication - -The replica is responsible for initiating the replication protocol. This is -done by opening a stream to a primary. If the destination of the stream is not a -primary database, it responds with a `StreamError::NotAPrimary` error and immediately close -the stream. If the destination database is a primary, it responds to the stream -open request with a `ReplicationMessage::HandshakeResponse` message. This message informs the -replica of the current log version, and of the primary current replication -index (frame_no). - -The replica compares the log version it received from the primary with the one it has, if any. If the -versions don't match, the replica deletes its state and start replicating again from the start. - -After a successful handshake, the replica sends a `ReplicationMessage::Replicate` message with the -next frame_no it's expecting. For example if the replica has not replicated any -frame yet, it sends `ReplicationMessage::Replicate { next_frame_no: 0 }` to -signify to the primary that it's expecting to be sent frame 0. The primary -sends the smallest frame with a `frame_no` satisfying `frame_no >= -next_frame_no`. Because logs can be compacted, the next frame_no the primary -sends to the replica isn't necessarily the one the replica is expecting. It's correct to send -the smallest frame >= next_frame_no because frame_nos only move forward in the event of a compaction: a -frame can only be missing if it was written too more recently, hence _moving -forward_ in the log. The primary ensure consistency by moving commit points -accordingly. It is an error for the primary to send a frame_no strictly less -than the requested frame_no, frame_nos can be received in any order. - -In the event of a disconnection, it is the replica's duty to re-initiate the replication protocol. - -Sending a replicate request twice on the same stream is an error. If a primary -receives more than a single `Replicate` request, it closes the stream and sends -a `StreamError::AlreadyReplicating` request. The replica can re-open a stream and start -replicating again if necessary. - -### Proxy - -Replicas can proxy queries to their primary. Replica can start sending proxy request after they have sent a replication request. - -To proxy a query, a replica sends a `ProxyRequest`. Proxied query on a same connection are serialized. The replica sets the connection id -and the request id for the proxied query. If no connection exists for the -passed id on the primary, one is created. The query is executed on the primary, -and the result rows are returned in `ProxyResponse`. The result rows can be split -into multiple `ProxyResponse`, enabling row streaming. A replica can send a `CancelRequest` to interrupt a request. Any -`ProxyResponse` for that `request_id` can be dropped by the replica, and the -primary should stop sending any more `ProxyResponse` message upon receiving the -cancel request. The primary must rollback a cancelled request. - -The primary can reduce the amount of concurrent open transaction by closing the -underlying SQLite connection for proxied connections that are not in a open -transaction state (`is_autocommit` is true). Subsequent requests on that -connection id will re-open a connection, if necessary. diff --git a/libsqlx-server/src/allocation/mod.rs b/libsqlx-server/src/allocation/mod.rs index f5e19ba6..20be7cc4 100644 --- a/libsqlx-server/src/allocation/mod.rs +++ b/libsqlx-server/src/allocation/mod.rs @@ -20,9 +20,8 @@ use tokio::time::Interval; use crate::allocation::primary::FrameStreamer; use crate::allocation::timeout_notifier::timeout_monitor; use crate::compactor::CompactionQueue; -use crate::hrana; -use crate::hrana::http::handle_pipeline; -use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::error::Error; +use crate::hrana::proto::DescribeResult; use crate::linc::bus::Dispatch; use crate::linc::proto::{Frames, Message}; use crate::linc::{Inbound, NodeId}; @@ -54,9 +53,8 @@ pub enum ConnectionMessage { } pub enum AllocationMessage { - HranaPipelineReq { - req: PipelineRequestBody, - ret: oneshot::Sender>, + Connect { + ret: oneshot::Sender>, }, Inbound(Inbound), } @@ -117,7 +115,7 @@ impl Database { dispatcher: Arc, compaction_queue: Arc, replica_commit_store: Arc, - ) -> Self { + ) -> crate::Result { let database_id = DatabaseId::from_name(&config.db_name); match config.db_config { @@ -139,8 +137,7 @@ impl Database { Box::new(move |fno| { let _ = sender.send(Some(fno)); }), - ) - .unwrap(); + )?; let compact_interval = replication_log_compact_interval.map(|d| { let mut i = tokio::time::interval(d / 2); @@ -148,7 +145,7 @@ impl Database { Box::pin(i) }); - Self::Primary { + Ok(Self::Primary { db: PrimaryDatabase { db: Arc::new(db), replica_streams: HashMap::new(), @@ -157,7 +154,7 @@ impl Database { }, compact_interval, transaction_timeout_duration, - } + }) } DbConfig::Replica { primary_node_id, @@ -165,26 +162,24 @@ impl Database { transaction_timeout_duration, } => { let next_frame_no = - block_in_place(|| replica_commit_store.get_commit_index(database_id)) + block_in_place(|| replica_commit_store.get_commit_index(database_id))? .map(|fno| fno + 1) .unwrap_or(0); - let commit_callback = Arc::new(move |fno| { - replica_commit_store.commit(database_id, fno); - }); + let commit_callback = + Arc::new(move |fno| replica_commit_store.commit(database_id, fno).is_ok()); let rdb = LibsqlDatabase::new_replica( path, MAX_INJECTOR_BUFFER_CAPACITY, commit_callback, - ) - .unwrap(); + )?; let wdb = RemoteDb { proxy_request_timeout_duration, }; - let mut db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); - let injector = db.injector().unwrap(); + let db = WriteProxyDatabase::new(rdb, wdb, Arc::new(|_| ())); + let injector = db.injector()?; let (sender, receiver) = mpsc::channel(16); let replicator = Replicator::new( @@ -198,13 +193,13 @@ impl Database { tokio::spawn(replicator.run()); - Self::Replica { + Ok(Self::Replica { db, injector_handle: sender, primary_id: primary_node_id, last_received_frame_ts: None, transaction_timeout_duration, - } + }) } } } @@ -214,28 +209,28 @@ impl Database { connection_id: u32, alloc: &Allocation, on_txn_status_change_cb: impl Fn(bool) + Send + Sync + 'static, - ) -> impl ConnectionHandler { + ) -> crate::Result { match self { Database::Primary { db: PrimaryDatabase { db, .. }, .. } => { - let mut conn = db.connect().unwrap(); + let mut conn = db.connect()?; conn.set_on_txn_status_change_cb(on_txn_status_change_cb); - Either::Right(PrimaryConnection { conn }) + Ok(Either::Right(PrimaryConnection { conn })) } Database::Replica { db, primary_id, .. } => { - let mut conn = db.connect().unwrap(); + let mut conn = db.connect()?; conn.reader_mut() .set_on_txn_status_change_cb(on_txn_status_change_cb); - Either::Left(ReplicaConnection { + Ok(Either::Left(ReplicaConnection { conn, connection_id, next_req_id: 0, primary_node_id: *primary_id, database_id: DatabaseId::from_name(&alloc.db_name), dispatcher: alloc.dispatcher.clone(), - }) + })) } } } @@ -254,7 +249,6 @@ pub struct Allocation { pub max_concurrent_connections: u32, pub connections: HashMap>, - pub hrana_server: Arc, /// handle to the message bus pub dispatcher: Arc, pub db_name: String, @@ -267,16 +261,16 @@ pub struct ConnectionHandle { } impl ConnectionHandle { - pub async fn execute( - &self, - pgm: Program, - builder: Box, - ) -> crate::Result<()> { - self.messages - .send(ConnectionMessage::Execute { pgm, builder }) - .await - .unwrap(); - Ok(()) + pub async fn execute(&self, pgm: Program, builder: Box) { + let msg = ConnectionMessage::Execute { pgm, builder }; + if let Err(e) = self.messages.send(msg).await { + let ConnectionMessage::Execute { mut builder, .. } = e.0 else { unreachable!() }; + builder.finnalize_error("connection closed".to_string()); + } + } + + pub async fn describe(&self, sql: String) -> crate::Result { + todo!() } } @@ -288,15 +282,14 @@ impl Allocation { _ = fut => (), Some(msg) = self.inbox.recv() => { match msg { - AllocationMessage::HranaPipelineReq { req, ret } => { - let server = self.hrana_server.clone(); - handle_pipeline(server, req, ret, || async { - let conn = self.new_conn(None).await; - Ok(conn) - }).await.unwrap(); + AllocationMessage::Connect { ret } => { + let _ = ret.send(self.new_conn(None).await); } AllocationMessage::Inbound(msg) => { - self.handle_inbound(msg).await; + if let Err(e) = self.handle_inbound(msg).await { + tracing::error!("allocation loop finished with error: {e}"); + return + } } } }, @@ -310,7 +303,7 @@ impl Allocation { } } - async fn handle_inbound(&mut self, msg: Inbound) { + async fn handle_inbound(&mut self, msg: Inbound) -> crate::Result<()> { debug_assert_eq!( msg.enveloppe.database_id, Some(DatabaseId::from_name(&self.db_name)) @@ -361,7 +354,7 @@ impl Allocation { } Entry::Vacant(e) => { let handle = tokio::spawn(streamer.run()); - // For some reason, not yielding causes the task not to be spawned + // For some reason, yielding here is necessary for the task to start running tokio::task::yield_now().await; e.insert((req_no, handle)); } @@ -369,20 +362,19 @@ impl Allocation { } Database::Replica { .. } => todo!("not a primary!"), }, - Message::Frames(frames) => match &mut self.database { - Database::Replica { + Message::Frames(frames) => { + if let Database::Replica { injector_handle, last_received_frame_ts, .. - } => { + } = &mut self.database + { *last_received_frame_ts = Some(Instant::now()); - injector_handle.send(frames).await.unwrap(); + if injector_handle.send(frames).await.is_err() { + return Err(Error::InjectorExited); + } } - Database::Primary { - db: PrimaryDatabase { .. }, - .. - } => todo!("handle primary receiving txn"), - }, + } Message::ProxyRequest { connection_id, req_id, @@ -397,13 +389,17 @@ impl Allocation { .get(&self.dispatcher.node_id()) .and_then(|m| m.get(&r.connection_id).cloned()) { - conn.inbound.send(msg).await.unwrap(); + if conn.inbound.send(msg).await.is_err() { + tracing::error!("cannot process message: connection is closed"); + } } } Message::CancelRequest { .. } => todo!(), Message::CloseConnection { .. } => todo!(), Message::Error(_) => todo!(), } + + Ok(()) } async fn handle_proxy( @@ -415,11 +411,8 @@ impl Allocation { ) { let dispatcher = self.dispatcher.clone(); let database_id = DatabaseId::from_name(&self.db_name); - let exec = |conn: ConnectionHandle| async move { - let builder = - ProxyResponseBuilder::new(dispatcher, database_id, to, req_id, connection_id); - conn.execute(program, Box::new(builder)).await.unwrap(); - }; + let mut builder = + ProxyResponseBuilder::new(dispatcher, database_id, to, req_id, connection_id); if self.database.is_primary() { match self @@ -428,17 +421,21 @@ impl Allocation { .and_then(|m| m.get(&connection_id).cloned()) { Some(handle) => { - tokio::spawn(exec(handle)); - } - None => { - let handle = self.new_conn(Some((to, connection_id))).await; - tokio::spawn(exec(handle)); + tokio::spawn(async move { handle.execute(program, Box::new(builder)).await }); } + None => match self.new_conn(Some((to, connection_id))).await { + Ok(handle) => { + tokio::spawn( + async move { handle.execute(program, Box::new(builder)).await }, + ); + } + Err(e) => builder.finnalize_error(format!("error creating connection: {e}")), + }, } } } - async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> ConnectionHandle { + async fn new_conn(&mut self, remote: Option<(NodeId, u32)>) -> crate::Result { let conn_id = self.next_conn_id(); let (timeout_monitor, notifier) = timeout_monitor(); let timeout = self.database.txn_timeout_duration(); @@ -450,7 +447,7 @@ impl Allocation { notifier.disable(); } }) - }); + })?; let (messages_sender, messages_receiver) = mpsc::channel(1); let (inbound_sender, inbound_receiver) = mpsc::channel(1); @@ -465,15 +462,18 @@ impl Allocation { }; self.connections_futs.spawn(conn.run()); + let handle = ConnectionHandle { messages: messages_sender, inbound: inbound_sender, }; + self.connections .entry(id.0) .or_insert_with(HashMap::new) .insert(id.1, handle.clone()); - handle + + Ok(handle) } fn next_conn_id(&mut self) -> u32 { @@ -564,6 +564,8 @@ impl Connection { } } + tracing::debug!("connection exited: {:?}", self.id); + self.id } } @@ -578,10 +580,10 @@ mod test { use tokio::sync::Notify; use crate::allocation::replica::ReplicaConnection; + use crate::init_dirs; use crate::linc::bus::Bus; use crate::replica_commit_store::ReplicaCommitStore; use crate::snapshot_store::SnapshotStore; - use crate::{init_dirs, replica_commit_store}; use super::*; @@ -654,7 +656,7 @@ mod test { transaction_timeout_duration: Duration::from_millis(100), }, }; - let (sender, inbox) = mpsc::channel(10); + let (_sender, inbox) = mpsc::channel(10); let env = EnvOpenOptions::new() .max_dbs(10) .map_size(4096 * 100) @@ -672,24 +674,24 @@ mod test { bus.clone(), queue, replica_commit_store, - ), + ) + .unwrap(), connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, hrana_server: Arc::new(hrana::http::Server::new(None)), - dispatcher: bus, // TODO: handle self URL? + dispatcher: bus, db_name: config.db_name, connections: HashMap::new(), }; - let conn = alloc.new_conn(None).await; + let conn = alloc.new_conn(None).await.unwrap(); tokio::spawn(alloc.run()); let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); conn.execute(Program::seq(&["begin"]), Box::new(builder)) - .await - .unwrap(); + .await; rcv.await.unwrap().unwrap(); tokio::time::sleep(Duration::from_secs(1)).await; @@ -697,8 +699,7 @@ mod test { let (snd, rcv) = oneshot::channel(); let builder = StepResultsBuilder::new(snd); conn.execute(Program::seq(&["create table test (x)"]), Box::new(builder)) - .await - .unwrap(); + .await; assert!(rcv.await.unwrap().is_err()); } } diff --git a/libsqlx-server/src/allocation/primary/compactor.rs b/libsqlx-server/src/allocation/primary/compactor.rs index 5bc4c9a3..3ee2aca4 100644 --- a/libsqlx-server/src/allocation/primary/compactor.rs +++ b/libsqlx-server/src/allocation/primary/compactor.rs @@ -57,7 +57,7 @@ impl LogCompactor for Compactor { self.queue.push(&CompactionJob { database_id: self.database_id, log_id, - }); + })?; Ok(()) } diff --git a/libsqlx-server/src/allocation/primary/mod.rs b/libsqlx-server/src/allocation/primary/mod.rs index ee9e0d96..63e60e61 100644 --- a/libsqlx-server/src/allocation/primary/mod.rs +++ b/libsqlx-server/src/allocation/primary/mod.rs @@ -6,7 +6,7 @@ use std::time::Duration; use bytes::Bytes; use libsqlx::libsql::{LibsqlDatabase, PrimaryType}; -use libsqlx::result_builder::ResultBuilder; +use libsqlx::result_builder::{QueryResultBuilderError, ResultBuilder}; use libsqlx::{Connection, Frame, FrameHeader, FrameNo, LogReadError, ReplicationLogger}; use tokio::task::block_in_place; @@ -58,7 +58,7 @@ impl ProxyResponseBuilder { } } - fn maybe_send(&mut self) { + fn maybe_send(&mut self) -> crate::Result<()> { // FIXME: this is stupid: compute current buffer size on the go instead let size = self .buffer @@ -82,11 +82,13 @@ impl ProxyResponseBuilder { .sum::(); if size > MAX_STEP_BATCH_SIZE { - self.send() + self.send()?; } + + Ok(()) } - fn send(&mut self) { + fn send(&mut self) -> crate::Result<()> { let msg = Outbound { to: self.to, enveloppe: Enveloppe { @@ -101,7 +103,9 @@ impl ProxyResponseBuilder { }; self.next_seq_no += 1; - tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg)); + tokio::runtime::Handle::current().block_on(self.dispatcher.dispatch(msg))?; + + Ok(()) } } @@ -109,15 +113,17 @@ impl ResultBuilder for ProxyResponseBuilder { fn init( &mut self, _config: &libsqlx::result_builder::QueryBuilderConfig, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::Init); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn begin_step(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn begin_step(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::BeginStep); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } @@ -125,65 +131,70 @@ impl ResultBuilder for ProxyResponseBuilder { &mut self, affected_row_count: u64, last_insert_rowid: Option, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::FinishStep( affected_row_count, last_insert_rowid, )); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn step_error( - &mut self, - error: libsqlx::error::Error, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> { self.buffer .push(BuilderStep::StepError(StepError(error.to_string()))); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } fn cols_description( &mut self, cols: &mut dyn Iterator, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer .push(BuilderStep::ColsDesc(cols.map(Into::into).collect())); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn begin_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn begin_rows(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::BeginRows); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn begin_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn begin_row(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::BeginRow); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } fn add_row_value( &mut self, v: libsqlx::result_builder::ValueRef, - ) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + ) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::AddRowValue(v.into())); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn finish_row(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn finish_row(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::FinishRow); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } - fn finish_rows(&mut self) -> Result<(), libsqlx::result_builder::QueryResultBuilderError> { + fn finish_rows(&mut self) -> Result<(), QueryResultBuilderError> { self.buffer.push(BuilderStep::FinishRows); - self.maybe_send(); + self.maybe_send() + .map_err(QueryResultBuilderError::from_any)?; Ok(()) } @@ -191,10 +202,10 @@ impl ResultBuilder for ProxyResponseBuilder { &mut self, is_txn: bool, frame_no: Option, - ) -> Result { + ) -> Result { self.buffer .push(BuilderStep::Finnalize { is_txn, frame_no }); - self.send(); + self.send().map_err(QueryResultBuilderError::from_any)?; Ok(true) } } @@ -238,26 +249,31 @@ impl FrameStreamer { } } Err(LogReadError::Error(_)) => todo!("handle log read error"), - Err(LogReadError::SnapshotRequired) => self.send_snapshot().await, + Err(LogReadError::SnapshotRequired) => { + if let Err(e) = self.send_snapshot().await { + tracing::error!("error sending snapshot: {e}"); + break; + } + } } } } - async fn send_snapshot(&mut self) { + async fn send_snapshot(&mut self) -> crate::Result<()> { tracing::debug!("sending frames from snapshot"); loop { match self .snapshot_store - .locate_file(self.database_id, self.next_frame_no) + .locate_file(self.database_id, self.next_frame_no)? { Some(file) => { let mut iter = file.frames_iter_from(self.next_frame_no).peekable(); while let Some(frame) = block_in_place(|| iter.next()) { - let frame = frame.unwrap(); + let frame = frame?; // TODO: factorize in maybe_send if self.buffer.len() > FRAMES_MESSAGE_MAX_COUNT { - self.send_frames().await; + self.send_frames().await?; } let size_after = iter .peek() @@ -287,9 +303,11 @@ impl FrameStreamer { } } } + + Ok(()) } - async fn send_frames(&mut self) { + async fn send_frames(&mut self) -> crate::Result<()> { let frames = std::mem::take(&mut self.buffer); let outbound = Outbound { to: self.node_id, @@ -303,7 +321,9 @@ impl FrameStreamer { }, }; self.seq_no += 1; - self.dipatcher.dispatch(outbound).await; + self.dipatcher.dispatch(outbound).await?; + + Ok(()) } } @@ -320,9 +340,7 @@ impl ConnectionHandler for PrimaryConnection { async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match msg { ConnectionMessage::Execute { pgm, builder } => { - block_in_place(|| { - self.conn.execute_program(&pgm, builder).unwrap() - }) + block_in_place(|| self.conn.execute_program(&pgm, builder)) } ConnectionMessage::Describe => { todo!() diff --git a/libsqlx-server/src/allocation/replica.rs b/libsqlx-server/src/allocation/replica.rs index b1422e5b..3d6e81e0 100644 --- a/libsqlx-server/src/allocation/replica.rs +++ b/libsqlx-server/src/allocation/replica.rs @@ -60,7 +60,7 @@ impl libsqlx::Connection for RemoteConn { &mut self, program: &libsqlx::program::Program, builder: Box, - ) -> libsqlx::Result<()> { + ) { // When we need to proxy a query, we place it in the current request slot. When we are // back in a async context, we'll send it to the primary, and asynchrously drive the // builder. @@ -75,8 +75,6 @@ impl libsqlx::Connection for RemoteConn { timeout: Box::pin(tokio::time::sleep(self.inner.request_timeout_duration)), }), }; - - Ok(()) } fn describe(&self, _sql: String) -> libsqlx::Result { @@ -130,7 +128,15 @@ impl Replicator { } pub async fn run(mut self) { - self.query_replicate().await; + macro_rules! ok_or_log { + ($e:expr) => { + if let Err(e) = $e { + tracing::warn!("failed to start replication process: {e}"); + } + }; + } + + ok_or_log!(self.query_replicate().await); loop { match timeout(Duration::from_secs(5), self.receiver.recv()).await { Ok(Some(Frames { @@ -147,7 +153,7 @@ impl Replicator { // this is not the batch of frame we were expecting, drop what we have, and // ask again from last checkpoint tracing::debug!(seq, self.next_seq, "wrong seq"); - self.query_replicate().await; + ok_or_log!(self.query_replicate().await); continue; }; self.next_seq += 1; @@ -155,23 +161,32 @@ impl Replicator { tracing::debug!("injecting {} frames", frames.len()); for bytes in frames { - let frame = Frame::try_from_bytes(bytes).unwrap(); - block_in_place(|| { - if let Some(last_committed) = self.injector.inject(frame).unwrap() { - tracing::debug!(last_committed); - self.next_frame_no = last_committed + 1; - } - }); + let inject = || -> crate::Result<()> { + let frame = Frame::try_from_bytes(bytes)?; + block_in_place(|| { + if let Some(last_committed) = self.injector.inject(frame).unwrap() { + tracing::debug!(last_committed); + self.next_frame_no = last_committed + 1; + } + Ok(()) + }) + }; + + if let Err(e) = inject() { + tracing::error!("error injecting frames: {e}"); + ok_or_log!(self.query_replicate().await); + break; + } } } // no news from primary for the past 5 secs, send a request again - Err(_) => self.query_replicate().await, + Err(_) => ok_or_log!(self.query_replicate().await), Ok(None) => break, } } } - async fn query_replicate(&mut self) { + async fn query_replicate(&mut self) -> crate::Result<()> { tracing::debug!("seinding replication request"); self.req_id += 1; self.next_seq = 0; @@ -188,7 +203,9 @@ impl Replicator { }, }, }) - .await; + .await?; + + Ok(()) } } @@ -290,7 +307,7 @@ impl ConnectionHandler for ReplicaConnection { async fn handle_conn_message(&mut self, msg: ConnectionMessage) { match msg { ConnectionMessage::Execute { pgm, builder } => { - self.conn.execute_program(&pgm, builder).unwrap(); + self.conn.execute_program(&pgm, builder); let msg = { let mut lock = self.conn.writer().inner.current_req.lock(); match *lock { diff --git a/libsqlx-server/src/compactor.rs b/libsqlx-server/src/compactor.rs index 2039343a..060a1dda 100644 --- a/libsqlx-server/src/compactor.rs +++ b/libsqlx-server/src/compactor.rs @@ -47,7 +47,7 @@ impl CompactionQueue { env: heed::Env, db_path: PathBuf, snapshot_store: Arc, - ) -> color_eyre::Result { + ) -> crate::Result { let mut txn = env.write_txn()?; let queue = env.create_database(&mut txn, Some(Self::COMPACTION_QUEUE_DB_NAME))?; let next_id = match queue.last(&mut txn)? { @@ -67,43 +67,47 @@ impl CompactionQueue { }) } - pub fn push(&self, job: &CompactionJob) { + pub fn push(&self, job: &CompactionJob) -> crate::Result<()> { tracing::debug!("new compaction job available: {job:?}"); - let mut txn = self.env.write_txn().unwrap(); + let mut txn = self.env.write_txn()?; let id = self.next_id.fetch_add(1, Ordering::Relaxed); - self.queue.put(&mut txn, &id, job).unwrap(); - txn.commit().unwrap(); + self.queue.put(&mut txn, &id, job)?; + txn.commit()?; self.notify.send_replace(Some(id)); + + Ok(()) } - pub async fn peek(&self) -> (u64, CompactionJob) { + pub async fn peek(&self) -> crate::Result<(u64, CompactionJob)> { let id = self.next_id.load(Ordering::Relaxed); - let txn = block_in_place(|| self.env.read_txn().unwrap()); - match block_in_place(|| self.queue.first(&txn).unwrap()) { - Some(job) => job, - None => { - drop(txn); - self.notify - .subscribe() - .wait_for(|x| x.map(|x| x >= id).unwrap_or_default()) - .await - .unwrap(); - block_in_place(|| { - let txn = self.env.read_txn().unwrap(); - self.queue.first(&txn).unwrap().unwrap() - }) + let peek = || { + block_in_place(|| -> crate::Result<_> { + let txn = self.env.read_txn()?; + Ok(self.queue.first(&txn)?) + }) + }; + + loop { + match peek()? { + Some(job) => return Ok(job), + None => { + self.notify + .subscribe() + .wait_for(|x| x.map(|x| x >= id).unwrap_or_default()) + .await + .expect("we're holding the other side of the channel!"); + } } } } - fn complete(&self, txn: &mut heed::RwTxn, job_id: u64) { - block_in_place(|| { - self.queue.delete(txn, &job_id).unwrap(); - }); + fn complete(&self, txn: &mut heed::RwTxn, job_id: u64) -> crate::Result<()> { + block_in_place(|| self.queue.delete(txn, &job_id))?; + Ok(()) } - async fn compact(&self) -> color_eyre::Result<()> { - let (job_id, job) = self.peek().await; + async fn compact(&self) -> crate::Result<()> { + let (job_id, job) = self.peek().await?; tracing::debug!("starting new compaction job: {job:?}"); let to_compact_path = self.snapshot_queue_dir().join(job.log_id.to_string()); let (start_fno, end_fno) = tokio::task::spawn_blocking({ @@ -127,12 +131,15 @@ impl CompactionQueue { builder.finish() } }) - .await??; + .await + .map_err(|_| { + crate::error::Error::Internal(color_eyre::eyre::anyhow!("compaction thread panicked")) + })??; let mut txn = self.env.write_txn()?; - self.complete(&mut txn, job_id); + self.complete(&mut txn, job_id)?; self.snapshot_store - .register(&mut txn, job.database_id, start_fno, end_fno, job.log_id); + .register(&mut txn, job.database_id, start_fno, end_fno, job.log_id)?; txn.commit()?; std::fs::remove_file(to_compact_path)?; @@ -193,6 +200,7 @@ pub struct SnapshotFrame { impl SnapshotFrame { const SIZE: usize = size_of::() + 4096; + #[cfg(test)] pub fn try_from_bytes(data: Bytes) -> crate::Result { if data.len() != Self::SIZE { color_eyre::eyre::bail!("invalid snapshot frame") @@ -220,7 +228,7 @@ impl SnapshotBuilder { snapshot_id: Uuid, start_fno: FrameNo, end_fno: FrameNo, - ) -> color_eyre::Result { + ) -> crate::Result { let temp_dir = db_path.join("tmp"); let mut target = BufWriter::new(NamedTempFile::new_in(&temp_dir)?); // reserve header space @@ -242,7 +250,7 @@ impl SnapshotBuilder { }) } - pub fn push_frame(&mut self, frame: Frame) -> color_eyre::Result<()> { + pub fn push_frame(&mut self, frame: Frame) -> crate::Result<()> { assert!(frame.header().frame_no < self.last_seen_frame_no); self.last_seen_frame_no = frame.header().frame_no; @@ -265,16 +273,21 @@ impl SnapshotBuilder { } /// Persist the snapshot, and returns the name and size is frame on the snapshot. - pub fn finish(mut self) -> color_eyre::Result<(FrameNo, FrameNo)> { + pub fn finish(mut self) -> crate::Result<(FrameNo, FrameNo)> { self.snapshot_file.flush()?; - let file = self.snapshot_file.into_inner()?; + let file = self + .snapshot_file + .into_inner() + .map_err(|e| crate::error::Error::Internal(e.into()))?; + file.as_file().write_all_at(bytes_of(&self.header), 0)?; let path = self .db_path .join("snapshots") .join(self.snapshot_id.to_string()); - file.persist(path)?; + file.persist(path) + .map_err(|e| crate::error::Error::Internal(e.into()))?; Ok((self.header.start_frame_no, self.header.end_frame_no)) } @@ -286,7 +299,7 @@ pub struct SnapshotFile { } impl SnapshotFile { - pub fn open(path: &Path) -> color_eyre::Result { + pub fn open(path: &Path) -> crate::Result { let file = File::open(path)?; let mut header_buf = [0; size_of::()]; file.read_exact_at(&mut header_buf, 0)?; diff --git a/libsqlx-server/src/database.rs b/libsqlx-server/src/database.rs index 4945cd70..2147a85d 100644 --- a/libsqlx-server/src/database.rs +++ b/libsqlx-server/src/database.rs @@ -1,22 +1,20 @@ use tokio::sync::{mpsc, oneshot}; -use crate::allocation::AllocationMessage; -use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; +use crate::allocation::{AllocationMessage, ConnectionHandle}; pub struct Database { pub sender: mpsc::Sender, } impl Database { - pub async fn hrana_pipeline( - &self, - req: PipelineRequestBody, - ) -> crate::Result { - let (sender, ret) = oneshot::channel(); + pub async fn connect(&self) -> crate::Result { + let (ret, conn) = oneshot::channel(); self.sender - .send(AllocationMessage::HranaPipelineReq { req, ret: sender }) + .send(AllocationMessage::Connect { ret }) .await - .unwrap(); - ret.await.unwrap() + .map_err(|_| crate::error::Error::AllocationClosed)?; + + conn.await + .map_err(|_| crate::error::Error::ConnectionClosed)? } } diff --git a/libsqlx-server/src/error.rs b/libsqlx-server/src/error.rs new file mode 100644 index 00000000..f62bf20c --- /dev/null +++ b/libsqlx-server/src/error.rs @@ -0,0 +1,21 @@ +use crate::meta::AllocationError; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error(transparent)] + Libsqlx(#[from] libsqlx::error::Error), + #[error("replica injector loop exited")] + InjectorExited, + #[error("connection closed")] + ConnectionClosed, + #[error(transparent)] + Io(#[from] std::io::Error), + #[error("allocation closed")] + AllocationClosed, + #[error("internal error: {0}")] + Internal(color_eyre::eyre::Error), + #[error(transparent)] + Heed(#[from] heed::Error), + #[error(transparent)] + Allocation(#[from] AllocationError), +} diff --git a/libsqlx-server/src/hrana/batch.rs b/libsqlx-server/src/hrana/batch.rs index 14cfb1c3..d01bc8ac 100644 --- a/libsqlx-server/src/hrana/batch.rs +++ b/libsqlx-server/src/hrana/batch.rs @@ -1,20 +1,22 @@ use std::collections::HashMap; +use std::sync::Arc; -use crate::allocation::ConnectionHandle; -use crate::hrana::stmt::StmtError; - -use super::result_builder::HranaBatchProtoBuilder; -use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error}; -use super::{proto, ProtocolError, Version}; +// use crate::auth::Authenticated; -use color_eyre::eyre::anyhow; use libsqlx::analysis::Statement; use libsqlx::program::{Cond, Program, Step}; use libsqlx::query::{Params, Query}; use libsqlx::result_builder::{StepResult, StepResultsBuilder}; use tokio::sync::oneshot; -fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre::Result { +use crate::allocation::ConnectionHandle; + +use super::error::HranaError; +use super::result_builder::HranaBatchProtoBuilder; +use super::stmt::{proto_stmt_to_query, StmtError}; +use super::{proto, ProtocolError, Version}; + +fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> Result { let try_convert_step = |step: i32| -> Result { let step = usize::try_from(step).map_err(|_| ProtocolError::BatchCondBadStep)?; if step >= max_step_i { @@ -22,6 +24,7 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre: } Ok(step) }; + let cond = match cond { proto::BatchCond::Ok { step } => Cond::Ok { step: try_convert_step(*step)?, @@ -36,13 +39,13 @@ fn proto_cond_to_cond(cond: &proto::BatchCond, max_step_i: usize) -> color_eyre: conds: conds .iter() .map(|cond| proto_cond_to_cond(cond, max_step_i)) - .collect::>()?, + .collect::>()?, }, proto::BatchCond::Or { conds } => Cond::Or { conds: conds .iter() .map(|cond| proto_cond_to_cond(cond, max_step_i)) - .collect::>()?, + .collect::>()?, }, }; @@ -53,7 +56,7 @@ pub fn proto_batch_to_program( batch: &proto::Batch, sqls: &HashMap, version: Version, -) -> color_eyre::Result { +) -> Result { let mut steps = Vec::with_capacity(batch.steps.len()); for (step_i, step) in batch.steps.iter().enumerate() { let query = proto_stmt_to_query(&step.stmt, sqls, version)?; @@ -71,19 +74,27 @@ pub fn proto_batch_to_program( } pub async fn execute_batch( - db: &ConnectionHandle, + conn: &ConnectionHandle, + // auth: Authenticated, pgm: Program, -) -> color_eyre::Result { +) -> Result { let (builder, ret) = HranaBatchProtoBuilder::new(); - db.execute(pgm, Box::new(builder)).await?; + conn.execute( + pgm, + // auth, + Box::new(builder), + ) + .await; - Ok(ret.await?) + Ok(ret + .await + .map_err(|_| crate::error::Error::ConnectionClosed)?) } -pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { +pub fn proto_sequence_to_program(sql: &str) -> Result { let stmts = Statement::parse(sql) - .collect::>>() - .map_err(|err| anyhow!(StmtError::SqlParse { source: err.into() }))?; + .collect::, libsqlx::error::Error>>() + .map_err(|err| StmtError::SqlParse { source: err.into() })?; let steps = stmts .into_iter() @@ -100,25 +111,32 @@ pub fn proto_sequence_to_program(sql: &str) -> color_eyre::Result { }; Step { cond, query } }) - .collect(); + .collect::>(); Ok(Program { steps }) } -pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_eyre::Result<()> { - let (snd, rcv) = oneshot::channel(); - let builder = StepResultsBuilder::new(snd); - conn.execute(pgm, Box::new(builder)).await?; +pub async fn execute_sequence( + conn: &ConnectionHandle, + // auth: Authenticated, + pgm: Program, +) -> Result<(), HranaError> { + let (send, ret) = oneshot::channel(); + let builder = StepResultsBuilder::new(send); + conn.execute( + pgm, + // auth, + Box::new(builder), + ) + .await; - rcv.await? - .map_err(|e| anyhow!("{e}"))? + ret.await + .unwrap() + .unwrap() .into_iter() .try_for_each(|result| match result { StepResult::Ok => Ok(()), - StepResult::Err(e) => match stmt_error_from_sqld_error(e) { - Ok(stmt_err) => Err(anyhow!(stmt_err)), - Err(sqld_err) => Err(anyhow!(sqld_err)), - }, - StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")), + StepResult::Err(e) => Err(crate::error::Error::from(e))?, + StepResult::Skipped => todo!(), // Err(anyhow!("Statement in sequence was not executed")), }) } diff --git a/libsqlx-server/src/hrana/error.rs b/libsqlx-server/src/hrana/error.rs new file mode 100644 index 00000000..2324887a --- /dev/null +++ b/libsqlx-server/src/hrana/error.rs @@ -0,0 +1,30 @@ +use super::http::request::StreamResponseError; +use super::http::StreamError; +use super::stmt::StmtError; +use super::ProtocolError; + +#[derive(Debug, thiserror::Error)] +pub enum HranaError { + #[error(transparent)] + Stmt(#[from] StmtError), + #[error(transparent)] + Proto(#[from] ProtocolError), + #[error(transparent)] + Stream(#[from] StreamError), + #[error(transparent)] + StreamResponse(#[from] StreamResponseError), + #[error(transparent)] + Libsqlx(crate::error::Error), +} + +impl HranaError { + pub fn code(&self) -> Option<&str>{ + match self { + HranaError::Stmt(e) => Some(e.code()), + HranaError::StreamResponse(e) => Some(e.code()), + HranaError::Stream(_) + | HranaError::Libsqlx(_) + | HranaError::Proto(_) => None, + } + } +} diff --git a/libsqlx-server/src/hrana/http/mod.rs b/libsqlx-server/src/hrana/http/mod.rs index 521d33ff..790a3c4f 100644 --- a/libsqlx-server/src/hrana/http/mod.rs +++ b/libsqlx-server/src/hrana/http/mod.rs @@ -1,19 +1,12 @@ -use std::sync::Arc; - -use color_eyre::eyre::Context; -use futures::Future; use parking_lot::Mutex; -use serde::{de::DeserializeOwned, Serialize}; -use tokio::sync::oneshot; - -use crate::allocation::ConnectionHandle; - -use self::proto::{PipelineRequestBody, PipelineResponseBody}; -use super::ProtocolError; +use super::error::HranaError; +// use crate::auth::Authenticated; +use crate::database::Database; +pub use stream::StreamError; pub mod proto; -mod request; +pub mod request; mod stream; pub struct Server { @@ -22,12 +15,6 @@ pub struct Server { stream_state: Mutex, } -#[derive(Debug)] -pub enum Route { - GetIndex, - PostPipeline, -} - impl Server { pub fn new(self_url: Option) -> Self { Self { @@ -42,90 +29,23 @@ impl Server { } } -fn handle_index() -> color_eyre::Result> { - Ok(text_response( - hyper::StatusCode::OK, - "Hello, this is HTTP API v2 (Hrana over HTTP)".into(), - )) -} - -pub async fn handle_pipeline( - server: Arc, - req: PipelineRequestBody, - ret: oneshot::Sender>, - mk_conn: F, -) -> color_eyre::Result<()> -where - F: FnOnce() -> Fut, - Fut: Future>, -{ - let mut stream_guard = stream::acquire(server.clone(), req.baton.as_deref(), mk_conn).await?; - - tokio::spawn(async move { - let f = async move { - let mut results = Vec::with_capacity(req.requests.len()); - for request in req.requests.into_iter() { - let result = request::handle(&mut stream_guard, request) - .await - .context("Could not execute a request in pipeline")?; - results.push(result); - } - - Ok(proto::PipelineResponseBody { - baton: stream_guard.release(), - base_url: server.self_url.clone(), - results, - }) - }; - - let _ = ret.send(f.await); - }); - - Ok(()) -} - -async fn read_request_json( - req: hyper::Request, -) -> color_eyre::Result { - let req_body = hyper::body::to_bytes(req.into_body()) - .await - .context("Could not read request body")?; - let req_body = serde_json::from_slice(&req_body) - .map_err(|err| ProtocolError::Deserialize { source: err }) - .context("Could not deserialize JSON request body")?; - Ok(req_body) -} - -fn protocol_error_response(err: ProtocolError) -> hyper::Response { - text_response(hyper::StatusCode::BAD_REQUEST, err.to_string()) -} - -fn stream_error_response(err: stream::StreamError) -> hyper::Response { - json_response( - hyper::StatusCode::INTERNAL_SERVER_ERROR, - &proto::Error { - message: err.to_string(), - code: err.code().into(), - }, - ) -} +pub async fn handle_pipeline( + server: &Server, + // auth: Authenticated, + req: proto::PipelineRequestBody, + db: Database, +) -> crate::Result { + let mut stream_guard = stream::acquire(server, req.baton.as_deref(), db).await?; -fn json_response( - status: hyper::StatusCode, - resp_body: &T, -) -> hyper::Response { - let resp_body = serde_json::to_vec(resp_body).unwrap(); - hyper::Response::builder() - .status(status) - .header(hyper::http::header::CONTENT_TYPE, "application/json") - .body(hyper::Body::from(resp_body)) - .unwrap() -} + let mut results = Vec::with_capacity(req.requests.len()); + for request in req.requests.into_iter() { + let result = request::handle(&mut stream_guard, /*auth,*/ request).await?; + results.push(result); + } -fn text_response(status: hyper::StatusCode, resp_body: String) -> hyper::Response { - hyper::Response::builder() - .status(status) - .header(hyper::http::header::CONTENT_TYPE, "text/plain") - .body(hyper::Body::from(resp_body)) - .unwrap() + Ok(proto::PipelineResponseBody { + baton: stream_guard.release(), + base_url: server.self_url.clone(), + results, + }) } diff --git a/libsqlx-server/src/hrana/http/request.rs b/libsqlx-server/src/hrana/http/request.rs index eb1623cd..eaf84eb5 100644 --- a/libsqlx-server/src/hrana/http/request.rs +++ b/libsqlx-server/src/hrana/http/request.rs @@ -1,7 +1,9 @@ -use color_eyre::eyre::{anyhow, bail}; +use crate::hrana::error::HranaError; +use crate::hrana::ProtocolError; -use super::super::{batch, stmt, ProtocolError, Version}; +use super::super::{batch, stmt, Version}; use super::{proto, stream}; +// use crate::auth::Authenticated; /// An error from executing a [`proto::StreamRequest`] #[derive(thiserror::Error, Debug)] @@ -13,77 +15,75 @@ pub enum StreamResponseError { } pub async fn handle( - stream_guard: &mut stream::Guard, + stream_guard: &mut stream::Guard<'_>, + // auth: Authenticated, request: proto::StreamRequest, -) -> color_eyre::Result { - let result = match try_handle(stream_guard, request).await { +) -> Result { + let result = match try_handle(stream_guard /*, auth*/, request).await { Ok(response) => proto::StreamResult::Ok { response }, Err(err) => { - let resp_err = err.downcast::()?; - let error = proto::Error { - message: resp_err.to_string(), - code: resp_err.code().into(), - }; - proto::StreamResult::Error { error } + if let HranaError::StreamResponse(err) = err { + let error = proto::Error { + message: err.to_string(), + code: err.code().into(), + }; + proto::StreamResult::Error { error } + } else { + Err(err)? + } } }; Ok(result) } async fn try_handle( - stream_guard: &mut stream::Guard, + stream_guard: &mut stream::Guard<'_>, + // auth: Authenticated, request: proto::StreamRequest, -) -> color_eyre::Result { +) -> crate::Result { Ok(match request { proto::StreamRequest::Close(_req) => { - stream_guard.close_db(); + stream_guard.close_conn(); proto::StreamResponse::Close(proto::CloseStreamResp {}) } proto::StreamRequest::Execute(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); - let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2) - .map_err(catch_stmt_error)?; - let result = stmt::execute_stmt(db, query) - .await - .map_err(catch_stmt_error)?; + let query = stmt::proto_stmt_to_query(&req.stmt, sqls, Version::Hrana2)?; + let result = stmt::execute_stmt(db, /*auth,*/ query).await?; proto::StreamResponse::Execute(proto::ExecuteStreamResp { result }) } proto::StreamRequest::Batch(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let pgm = batch::proto_batch_to_program(&req.batch, sqls, Version::Hrana2)?; - let result = batch::execute_batch(db, pgm).await?; + let result = batch::execute_batch(db, /*auth,*/ pgm).await?; proto::StreamResponse::Batch(proto::BatchStreamResp { result }) } proto::StreamRequest::Sequence(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; - let pgm = batch::proto_sequence_to_program(sql).map_err(catch_stmt_error)?; - batch::execute_sequence(db, pgm) - .await - .map_err(catch_stmt_error)?; + let pgm = batch::proto_sequence_to_program(sql)?; + batch::execute_sequence(db, /*auth,*/ pgm).await?; proto::StreamResponse::Sequence(proto::SequenceStreamResp {}) } proto::StreamRequest::Describe(req) => { - let db = stream_guard.get_db()?; + let db = stream_guard.get_conn()?; let sqls = stream_guard.sqls(); let sql = stmt::proto_sql_to_sql(req.sql.as_deref(), req.sql_id, sqls, Version::Hrana2)?; - let result = stmt::describe_stmt(db, sql.into()) - .await - .map_err(catch_stmt_error)?; + let result = stmt::describe_stmt(db, /* auth,*/ sql.into()).await?; proto::StreamResponse::Describe(proto::DescribeStreamResp { result }) } proto::StreamRequest::StoreSql(req) => { let sqls = stream_guard.sqls_mut(); let sql_id = req.sql_id; if sqls.contains_key(&sql_id) { - bail!(ProtocolError::SqlExists { sql_id }) + Err(ProtocolError::SqlExists { sql_id })? } else if sqls.len() >= MAX_SQL_COUNT { - bail!(StreamResponseError::SqlTooMany { count: sqls.len() }) + Err(StreamResponseError::SqlTooMany { count: sqls.len() })? } sqls.insert(sql_id, req.sql); proto::StreamResponse::StoreSql(proto::StoreSqlStreamResp {}) @@ -98,13 +98,6 @@ async fn try_handle( const MAX_SQL_COUNT: usize = 50; -fn catch_stmt_error(err: color_eyre::eyre::Error) -> color_eyre::eyre::Error { - match err.downcast::() { - Ok(stmt_err) => anyhow!(StreamResponseError::Stmt(stmt_err)), - Err(err) => err, - } -} - impl StreamResponseError { pub fn code(&self) -> &'static str { match self { diff --git a/libsqlx-server/src/hrana/http/stream.rs b/libsqlx-server/src/hrana/http/stream.rs index 25c1e719..1320df90 100644 --- a/libsqlx-server/src/hrana/http/stream.rs +++ b/libsqlx-server/src/hrana/http/stream.rs @@ -1,19 +1,18 @@ -use std::cmp::Reverse; -use std::collections::{HashMap, VecDeque}; -use std::pin::Pin; -use std::sync::Arc; -use std::{future, mem, task}; - use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD}; -use color_eyre::eyre::{anyhow, WrapErr}; use futures::Future; use hmac::Mac as _; use priority_queue::PriorityQueue; +use std::cmp::Reverse; +use std::collections::{HashMap, VecDeque}; +use std::pin::Pin; +use std::{future, mem, task}; use tokio::time::{Duration, Instant}; use super::super::ProtocolError; use super::Server; use crate::allocation::ConnectionHandle; +use crate::database::Database; +use crate::hrana::error::HranaError; /// Mutable state related to streams, owned by [`Server`] and protected with a mutex. pub struct ServerStreamState { @@ -68,8 +67,8 @@ struct Stream { /// Guard object that is used to access a stream from the outside. The guard makes sure that the /// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with /// [`Handle::Available`] after the guard goes out of scope. -pub struct Guard { - server: Arc, +pub struct Guard<'srv> { + server: &'srv Server, /// The guarded stream. This is only set to `None` in the destructor. stream: Option>, /// If set to `true`, the destructor will release the stream for further use (saving it as @@ -102,42 +101,37 @@ impl ServerStreamState { /// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream, /// otherwise we create a new stream. -pub async fn acquire( - server: Arc, +pub async fn acquire<'srv>( + server: &'srv Server, baton: Option<&str>, - mk_conn: F, -) -> color_eyre::Result -where - F: FnOnce() -> Fut, - Fut: Future>, -{ + db: Database, +) -> Result, HranaError> { let stream = match baton { Some(baton) => { - let (stream_id, baton_seq) = decode_baton(&server, baton)?; + let (stream_id, baton_seq) = decode_baton(server, baton)?; let mut state = server.stream_state.lock(); let handle = state.handles.get_mut(&stream_id); match handle { None => { - return Err(ProtocolError::BatonInvalid(format!( - "Stream handle for {stream_id} was not found" - )) - .into()) + Err(ProtocolError::BatonInvalid { + reason: format!("Stream handle for {stream_id} was not found"), + })?; } Some(Handle::Acquired) => { - return Err(ProtocolError::BatonReused) - .context(format!("Stream handle for {stream_id} is acquired")); - } - Some(Handle::Expired) => { - return Err(StreamError::StreamExpired) - .context(format!("Stream handle for {stream_id} is expired")); + Err(ProtocolError::BatonReused { + reason: format!("Stream handle for {stream_id} is acquired"), + })?; } + Some(Handle::Expired) => Err(StreamError::StreamExpired)?, Some(Handle::Available(stream)) => { if stream.baton_seq != baton_seq { - return Err(ProtocolError::BatonReused).context(format!( - "Expected baton seq {}, received {baton_seq}", - stream.baton_seq - )); + Err(ProtocolError::BatonReused { + reason: format!( + "Expected baton seq {}, received {baton_seq}", + stream.baton_seq + ), + })?; } } }; @@ -154,10 +148,7 @@ where stream } None => { - let conn = mk_conn() - .await - .context("Could not create a database connection")?; - + let conn = db.connect().await?; let mut state = server.stream_state.lock(); let stream = Box::new(Stream { conn: Some(conn), @@ -183,15 +174,15 @@ where }) } -impl Guard { - pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> { +impl<'srv> Guard<'srv> { + pub fn get_conn(&self) -> Result<&ConnectionHandle, ProtocolError> { let stream = self.stream.as_ref().unwrap(); stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed) } /// Closes the database connection. The next call to [`Guard::release()`] will then remove the /// stream. - pub fn close_db(&mut self) { + pub fn close_conn(&mut self) { let stream = self.stream.as_mut().unwrap(); stream.conn = None; } @@ -212,7 +203,7 @@ impl Guard { if stream.conn.is_some() { self.release = true; // tell destructor to make the stream available again Some(encode_baton( - &self.server, + self.server, stream.stream_id, stream.baton_seq, )) @@ -222,7 +213,7 @@ impl Guard { } } -impl Drop for Guard { +impl<'srv> Drop for Guard<'srv> { fn drop(&mut self) { let stream = self.stream.take().unwrap(); let stream_id = stream.stream_id; @@ -289,17 +280,18 @@ fn encode_baton(server: &Server, stream_id: u64, baton_seq: u64) -> String { /// Decodes a baton encoded with `encode_baton()` and returns `(stream_id, baton_seq)`. Always /// returns a [`ProtocolError::BatonInvalid`] if the baton is invalid, but it attaches an anyhow /// context that describes the precise cause. -fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u64)> { - let baton_data = BASE64_STANDARD_NO_PAD.decode(baton_str).map_err(|err| { - ProtocolError::BatonInvalid(format!("Could not base64-decode baton: {err}")) - })?; +fn decode_baton(server: &Server, baton_str: &str) -> crate::Result<(u64, u64), HranaError> { + let baton_data = + BASE64_STANDARD_NO_PAD + .decode(baton_str) + .map_err(|err| ProtocolError::BatonInvalid { + reason: format!("Could not base64-decode baton: {err}"), + })?; if baton_data.len() != 48 { - return Err(ProtocolError::BatonInvalid(format!( - "Baton has invalid size of {} bytes", - baton_data.len() - )) - .into()); + Err(ProtocolError::BatonInvalid { + reason: format!("Baton has invalid size of {} bytes", baton_data.len()), + })?; } let payload = &baton_data[0..16]; @@ -307,11 +299,10 @@ fn decode_baton(server: &Server, baton_str: &str) -> color_eyre::Result<(u64, u6 let mut hmac = hmac::Hmac::::new_from_slice(&server.baton_key).unwrap(); hmac.update(payload); - hmac.verify_slice(received_mac).map_err(|_| { - anyhow!(ProtocolError::BatonInvalid( - "Invalid MAC on baton".to_string() - )) - })?; + hmac.verify_slice(received_mac) + .map_err(|_| ProtocolError::BatonInvalid { + reason: "Invalid MAC on baton".into(), + })?; let stream_id = u64::from_be_bytes(payload[0..8].try_into().unwrap()); let baton_seq = u64::from_be_bytes(payload[8..16].try_into().unwrap()); diff --git a/libsqlx-server/src/hrana/mod.rs b/libsqlx-server/src/hrana/mod.rs index fc85fcfe..8f4c7f68 100644 --- a/libsqlx-server/src/hrana/mod.rs +++ b/libsqlx-server/src/hrana/mod.rs @@ -6,6 +6,7 @@ pub mod proto; mod result_builder; pub mod stmt; // pub mod ws; +pub mod error; #[derive(Debug, Copy, Clone, PartialOrd, Ord, PartialEq, Eq)] pub enum Version { @@ -50,10 +51,10 @@ pub enum ProtocolError { #[error("Invalid reference to step in a batch condition")] BatchCondBadStep, - #[error("Received an invalid baton: {0}")] - BatonInvalid(String), - #[error("Received a baton that has already been used")] - BatonReused, + #[error("Received an invalid baton: {reason}")] + BatonInvalid { reason: String }, + #[error("Received a baton that has already been used: {reason}")] + BatonReused { reason: String }, #[error("Stream for this baton was closed")] BatonStreamClosed, diff --git a/libsqlx-server/src/hrana/result_builder.rs b/libsqlx-server/src/hrana/result_builder.rs index c0c597bf..3985795b 100644 --- a/libsqlx-server/src/hrana/result_builder.rs +++ b/libsqlx-server/src/hrana/result_builder.rs @@ -5,19 +5,20 @@ use bytes::Bytes; use libsqlx::{result_builder::*, FrameNo}; use tokio::sync::oneshot; -use crate::hrana::stmt::{proto_error_from_stmt_error, stmt_error_from_sqld_error}; +use crate::hrana::stmt::proto_error_from_stmt_error; +use super::error::HranaError; use super::proto; pub struct SingleStatementBuilder { builder: StatementBuilder, - ret: Option>>, + ret: Option>>, } impl SingleStatementBuilder { pub fn new() -> ( Self, - oneshot::Receiver>, + oneshot::Receiver>, ) { let (ret, rcv) = oneshot::channel(); ( @@ -199,9 +200,9 @@ impl StatementBuilder { Ok(()) } - pub fn take_ret(&mut self) -> Result { + pub fn take_ret(&mut self) -> crate::Result { match self.err.take() { - Some(err) => Err(err), + Some(err) => Err(crate::error::Error::from(err))?, None => Ok(proto::StmtResult { cols: std::mem::take(&mut self.cols), rows: std::mem::take(&mut self.rows), @@ -270,7 +271,7 @@ pub struct HranaBatchProtoBuilder { current_size: u64, max_response_size: u64, step_empty: bool, - ret: oneshot::Sender, + ret: Option>, } impl HranaBatchProtoBuilder { @@ -284,15 +285,16 @@ impl HranaBatchProtoBuilder { current_size: 0, max_response_size: u64::MAX, step_empty: false, - ret, + ret: Some(ret), }, rcv, ) } - pub fn into_ret(self) -> proto::BatchResult { + + pub fn into_ret(&mut self) -> proto::BatchResult { proto::BatchResult { - step_results: self.step_results, - step_errors: self.step_errors, + step_results: std::mem::take(&mut self.step_results), + step_errors: std::mem::take(&mut self.step_errors), } } } @@ -331,7 +333,7 @@ impl ResultBuilder for HranaBatchProtoBuilder { Err(e) => { self.step_results.push(None); self.step_errors.push(Some(proto_error_from_stmt_error( - &stmt_error_from_sqld_error(e).map_err(QueryResultBuilderError::from_any)?, + Err(HranaError::from(e)).map_err(QueryResultBuilderError::from_any)?, ))); } } @@ -359,6 +361,18 @@ impl ResultBuilder for HranaBatchProtoBuilder { self.stmt_builder.add_row_value(v) } + fn finnalize( + &mut self, + _is_txn: bool, + _frame_no: Option, + ) -> Result { + if let Some(ret) = self.ret.take() { + let _ = ret.send(self.into_ret()); + } + + Ok(false) + } + fn finnalize_error(&mut self, _e: String) { todo!() } diff --git a/libsqlx-server/src/hrana/stmt.rs b/libsqlx-server/src/hrana/stmt.rs index 1b843367..d1a7799e 100644 --- a/libsqlx-server/src/hrana/stmt.rs +++ b/libsqlx-server/src/hrana/stmt.rs @@ -1,12 +1,16 @@ use std::collections::HashMap; -use color_eyre::eyre::{anyhow, bail}; +use futures::FutureExt; use libsqlx::analysis::Statement; +use libsqlx::program::Program; use libsqlx::query::{Params, Query, Value}; +use libsqlx::DescribeResponse; +use super::error::HranaError; use super::result_builder::SingleStatementBuilder; use super::{proto, ProtocolError, Version}; use crate::allocation::ConnectionHandle; +// use crate::auth::Authenticated; use crate::hrana; /// An error during execution of an SQL statement. @@ -18,8 +22,8 @@ pub enum StmtError { SqlNoStmt, #[error("SQL string contains more than one statement")] SqlManyStmts, - #[error("Arguments do not match SQL parameters: {msg}")] - ArgsInvalid { msg: String }, + #[error("Arguments do not match SQL parameters: {source}")] + ArgsInvalid { source: color_eyre::eyre::Error }, #[error("Specifying both positional and named arguments is not supported")] ArgsBothPositionalAndNamed, @@ -34,7 +38,7 @@ pub enum StmtError { }, #[error("SQL input error: {message} (at offset {offset})")] SqlInputError { - source: color_eyre::eyre::Error, + source: libsqlx::rusqlite::ffi::Error, message: String, offset: i32, }, @@ -45,31 +49,32 @@ pub enum StmtError { pub async fn execute_stmt( conn: &ConnectionHandle, + // auth: Authenticated, query: Query, -) -> color_eyre::Result { +) -> crate::Result { let (builder, ret) = SingleStatementBuilder::new(); - let pgm = libsqlx::program::Program::from_queries(std::iter::once(query)); - conn.execute(pgm, Box::new(builder)).await?; - ret.await? - .map_err(|sqld_error| match stmt_error_from_sqld_error(sqld_error) { - Ok(stmt_error) => anyhow!(stmt_error), - Err(sqld_error) => anyhow!(sqld_error), - }) + conn.execute( + Program::from_queries(Some(query)), /*, auth*/ + Box::new(builder), + ) + .await; + ret.await + .map_err(|_| crate::error::Error::ConnectionClosed)? } pub async fn describe_stmt( - _db: &ConnectionHandle, - _sql: String, -) -> color_eyre::Result { - todo!(); - // match db.describe(sql).await? { - // Ok(describe_response) => todo!(), - // // Ok(proto_describe_result_from_describe_response( - // // describe_response, - // // )), + db: &ConnectionHandle, + // auth: Authenticated, + sql: String, +) -> crate::Result { + todo!() + // match db.describe(sql/*, auth*/).await? { + // Ok(describe_response) => Ok(proto_describe_result_from_describe_response( + // describe_response, + // )), // Err(sqld_error) => match stmt_error_from_sqld_error(sqld_error) { - // Ok(stmt_error) => bail!(stmt_error), - // Err(sqld_error) => bail!(sqld_error), + // Ok(stmt_error) => Err(stmt_error)?, + // Err(sqld_error) => Err(sqld_error)?, // }, // } } @@ -77,19 +82,19 @@ pub async fn describe_stmt( pub fn proto_stmt_to_query( proto_stmt: &proto::Stmt, sqls: &HashMap, - version: Version, -) -> color_eyre::Result { - let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, version)?; + verion: Version, +) -> crate::Result { + let sql = proto_sql_to_sql(proto_stmt.sql.as_deref(), proto_stmt.sql_id, sqls, verion)?; let mut stmt_iter = Statement::parse(sql); let stmt = match stmt_iter.next() { Some(Ok(stmt)) => stmt, - Some(Err(err)) => bail!(StmtError::SqlParse { source: err.into() }), - None => bail!(StmtError::SqlNoStmt), + Some(Err(err)) => Err(StmtError::SqlParse { source: err.into() })?, + None => Err(StmtError::SqlNoStmt)?, }; if stmt_iter.next().is_some() { - bail!(StmtError::SqlManyStmts) + Err(StmtError::SqlManyStmts)? } let params = if proto_stmt.named_args.is_empty() { @@ -103,7 +108,7 @@ pub fn proto_stmt_to_query( .collect(); Params::Named(values) } else { - bail!(StmtError::ArgsBothPositionalAndNamed) + Err(StmtError::ArgsBothPositionalAndNamed)? }; let want_rows = proto_stmt.want_rows.unwrap_or(true); @@ -162,63 +167,78 @@ fn proto_value_from_value(value: Value) -> proto::Value { } } -// fn proto_describe_result_from_describe_response( -// response: DescribeResponse, -// ) -> proto::DescribeResult { -// proto::DescribeResult { -// params: response -// .params -// .into_iter() -// .map(|p| proto::DescribeParam { name: p.name }) -// .collect(), -// cols: response -// .cols -// .into_iter() -// .map(|c| proto::DescribeCol { -// name: c.name, -// decltype: c.decltype, -// }) -// .collect(), -// is_explain: response.is_explain, -// is_readonly: response.is_readonly, -// } -// } +fn proto_describe_result_from_describe_response( + response: DescribeResponse, +) -> proto::DescribeResult { + proto::DescribeResult { + params: response + .params + .into_iter() + .map(|p| proto::DescribeParam { name: p.name }) + .collect(), + cols: response + .cols + .into_iter() + .map(|c| proto::DescribeCol { + name: c.name, + decltype: c.decltype, + }) + .collect(), + is_explain: response.is_explain, + is_readonly: response.is_readonly, + } +} -pub fn stmt_error_from_sqld_error( - sqld_error: libsqlx::error::Error, -) -> Result { - Ok(match sqld_error { - libsqlx::error::Error::LibSqlInvalidQueryParams(msg) => StmtError::ArgsInvalid { msg }, - libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout, - libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy, - libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }, - libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => { - StmtError::SqliteError { - source: sqlite_error, - message, - } - } - libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => { - StmtError::SqliteError { - message: sqlite_error.to_string(), - source: sqlite_error, +impl From for HranaError { + fn from(error: crate::error::Error) -> Self { + if let crate::error::Error::Libsqlx(e) = error { + match e { + libsqlx::error::Error::LibSqlInvalidQueryParams(source) => StmtError::ArgsInvalid { + source: color_eyre::eyre::anyhow!("{source}"), } + .into(), + libsqlx::error::Error::LibSqlTxTimeout => StmtError::TransactionTimeout.into(), + libsqlx::error::Error::LibSqlTxBusy => StmtError::TransactionBusy.into(), + libsqlx::error::Error::Blocked(reason) => StmtError::Blocked { reason }.into(), + libsqlx::error::Error::RusqliteError(rusqlite_error) => match rusqlite_error { + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, Some(message)) => { + StmtError::SqliteError { + source: sqlite_error, + message, + } + .into() + } + libsqlx::error::RusqliteError::SqliteFailure(sqlite_error, None) => { + StmtError::SqliteError { + message: sqlite_error.to_string(), + source: sqlite_error, + } + .into() + } + libsqlx::error::RusqliteError::SqlInputError { + error: sqlite_error, + msg: message, + offset, + .. + } => StmtError::SqlInputError { + source: sqlite_error, + message, + offset, + } + .into(), + rusqlite_error => { + return crate::error::Error::from(libsqlx::error::Error::RusqliteError( + rusqlite_error, + )) + .into() + } + }, + sqld_error => return crate::error::Error::from(sqld_error).into(), } - libsqlx::error::RusqliteError::SqlInputError { - error: sqlite_error, - msg: message, - offset, - .. - } => StmtError::SqlInputError { - source: sqlite_error.into(), - message, - offset, - }, - rusqlite_error => return Err(libsqlx::error::Error::RusqliteError(rusqlite_error)), - }, - sqld_error => return Err(sqld_error), - }) + } else { + Self::Libsqlx(error) + } + } } pub fn proto_error_from_stmt_error(error: &StmtError) -> hrana::proto::Error { diff --git a/libsqlx-server/src/hrana/ws/mod.rs b/libsqlx-server/src/hrana/ws/mod.rs index 32a34957..bcdb5209 100644 --- a/libsqlx-server/src/hrana/ws/mod.rs +++ b/libsqlx-server/src/hrana/ws/mod.rs @@ -1,7 +1,5 @@ -use crate::auth::Auth; +// use crate::auth::Auth; use crate::database::Database; -use crate::utils::services::idle_shutdown::IdleKicker; -use anyhow::{Context as _, Result}; use enclose::enclose; use std::net::SocketAddr; use std::sync::atomic::{AtomicU64, Ordering}; @@ -14,10 +12,9 @@ mod conn; mod handshake; mod session; -struct Server { - db_factory: Arc>, +struct Server { auth: Arc, - idle_kicker: Option, + // idle_kicker: Option, next_conn_id: AtomicU64, } diff --git a/libsqlx-server/src/http/admin.rs b/libsqlx-server/src/http/admin.rs index ff718674..2bac534c 100644 --- a/libsqlx-server/src/http/admin.rs +++ b/libsqlx-server/src/http/admin.rs @@ -197,6 +197,7 @@ async fn list_allocs( .handler() .store() .list_allocs() + .unwrap() .into_iter() .map(|cfg| AllocView { id: cfg.db_name }) .collect(); diff --git a/libsqlx-server/src/http/user/error.rs b/libsqlx-server/src/http/user/error.rs index 9aab9a71..81a9ea2b 100644 --- a/libsqlx-server/src/http/user/error.rs +++ b/libsqlx-server/src/http/user/error.rs @@ -11,6 +11,8 @@ pub enum UserApiError { InvalidHost, #[error("Database `{0}` doesn't exist")] UnknownDatabase(String), + #[error(transparent)] + LibsqlxServer(#[from] crate::error::Error), } impl UserApiError { @@ -19,6 +21,7 @@ impl UserApiError { UserApiError::MissingHost | UserApiError::InvalidHost | UserApiError::UnknownDatabase(_) => StatusCode::BAD_REQUEST, + UserApiError::LibsqlxServer(_) => StatusCode::INTERNAL_SERVER_ERROR, } } } diff --git a/libsqlx-server/src/http/user/extractors.rs b/libsqlx-server/src/http/user/extractors.rs index 582b0fd6..fc84900e 100644 --- a/libsqlx-server/src/http/user/extractors.rs +++ b/libsqlx-server/src/http/user/extractors.rs @@ -20,7 +20,7 @@ impl FromRequestParts> for Database { let Ok(host_str) = std::str::from_utf8(host.as_bytes()) else {return Err(UserApiError::MissingHost)}; let db_name = parse_host(host_str)?; let db_id = DatabaseId::from_name(db_name); - let Some(sender) = state.manager.schedule(db_id, state.bus.clone()).await else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; + let Some(sender) = state.manager.schedule(db_id, state.bus.clone()).await? else { return Err(UserApiError::UnknownDatabase(db_name.to_owned())) }; Ok(Database { sender }) } diff --git a/libsqlx-server/src/http/user/mod.rs b/libsqlx-server/src/http/user/mod.rs index c947fb8b..3653377b 100644 --- a/libsqlx-server/src/http/user/mod.rs +++ b/libsqlx-server/src/http/user/mod.rs @@ -1,12 +1,18 @@ use std::sync::Arc; +use axum::extract::State; +use axum::response::IntoResponse; use axum::routing::post; use axum::{Json, Router}; use color_eyre::Result; +use hyper::StatusCode; use hyper::server::accept::Accept; +use serde::Serialize; use tokio::io::{AsyncRead, AsyncWrite}; use crate::database::Database; +use crate::hrana; +use crate::hrana::error::HranaError; use crate::hrana::http::proto::{PipelineRequestBody, PipelineResponseBody}; use crate::linc::bus::Bus; use crate::manager::Manager; @@ -14,14 +20,38 @@ use crate::manager::Manager; mod error; mod extractors; +#[derive(Debug, Serialize)] +struct ErrorResponseBody { + pub message: String, + pub code: String, +} + +impl IntoResponse for HranaError { + fn into_response(self) -> axum::response::Response { + let (message, code) = match self.code() { + Some(code) => (self.to_string(), code.to_owned()), + None => ("internal error, please check the logs".to_owned(), "INTERNAL_ERROR".to_owned()), + }; + let resp = ErrorResponseBody { + message, + code, + }; + let mut resp = Json(resp).into_response(); + *resp.status_mut() = StatusCode::BAD_REQUEST; + resp + } +} + pub struct Config { pub manager: Arc, pub bus: Arc>>, + pub hrana_server: Arc, } struct UserApiState { manager: Arc, bus: Arc>>, + hrana_server: Arc, } pub async fn run_user_api(config: Config, listener: I) -> Result<()> @@ -32,6 +62,7 @@ where let state = UserApiState { manager: config.manager, bus: config.bus, + hrana_server: config.hrana_server, }; let app = Router::new() @@ -46,9 +77,10 @@ where } async fn handle_hrana_pipeline( + State(state): State>, db: Database, Json(req): Json, -) -> Json { - let resp = db.hrana_pipeline(req).await; - Json(resp.unwrap()) +) -> crate::Result, HranaError> { + let ret = hrana::http::handle_pipeline(&state.hrana_server, req, db).await?; + Ok(Json(ret)) } diff --git a/libsqlx-server/src/linc/bus.rs b/libsqlx-server/src/linc/bus.rs index 5072c8ae..c74ba267 100644 --- a/libsqlx-server/src/linc/bus.rs +++ b/libsqlx-server/src/linc/bus.rs @@ -36,11 +36,9 @@ impl Bus { } pub async fn incomming(self: &Arc, incomming: Inbound) { - self.handler.handle(self.clone(), incomming).await; - } - - pub fn send_queue(&self) -> &SendQueue { - &self.send_queue + if let Err(e) = self.handler.handle(self.clone(), incomming).await { + tracing::error!("error handling message: {e}") + } } pub fn connect(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { @@ -48,27 +46,26 @@ impl Bus { self.peers.write().insert(node_id); self.send_queue.register(node_id) } - - pub fn disconnect(&self, node_id: NodeId) { - self.peers.write().remove(&node_id); - } } #[async_trait::async_trait] pub trait Dispatch: Send + Sync + 'static { - async fn dispatch(&self, msg: Outbound); + async fn dispatch(&self, msg: Outbound) -> crate::Result<()>; + /// id of the current node fn node_id(&self) -> NodeId; } #[async_trait::async_trait] impl Dispatch for Bus { - async fn dispatch(&self, msg: Outbound) { + async fn dispatch(&self, msg: Outbound) -> crate::Result<()> { assert!( msg.to != self.node_id(), "trying to send a message to ourself!" ); // This message is outbound. - self.send_queue.enqueue(msg).await; + self.send_queue.enqueue(msg).await?; + + Ok(()) } fn node_id(&self) -> NodeId { diff --git a/libsqlx-server/src/linc/connection.rs b/libsqlx-server/src/linc/connection.rs index b979c437..9fa41bb1 100644 --- a/libsqlx-server/src/linc/connection.rs +++ b/libsqlx-server/src/linc/connection.rs @@ -62,13 +62,24 @@ impl SendQueue { } } - pub async fn enqueue(&self, msg: Outbound) { + pub async fn enqueue(&self, msg: Outbound) -> crate::Result<()> { let sender = match self.senders.read().get(&msg.to) { Some(sender) => sender.clone(), - None => todo!("no queue"), + None => { + return Err(crate::error::Error::Internal(color_eyre::eyre::anyhow!( + "failed to deliver message: unknown node id `{}`", + msg.to + ))) + } }; - sender.send(msg.enveloppe).unwrap(); + sender.send(msg.enveloppe).map_err(|_| { + crate::error::Error::Internal(color_eyre::eyre::anyhow!( + "failed to deliver message: connection closed" + )) + })?; + + Ok(()) } pub fn register(&self, node_id: NodeId) -> mpsc::UnboundedReceiver { @@ -155,14 +166,22 @@ where } } }, - // TODO: pop send queue - Some(m) = self.send_queue.as_mut().unwrap().recv() => { - self.conn.feed(m).await.unwrap(); - // send as many as possible - while let Ok(m) = self.send_queue.as_mut().unwrap().try_recv() { - self.conn.feed(m).await.unwrap(); + Some(m) = self.send_queue.as_mut().expect("no send_queue in connected sate").recv() => { + let feed = || async { + self.conn.feed(m).await?; + // send as many as possible + while let Ok(m) = self.send_queue.as_mut().expect("no send_queue in connected sate").try_recv() { + self.conn.feed(m).await?; + } + self.conn.flush().await?; + + Ok(()) + }; + + if let Err(e) = feed().await { + tracing::error!("error flusing send queue for {}; closing connection", self.peer.unwrap()); + self.state = ConnectionState::CloseError(e) } - self.conn.flush().await.unwrap(); }, else => { self.state = ConnectionState::Close; diff --git a/libsqlx-server/src/linc/connection_manager.rs b/libsqlx-server/src/linc/connection_manager.rs deleted file mode 100644 index e69de29b..00000000 diff --git a/libsqlx-server/src/linc/handler.rs b/libsqlx-server/src/linc/handler.rs index 2d17ff96..410e4c24 100644 --- a/libsqlx-server/src/linc/handler.rs +++ b/libsqlx-server/src/linc/handler.rs @@ -6,7 +6,7 @@ use super::Inbound; #[async_trait::async_trait] pub trait Handler: Sized + Send + Sync + 'static { /// Handle inbound message - async fn handle(&self, bus: Arc, msg: Inbound); + async fn handle(&self, bus: Arc, msg: Inbound) -> crate::Result<()>; } #[cfg(test)] @@ -16,7 +16,8 @@ where F: Fn(Arc, Inbound) -> Fut + Send + Sync + 'static, Fut: std::future::Future + Send, { - async fn handle(&self, bus: Arc, msg: Inbound) { - (self)(bus, msg).await + async fn handle(&self, bus: Arc, msg: Inbound) -> crate::Result<()> { + (self)(bus, msg).await; + Ok(()) } } diff --git a/libsqlx-server/src/linc/mod.rs b/libsqlx-server/src/linc/mod.rs index 638f56e2..2ee07790 100644 --- a/libsqlx-server/src/linc/mod.rs +++ b/libsqlx-server/src/linc/mod.rs @@ -1,4 +1,4 @@ -use self::proto::{Enveloppe, Message}; +use self::proto::Enveloppe; pub mod bus; pub mod connection; @@ -11,7 +11,6 @@ pub mod server; pub type NodeId = u64; const CURRENT_PROTO_VERSION: u32 = 1; -const MAX_STREAM_MSG: usize = 64; #[derive(Debug)] pub struct Inbound { @@ -21,18 +20,6 @@ pub struct Inbound { pub enveloppe: Enveloppe, } -impl Inbound { - pub fn respond(&self, message: Message) -> Outbound { - Outbound { - to: self.from, - enveloppe: Enveloppe { - database_id: None, - message, - }, - } - } -} - #[derive(Debug)] pub struct Outbound { pub to: NodeId, diff --git a/libsqlx-server/src/main.rs b/libsqlx-server/src/main.rs index ff5c415a..6162df26 100644 --- a/libsqlx-server/src/main.rs +++ b/libsqlx-server/src/main.rs @@ -3,7 +3,6 @@ use std::path::{Path, PathBuf}; use std::sync::Arc; use clap::Parser; -use color_eyre::eyre::Result; use compactor::{run_compactor_loop, CompactionQueue}; use config::{AdminApiConfig, ClusterConfig, UserApiConfig}; use http::admin::run_admin_api; @@ -24,6 +23,7 @@ mod allocation; mod compactor; mod config; mod database; +mod error; mod hrana; mod http; mod linc; @@ -32,6 +32,8 @@ mod meta; mod replica_commit_store; mod snapshot_store; +pub type Result = std::result::Result; + #[derive(Debug, Parser)] struct Args { /// Path to the node configuration file @@ -40,10 +42,10 @@ struct Args { } async fn spawn_admin_api( - set: &mut JoinSet>, + set: &mut JoinSet>, config: &AdminApiConfig, bus: Arc>>, -) -> Result<()> { +) -> color_eyre::Result<()> { let admin_api_listener = TcpListener::bind(config.addr).await?; let fut = run_admin_api( http::admin::Config { bus }, @@ -55,14 +57,26 @@ async fn spawn_admin_api( } async fn spawn_user_api( - set: &mut JoinSet>, + set: &mut JoinSet>, config: &UserApiConfig, manager: Arc, bus: Arc>>, -) -> Result<()> { +) -> color_eyre::Result<()> { let user_api_listener = TcpListener::bind(config.addr).await?; + let hrana_server = Arc::new(hrana::http::Server::new(None)); + set.spawn({ + let hrana_server = hrana_server.clone(); + async move { + hrana_server.run_expire().await; + Ok(()) + } + }); set.spawn(run_user_api( - http::user::Config { manager, bus }, + http::user::Config { + manager, + bus, + hrana_server, + }, AddrIncoming::from_listener(user_api_listener)?, )); @@ -70,10 +84,10 @@ async fn spawn_user_api( } async fn spawn_cluster_networking( - set: &mut JoinSet>, + set: &mut JoinSet>, config: &ClusterConfig, bus: Arc>>, -) -> Result<()> { +) -> color_eyre::Result<()> { let server = linc::server::Server::new(bus.clone()); let listener = TcpListener::bind(config.addr).await?; @@ -102,8 +116,9 @@ async fn init_dirs(db_path: &Path) -> color_eyre::Result<()> { } #[tokio::main(flavor = "multi_thread", worker_threads = 10)] -async fn main() -> Result<()> { - init(); +async fn main() -> color_eyre::Result<()> { + init()?; + let args = Args::parse(); let config_str = read_to_string(args.config)?; let config: config::Config = toml::from_str(&config_str)?; @@ -124,8 +139,8 @@ async fn main() -> Result<()> { config.db_path.clone(), snapshot_store, )?); - let store = Arc::new(Store::new(env.clone())); - let replica_commit_store = Arc::new(ReplicaCommitStore::new(env.clone())); + let store = Arc::new(Store::new(env.clone())?); + let replica_commit_store = Arc::new(ReplicaCommitStore::new(env.clone())?); let manager = Arc::new(Manager::new( config.db_path.clone(), store.clone(), @@ -145,7 +160,7 @@ async fn main() -> Result<()> { Ok(()) } -fn init() { +fn init() -> color_eyre::Result<()> { let registry = tracing_subscriber::registry(); registry @@ -160,5 +175,7 @@ fn init() { ) .init(); - color_eyre::install().unwrap(); + color_eyre::install()?; + + Ok(()) } diff --git a/libsqlx-server/src/manager.rs b/libsqlx-server/src/manager.rs index 1b0ca7d1..686e6fca 100644 --- a/libsqlx-server/src/manager.rs +++ b/libsqlx-server/src/manager.rs @@ -9,7 +9,6 @@ use tokio::task::JoinSet; use crate::allocation::config::AllocConfig; use crate::allocation::{Allocation, AllocationMessage, Database}; use crate::compactor::CompactionQueue; -use crate::hrana; use crate::linc::bus::Dispatch; use crate::linc::handler::Handler; use crate::linc::Inbound; @@ -48,14 +47,14 @@ impl Manager { self: &Arc, database_id: DatabaseId, dispatcher: Arc, - ) -> Option> { + ) -> crate::Result>> { if let Some(sender) = self.cache.get(&database_id) { - return Some(sender.clone()); + return Ok(Some(sender.clone())); } - if let Some(config) = self.meta_store.meta(&database_id) { + if let Some(config) = self.meta_store.meta(&database_id)? { let path = self.db_path.join("dbs").join(database_id.to_string()); - tokio::fs::create_dir_all(&path).await.unwrap(); + tokio::fs::create_dir_all(&path).await?; let (alloc_sender, inbox) = mpsc::channel(MAX_ALLOC_MESSAGE_QUEUE_LEN); let alloc = Allocation { inbox, @@ -65,12 +64,11 @@ impl Manager { dispatcher.clone(), self.compaction_queue.clone(), self.replica_commit_store.clone(), - ), + )?, connections_futs: JoinSet::new(), next_conn_id: 0, max_concurrent_connections: config.max_conccurent_connection, - hrana_server: Arc::new(hrana::http::Server::new(None)), - dispatcher, // TODO: handle self URL? + dispatcher, db_name: config.db_name, connections: HashMap::new(), }; @@ -79,10 +77,10 @@ impl Manager { self.cache.insert(database_id, alloc_sender.clone()).await; - return Some(alloc_sender); + return Ok(Some(alloc_sender)); } - None + Ok(None) } pub async fn allocate( @@ -90,16 +88,21 @@ impl Manager { database_id: DatabaseId, meta: &AllocConfig, dispatcher: Arc, - ) { - self.store().allocate(&database_id, meta); - self.schedule(database_id, dispatcher).await; + ) -> crate::Result<()> { + self.store().allocate(&database_id, meta)?; + self.schedule(database_id, dispatcher).await?; + Ok(()) } - pub async fn deallocate(&self, database_id: DatabaseId) { - self.meta_store.deallocate(&database_id); + pub async fn deallocate(&self, database_id: DatabaseId) -> crate::Result<()> { + self.meta_store.deallocate(&database_id)?; self.cache.remove(&database_id).await; let db_path = self.db_path.join("dbs").join(database_id.to_string()); - tokio::fs::remove_dir_all(db_path).await.unwrap(); + if db_path.exists() { + tokio::fs::remove_dir_all(db_path).await?; + } + + Ok(()) } pub fn store(&self) -> &Store { @@ -109,13 +112,15 @@ impl Manager { #[async_trait::async_trait] impl Handler for Arc { - async fn handle(&self, bus: Arc, msg: Inbound) { - if let Some(sender) = self - .clone() - .schedule(msg.enveloppe.database_id.unwrap(), bus.clone()) - .await - { - let _ = sender.send(AllocationMessage::Inbound(msg)).await; + async fn handle(&self, bus: Arc, msg: Inbound) -> crate::Result<()> { + if let Some(database_id) = msg.enveloppe.database_id { + if let Some(sender) = self.clone().schedule(database_id, bus.clone()).await? { + sender + .send(AllocationMessage::Inbound(msg)) + .await + .map_err(|_| crate::error::Error::AllocationClosed)?; + } } + Ok(()) } } diff --git a/libsqlx-server/src/meta.rs b/libsqlx-server/src/meta.rs index 2436839b..baefaeaf 100644 --- a/libsqlx-server/src/meta.rs +++ b/libsqlx-server/src/meta.rs @@ -3,6 +3,7 @@ use std::mem::size_of; use heed::bytemuck::{Pod, Zeroable}; use heed_types::{OwnedType, SerdeBincode}; +use itertools::Itertools; use serde::{Deserialize, Serialize}; use sha3::digest::{ExtendableOutput, Update, XofReader}; use sha3::Shake128; @@ -52,64 +53,72 @@ impl AsRef<[u8]> for DatabaseId { } } +#[derive(Debug, thiserror::Error)] +pub enum AllocationError { + #[error("an allocation already exists for {0}")] + AlreadyExist(String), +} + impl Store { const ALLOC_CONFIG_DB_NAME: &'static str = "alloc_conf_db"; - pub fn new(env: heed::Env) -> Self { - let mut txn = env.write_txn().unwrap(); - let alloc_config_db = env - .create_database(&mut txn, Some(Self::ALLOC_CONFIG_DB_NAME)) - .unwrap(); - txn.commit().unwrap(); + pub fn new(env: heed::Env) -> crate::Result { + let mut txn = env.write_txn()?; + let alloc_config_db = env.create_database(&mut txn, Some(Self::ALLOC_CONFIG_DB_NAME))?; + txn.commit()?; - Self { + Ok(Self { env, alloc_config_db, - } + }) } - pub fn allocate(&self, id: &DatabaseId, meta: &AllocConfig) { - //TODO: Handle conflict + pub fn allocate(&self, id: &DatabaseId, meta: &AllocConfig) -> crate::Result<()> { block_in_place(|| { - let mut txn = self.env.write_txn().unwrap(); + let mut txn = self.env.write_txn()?; if self .alloc_config_db .lazily_decode_data() - .get(&txn, id) - .unwrap() + .get(&txn, id)? .is_some() { - panic!("alloc already exists"); + Err(AllocationError::AlreadyExist(meta.db_name.clone()))?; }; - self.alloc_config_db.put(&mut txn, id, meta).unwrap(); - txn.commit().unwrap(); - }); + + self.alloc_config_db.put(&mut txn, id, meta)?; + + txn.commit()?; + + Ok(()) + }) } - pub fn deallocate(&self, id: &DatabaseId) { + pub fn deallocate(&self, id: &DatabaseId) -> crate::Result<()> { block_in_place(|| { - let mut txn = self.env.write_txn().unwrap(); - self.alloc_config_db.delete(&mut txn, id).unwrap(); - txn.commit().unwrap(); - }); + let mut txn = self.env.write_txn()?; + self.alloc_config_db.delete(&mut txn, id)?; + txn.commit()?; + + Ok(()) + }) } - pub fn meta(&self, id: &DatabaseId) -> Option { + pub fn meta(&self, id: &DatabaseId) -> crate::Result> { block_in_place(|| { - let txn = self.env.read_txn().unwrap(); - self.alloc_config_db.get(&txn, id).unwrap() + let txn = self.env.read_txn()?; + Ok(self.alloc_config_db.get(&txn, id)?) }) } - pub fn list_allocs(&self) -> Vec { + pub fn list_allocs(&self) -> crate::Result> { block_in_place(|| { - let txn = self.env.read_txn().unwrap(); - self.alloc_config_db - .iter(&txn) - .unwrap() - .map(Result::unwrap) - .map(|x| x.1) - .collect() + let txn = self.env.read_txn()?; + let res = self + .alloc_config_db + .iter(&txn)? + .map(|x| x.map(|x| x.1)) + .try_collect()?; + Ok(res) }) } } diff --git a/libsqlx-server/src/replica_commit_store.rs b/libsqlx-server/src/replica_commit_store.rs index 18c0aeed..2598c3b0 100644 --- a/libsqlx-server/src/replica_commit_store.rs +++ b/libsqlx-server/src/replica_commit_store.rs @@ -11,24 +11,24 @@ pub struct ReplicaCommitStore { impl ReplicaCommitStore { const DB_NAME: &str = "replica-commit-store"; - pub fn new(env: heed::Env) -> Self { - let mut txn = env.write_txn().unwrap(); - let database = env.create_database(&mut txn, Some(Self::DB_NAME)).unwrap(); - txn.commit().unwrap(); + pub fn new(env: heed::Env) -> crate::Result { + let mut txn = env.write_txn()?; + let database = env.create_database(&mut txn, Some(Self::DB_NAME))?; + txn.commit()?; - Self { env, database } + Ok(Self { env, database }) } - pub fn commit(&self, database_id: DatabaseId, frame_no: FrameNo) { - let mut txn = self.env.write_txn().unwrap(); - self.database - .put(&mut txn, &database_id, &frame_no) - .unwrap(); - txn.commit().unwrap(); + pub fn commit(&self, database_id: DatabaseId, frame_no: FrameNo) -> crate::Result<()> { + let mut txn = self.env.write_txn()?; + self.database.put(&mut txn, &database_id, &frame_no)?; + txn.commit()?; + + Ok(()) } - pub fn get_commit_index(&self, database_id: DatabaseId) -> Option { - let txn = self.env.read_txn().unwrap(); - self.database.get(&txn, &database_id).unwrap() + pub fn get_commit_index(&self, database_id: DatabaseId) -> crate::Result> { + let txn = self.env.read_txn()?; + Ok(self.database.get(&txn, &database_id)?) } } diff --git a/libsqlx-server/src/snapshot_store.rs b/libsqlx-server/src/snapshot_store.rs index 32f7f0e9..73cbc0cb 100644 --- a/libsqlx-server/src/snapshot_store.rs +++ b/libsqlx-server/src/snapshot_store.rs @@ -10,6 +10,8 @@ use uuid::Uuid; use crate::{compactor::SnapshotFile, meta::DatabaseId}; +/// Equivalent to a u64, but stored in big-endian ordering. +/// Used for storing values whose bytes need to be lexically ordered. #[derive(Clone, Copy, Zeroable, Pod, Debug)] #[repr(transparent)] struct BEU64([u8; size_of::()]); @@ -48,8 +50,8 @@ pub struct SnapshotStore { impl SnapshotStore { const SNAPSHOT_STORE_NAME: &str = "snapshot-store-db"; - pub fn new(db_path: PathBuf, env: heed::Env) -> color_eyre::Result { - let mut txn = env.write_txn().unwrap(); + pub fn new(db_path: PathBuf, env: heed::Env) -> crate::Result { + let mut txn = env.write_txn()?; let database = env.create_database(&mut txn, Some(Self::SNAPSHOT_STORE_NAME))?; txn.commit()?; @@ -67,7 +69,7 @@ impl SnapshotStore { start_frame_no: FrameNo, end_frame_no: FrameNo, snapshot_id: Uuid, - ) { + ) -> crate::Result<()> { let key = SnapshotKey { database_id, start_frame_no: start_frame_no.into(), @@ -76,13 +78,19 @@ impl SnapshotStore { let data = SnapshotMeta { snapshot_id }; - block_in_place(|| self.database.put(txn, &key, &data).unwrap()); + block_in_place(|| self.database.put(txn, &key, &data))?; + + Ok(()) } /// Locate a snapshot for `database_id` that contains `frame_no` - pub fn locate(&self, database_id: DatabaseId, frame_no: FrameNo) -> Option { - let txn = self.env.read_txn().unwrap(); - // Snapshot keys being lexicographically ordered, looking for the first key less than of + pub fn locate( + &self, + database_id: DatabaseId, + frame_no: FrameNo, + ) -> crate::Result> { + let txn = self.env.read_txn()?; + // Snapshot keys are lexicographically ordered, looking for the first key less than of // equal to (db_id, frame_no, FrameNo::MAX) will always return the entry we're looking for // if it exists. let key = SnapshotKey { @@ -91,14 +99,10 @@ impl SnapshotStore { end_frame_no: u64::MAX.into(), }; - match self - .database - .get_lower_than_or_equal_to(&txn, &key) - .transpose()? - { - Ok((key, v)) => { + match self.database.get_lower_than_or_equal_to(&txn, &key)? { + Some((key, v)) => { if key.database_id != database_id { - return None; + return Ok(None); } else if frame_no >= key.start_frame_no.into() && frame_no <= key.end_frame_no.into() { @@ -107,22 +111,26 @@ impl SnapshotStore { u64::from(key.start_frame_no), u64::from(key.end_frame_no) ); - return Some(v); + return Ok(Some(v)); } else { - None + Ok(None) } } - Err(_) => todo!(), + None => Ok(None), } } - pub fn locate_file(&self, database_id: DatabaseId, frame_no: FrameNo) -> Option { - let meta = self.locate(database_id, frame_no)?; + pub fn locate_file( + &self, + database_id: DatabaseId, + frame_no: FrameNo, + ) -> crate::Result> { + let Some(meta) = self.locate(database_id, frame_no)? else { return Ok(None) }; let path = self .db_path .join("snapshots") .join(meta.snapshot_id.to_string()); - Some(SnapshotFile::open(&path).unwrap()) + Ok(Some(SnapshotFile::open(&path)?)) } } diff --git a/libsqlx/src/connection.rs b/libsqlx/src/connection.rs index e767073a..9c3fcdb0 100644 --- a/libsqlx/src/connection.rs +++ b/libsqlx/src/connection.rs @@ -24,11 +24,7 @@ pub struct DescribeCol { pub trait Connection { /// Executes a query program - fn execute_program( - &mut self, - pgm: &Program, - result_builder: Box, - ) -> crate::Result<()>; + fn execute_program(&mut self, pgm: &Program, result_builder: Box); /// Parse the SQL statement and return information about it. fn describe(&self, sql: String) -> crate::Result; @@ -39,11 +35,7 @@ where T: Connection, X: Connection, { - fn execute_program( - &mut self, - pgm: &Program, - result_builder: Box, - ) -> crate::Result<()> { + fn execute_program(&mut self, pgm: &Program, result_builder: Box) { match self { Either::Left(c) => c.execute_program(pgm, result_builder), Either::Right(c) => c.execute_program(pgm, result_builder), diff --git a/libsqlx/src/database/libsql/connection.rs b/libsqlx/src/database/libsql/connection.rs index 0ad8b780..9579dc94 100644 --- a/libsqlx/src/database/libsql/connection.rs +++ b/libsqlx/src/database/libsql/connection.rs @@ -249,12 +249,10 @@ fn eval_cond(cond: &Cond, results: &[bool]) -> Result { } impl Connection for LibsqlConnection { - fn execute_program( - &mut self, - pgm: &Program, - mut builder: Box, - ) -> crate::Result<()> { - self.run(pgm, &mut *builder) + fn execute_program(&mut self, pgm: &Program, mut builder: Box) { + if let Err(e) = self.run(pgm, &mut *builder) { + builder.finnalize_error(e.to_string()); + } } fn describe(&self, sql: String) -> crate::Result { diff --git a/libsqlx/src/database/libsql/injector/mod.rs b/libsqlx/src/database/libsql/injector/mod.rs index cbc9dc80..5318b997 100644 --- a/libsqlx/src/database/libsql/injector/mod.rs +++ b/libsqlx/src/database/libsql/injector/mod.rs @@ -18,7 +18,7 @@ mod headers; mod hook; pub type FrameBuffer = Arc>>; -pub type OnCommitCb = Arc; +pub type OnCommitCb = Arc bool + Send + Sync + 'static>; pub struct Injector { /// The injector is in a transaction state @@ -85,7 +85,7 @@ impl Injector { self.buffer.lock().push_back(frame); if frame_close_txn || self.buffer.lock().len() >= self.capacity { if !self.is_txn { - self.begin_txn(); + self.begin_txn()?; } return self.flush(); } @@ -135,14 +135,14 @@ impl Injector { fn commit(&mut self) { // TODO: error? - let _ = self.connection.execute("COMMIT", ()); + let _ = dbg!(self.connection.execute("COMMIT", ())); } - fn begin_txn(&mut self) { - self.connection.execute("BEGIN IMMEDIATE", ()).unwrap(); + fn begin_txn(&mut self) -> crate::Result<()> { + self.connection.execute("BEGIN IMMEDIATE", ())?; self.connection - .execute("CREATE TABLE __DUMMY__ (__dummy__)", ()) - .unwrap(); + .execute("CREATE TABLE __DUMMY__ (__dummy__)", ())?; + Ok(()) } } diff --git a/libsqlx/src/database/libsql/mod.rs b/libsqlx/src/database/libsql/mod.rs index 9582cf2c..3a1191f1 100644 --- a/libsqlx/src/database/libsql/mod.rs +++ b/libsqlx/src/database/libsql/mod.rs @@ -184,7 +184,7 @@ impl Database for LibsqlDatabase { } impl InjectableDatabase for LibsqlDatabase { - fn injector(&mut self) -> crate::Result> { + fn injector(&self) -> crate::Result> { Ok(Box::new(Injector::new( &self.db_path, self.ty.on_commit_cb.clone(), @@ -228,7 +228,7 @@ mod test { on_commit_cb: Arc::new(|_| ()), injector_buffer_capacity: 10, }; - let mut db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); + let db = LibsqlDatabase::new(temp.path().to_path_buf(), replica); let mut conn = db.connect().unwrap(); let row: Arc>> = Default::default(); @@ -265,7 +265,7 @@ mod test { }, ); - let mut replica = LibsqlDatabase::new( + let replica = LibsqlDatabase::new( temp_replica.path().to_path_buf(), ReplicaType { on_commit_cb: Arc::new(|_| ()), diff --git a/libsqlx/src/database/libsql/replication_log/logger.rs b/libsqlx/src/database/libsql/replication_log/logger.rs index 38112c37..998fada7 100644 --- a/libsqlx/src/database/libsql/replication_log/logger.rs +++ b/libsqlx/src/database/libsql/replication_log/logger.rs @@ -711,20 +711,6 @@ impl LogFileHeader { } } -pub struct Generation { - pub id: Uuid, - pub start_index: u64, -} - -impl Generation { - fn new(start_index: u64) -> Self { - Self { - id: Uuid::new_v4(), - start_index, - } - } -} - pub trait LogCompactor: Sync + Send + 'static { /// returns whether the passed log file should be compacted. If this method returns true, /// compact should be called next. diff --git a/libsqlx/src/database/mod.rs b/libsqlx/src/database/mod.rs index 43fa0dac..868eaf7f 100644 --- a/libsqlx/src/database/mod.rs +++ b/libsqlx/src/database/mod.rs @@ -21,7 +21,7 @@ pub trait Database { } pub trait InjectableDatabase { - fn injector(&mut self) -> crate::Result>; + fn injector(&self) -> crate::Result>; } // Trait implemented by databases that support frame injection diff --git a/libsqlx/src/database/proxy/connection.rs b/libsqlx/src/database/proxy/connection.rs index a7638e56..4e433d7b 100644 --- a/libsqlx/src/database/proxy/connection.rs +++ b/libsqlx/src/database/proxy/connection.rs @@ -144,9 +144,7 @@ where // set the connection state to unknown before executing on the remote self.state.lock().state = State::Unknown; - self.conn - .execute_program(&self.pgm, Box::new(builder)) - .unwrap(); + self.conn.execute_program(&self.pgm, Box::new(builder)); Ok(false) } else { @@ -164,11 +162,7 @@ where R: Connection, W: Connection + Clone + Send + 'static, { - fn execute_program( - &mut self, - pgm: &Program, - builder: Box, - ) -> crate::Result<()> { + fn execute_program(&mut self, pgm: &Program, builder: Box) { if self.state.lock().state.is_idle() && pgm.is_read_only() { if let Some(frame_no) = self.state.lock().last_frame_no { (self.wait_frame_no_cb)(frame_no); @@ -183,9 +177,8 @@ where // We know that this program won't perform any writes. We attempt to run it on the // replica. If it leaves an open transaction, then this program is an interactive // transaction, so we rollback the replica, and execute again on the primary. - self.read_conn.execute_program(pgm, Box::new(builder))?; + self.read_conn.execute_program(pgm, Box::new(builder)); // rollback(&mut self.conn.read_db); - Ok(()) } else { // we set the state to unknown because until we have received from the actual // connection state from the primary. @@ -194,8 +187,7 @@ where builder, state: self.state.clone(), }; - self.write_conn.execute_program(pgm, Box::new(builder))?; - Ok(()) + self.write_conn.execute_program(pgm, Box::new(builder)); } } diff --git a/libsqlx/src/database/proxy/database.rs b/libsqlx/src/database/proxy/database.rs index adc82ed8..d0e7c5a0 100644 --- a/libsqlx/src/database/proxy/database.rs +++ b/libsqlx/src/database/proxy/database.rs @@ -42,7 +42,7 @@ impl InjectableDatabase for WriteProxyDatabase where RDB: InjectableDatabase, { - fn injector(&mut self) -> crate::Result> { + fn injector(&self) -> crate::Result> { self.read_db.injector() } }