Skip to content
This repository was archived by the owner on Oct 18, 2023. It is now read-only.

proxy writes #535

Merged
merged 1 commit into from
Jul 19, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
615 changes: 531 additions & 84 deletions libsqlx-server/src/allocation/mod.rs

Large diffs are not rendered by default.

23 changes: 10 additions & 13 deletions libsqlx-server/src/hrana/batch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ use super::stmt::{proto_stmt_to_query, stmt_error_from_sqld_error};
use super::{proto, ProtocolError, Version};

use color_eyre::eyre::anyhow;
use libsqlx::Connection;
use libsqlx::analysis::Statement;
use libsqlx::program::{Cond, Program, Step};
use libsqlx::query::{Params, Query};
Expand Down Expand Up @@ -78,7 +77,7 @@ pub async fn execute_batch(
let fut = db
.exec(move |conn| -> color_eyre::Result<_> {
let (builder, ret) = HranaBatchProtoBuilder::new();
conn.execute_program(&pgm, builder)?;
conn.execute_program(&pgm, Box::new(builder))?;
Ok(ret)
})
.await??;
Expand Down Expand Up @@ -116,20 +115,18 @@ pub async fn execute_sequence(conn: &ConnectionHandle, pgm: Program) -> color_ey
.exec(move |conn| -> color_eyre::Result<_> {
let (snd, rcv) = oneshot::channel();
let builder = StepResultsBuilder::new(snd);
conn.execute_program(&pgm, builder)?;
conn.execute_program(&pgm, Box::new(builder))?;

Ok(rcv)
})
.await??;

fut.await?
.into_iter()
.try_for_each(|result| match result {
StepResult::Ok => Ok(()),
StepResult::Err(e) => match stmt_error_from_sqld_error(e) {
Ok(stmt_err) => Err(anyhow!(stmt_err)),
Err(sqld_err) => Err(anyhow!(sqld_err)),
},
StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")),
})
fut.await?.into_iter().try_for_each(|result| match result {
StepResult::Ok => Ok(()),
StepResult::Err(e) => match stmt_error_from_sqld_error(e) {
Ok(stmt_err) => Err(anyhow!(stmt_err)),
Err(sqld_err) => Err(anyhow!(sqld_err)),
},
StepResult::Skipped => Err(anyhow!("Statement in sequence was not executed")),
})
}
48 changes: 29 additions & 19 deletions libsqlx-server/src/hrana/http/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::sync::Arc;

use color_eyre::eyre::Context;
use futures::Future;
use parking_lot::Mutex;
use serde::{de::DeserializeOwned, Serialize};
use tokio::sync::oneshot;

use crate::allocation::ConnectionHandle;

Expand Down Expand Up @@ -47,31 +50,38 @@ fn handle_index() -> color_eyre::Result<hyper::Response<hyper::Body>> {
}

pub async fn handle_pipeline<F, Fut>(
server: &Server,
server: Arc<Server>,
req: PipelineRequestBody,
ret: oneshot::Sender<color_eyre::Result<PipelineResponseBody>>,
mk_conn: F,
) -> color_eyre::Result<PipelineResponseBody>
) -> color_eyre::Result<()>
where
F: FnOnce() -> Fut,
Fut: Future<Output = crate::Result<ConnectionHandle>>,
{
let mut stream_guard = stream::acquire(server, req.baton.as_deref(), mk_conn).await?;

let mut results = Vec::with_capacity(req.requests.len());
for request in req.requests.into_iter() {
let result = request::handle(&mut stream_guard, request)
.await
.context("Could not execute a request in pipeline")?;
results.push(result);
}

let resp_body = proto::PipelineResponseBody {
baton: stream_guard.release(),
base_url: server.self_url.clone(),
results,
};

Ok(resp_body)
let mut stream_guard = stream::acquire(server.clone(), req.baton.as_deref(), mk_conn).await?;

tokio::spawn(async move {
let f = async move {
let mut results = Vec::with_capacity(req.requests.len());
for request in req.requests.into_iter() {
let result = request::handle(&mut stream_guard, request)
.await
.context("Could not execute a request in pipeline")?;
results.push(result);
}

Ok(proto::PipelineResponseBody {
baton: stream_guard.release(),
base_url: server.self_url.clone(),
results,
})
};

let _ = ret.send(f.await);
});

Ok(())
}

async fn read_request_json<T: DeserializeOwned>(
Expand Down
4 changes: 2 additions & 2 deletions libsqlx-server/src/hrana/http/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ pub enum StreamResponseError {
}

pub async fn handle(
stream_guard: &mut stream::Guard<'_>,
stream_guard: &mut stream::Guard,
request: proto::StreamRequest,
) -> color_eyre::Result<proto::StreamResult> {
let result = match try_handle(stream_guard, request).await {
Expand All @@ -31,7 +31,7 @@ pub async fn handle(
}

async fn try_handle(
stream_guard: &mut stream::Guard<'_>,
stream_guard: &mut stream::Guard,
request: proto::StreamRequest,
) -> color_eyre::Result<proto::StreamResponse> {
Ok(match request {
Expand Down
19 changes: 10 additions & 9 deletions libsqlx-server/src/hrana/http/stream.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::cmp::Reverse;
use std::collections::{HashMap, VecDeque};
use std::pin::Pin;
use std::sync::Arc;
use std::{future, mem, task};

use base64::prelude::{Engine as _, BASE64_STANDARD_NO_PAD};
Expand Down Expand Up @@ -67,8 +68,8 @@ struct Stream {
/// Guard object that is used to access a stream from the outside. The guard makes sure that the
/// stream's entry in [`ServerStreamState::handles`] is either removed or replaced with
/// [`Handle::Available`] after the guard goes out of scope.
pub struct Guard<'srv> {
server: &'srv Server,
pub struct Guard {
server: Arc<Server>,
/// The guarded stream. This is only set to `None` in the destructor.
stream: Option<Box<Stream>>,
/// If set to `true`, the destructor will release the stream for further use (saving it as
Expand Down Expand Up @@ -101,18 +102,18 @@ impl ServerStreamState {

/// Acquire a guard to a new or existing stream. If baton is `Some`, we try to look up the stream,
/// otherwise we create a new stream.
pub async fn acquire<'srv, F, Fut>(
server: &'srv Server,
pub async fn acquire<F, Fut>(
server: Arc<Server>,
baton: Option<&str>,
mk_conn: F,
) -> color_eyre::Result<Guard<'srv>>
) -> color_eyre::Result<Guard>
where
F: FnOnce() -> Fut,
Fut: Future<Output = crate::Result<ConnectionHandle>>,
{
let stream = match baton {
Some(baton) => {
let (stream_id, baton_seq) = decode_baton(server, baton)?;
let (stream_id, baton_seq) = decode_baton(&server, baton)?;

let mut state = server.stream_state.lock();
let handle = state.handles.get_mut(&stream_id);
Expand Down Expand Up @@ -182,7 +183,7 @@ where
})
}

impl<'srv> Guard<'srv> {
impl Guard {
pub fn get_db(&self) -> Result<&ConnectionHandle, ProtocolError> {
let stream = self.stream.as_ref().unwrap();
stream.conn.as_ref().ok_or(ProtocolError::BatonStreamClosed)
Expand Down Expand Up @@ -211,7 +212,7 @@ impl<'srv> Guard<'srv> {
if stream.conn.is_some() {
self.release = true; // tell destructor to make the stream available again
Some(encode_baton(
self.server,
&self.server,
stream.stream_id,
stream.baton_seq,
))
Expand All @@ -221,7 +222,7 @@ impl<'srv> Guard<'srv> {
}
}

impl<'srv> Drop for Guard<'srv> {
impl Drop for Guard {
fn drop(&mut self) {
let stream = self.stream.take().unwrap();
let stream_id = stream.stream_id;
Expand Down
67 changes: 36 additions & 31 deletions libsqlx-server/src/hrana/result_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,16 +11,22 @@ use super::proto;

pub struct SingleStatementBuilder {
builder: StatementBuilder,
ret: oneshot::Sender<Result<proto::StmtResult, libsqlx::error::Error>>,
ret: Option<oneshot::Sender<Result<proto::StmtResult, libsqlx::error::Error>>>,
}

impl SingleStatementBuilder {
pub fn new() -> (Self, oneshot::Receiver<Result<proto::StmtResult, libsqlx::error::Error>>) {
pub fn new() -> (
Self,
oneshot::Receiver<Result<proto::StmtResult, libsqlx::error::Error>>,
) {
let (ret, rcv) = oneshot::channel();
(Self {
builder: StatementBuilder::default(),
ret,
}, rcv)
(
Self {
builder: StatementBuilder::default(),
ret: Some(ret),
},
rcv,
)
}
}

Expand All @@ -38,7 +44,8 @@ impl ResultBuilder for SingleStatementBuilder {
affected_row_count: u64,
last_insert_rowid: Option<i64>,
) -> Result<(), QueryResultBuilderError> {
self.builder.finish_step(affected_row_count, last_insert_rowid)
self.builder
.finish_step(affected_row_count, last_insert_rowid)
}

fn step_error(&mut self, error: libsqlx::error::Error) -> Result<(), QueryResultBuilderError> {
Expand All @@ -61,19 +68,16 @@ impl ResultBuilder for SingleStatementBuilder {
}

fn finnalize(
self,
&mut self,
_is_txn: bool,
_frame_no: Option<FrameNo>,
) -> Result<bool, QueryResultBuilderError>
where Self: Sized
{
let res = self.builder.into_ret();
let _ = self.ret.send(res);
) -> Result<bool, QueryResultBuilderError> {
let res = self.builder.take_ret();
let _ = self.ret.take().unwrap().send(res);
Ok(true)
}
}


#[derive(Debug, Default)]
struct StatementBuilder {
has_step: bool,
Expand Down Expand Up @@ -191,12 +195,12 @@ impl StatementBuilder {
Ok(())
}

pub fn into_ret(self) -> Result<proto::StmtResult, libsqlx::error::Error> {
match self.err {
pub fn take_ret(&mut self) -> Result<proto::StmtResult, libsqlx::error::Error> {
match self.err.take() {
Some(err) => Err(err),
None => Ok(proto::StmtResult {
cols: self.cols,
rows: self.rows,
cols: std::mem::take(&mut self.cols),
rows: std::mem::take(&mut self.rows),
affected_row_count: self.affected_row_count,
last_insert_rowid: self.last_insert_rowid,
}),
Expand Down Expand Up @@ -262,23 +266,24 @@ pub struct HranaBatchProtoBuilder {
current_size: u64,
max_response_size: u64,
step_empty: bool,
ret: oneshot::Sender<proto::BatchResult>
ret: oneshot::Sender<proto::BatchResult>,
}

impl HranaBatchProtoBuilder {
pub fn new() -> (Self, oneshot::Receiver<proto::BatchResult>) {
let (ret, rcv) = oneshot::channel();
(Self {
step_results: Vec::new(),
step_errors: Vec::new(),
stmt_builder: StatementBuilder::default(),
current_size: 0,
max_response_size: u64::MAX,
step_empty: false,
ret,
},
rcv)

(
Self {
step_results: Vec::new(),
step_errors: Vec::new(),
stmt_builder: StatementBuilder::default(),
current_size: 0,
max_response_size: u64::MAX,
step_empty: false,
ret,
},
rcv,
)
}
pub fn into_ret(self) -> proto::BatchResult {
proto::BatchResult {
Expand Down Expand Up @@ -314,7 +319,7 @@ impl ResultBuilder for HranaBatchProtoBuilder {
max_response_size: self.max_response_size - self.current_size,
..Default::default()
};
match std::mem::replace(&mut self.stmt_builder, new_builder).into_ret() {
match std::mem::replace(&mut self.stmt_builder, new_builder).take_ret() {
Ok(res) => {
self.step_results.push((!self.step_empty).then_some(res));
self.step_errors.push(None);
Expand Down
3 changes: 1 addition & 2 deletions libsqlx-server/src/hrana/stmt.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ use std::collections::HashMap;
use color_eyre::eyre::{anyhow, bail};
use libsqlx::analysis::Statement;
use libsqlx::query::{Params, Query, Value};
use libsqlx::Connection;

use super::result_builder::SingleStatementBuilder;
use super::{proto, ProtocolError, Version};
Expand Down Expand Up @@ -52,7 +51,7 @@ pub async fn execute_stmt(
.exec(move |conn| -> color_eyre::Result<_> {
let (builder, ret) = SingleStatementBuilder::new();
let pgm = libsqlx::program::Program::from_queries(std::iter::once(query));
conn.execute_program(&pgm, builder)?;
conn.execute_program(&pgm, Box::new(builder))?;

Ok(ret)
})
Expand Down
2 changes: 1 addition & 1 deletion libsqlx-server/src/linc/connection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ where
self.conn.feed(m).await.unwrap();
}
self.conn.flush().await.unwrap();
}
},
else => {
self.state = ConnectionState::Close;
}
Expand Down
Loading