diff --git a/Cargo.toml b/Cargo.toml index 8731b28..3cd1593 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -15,7 +15,7 @@ base64 = "0.21.0" num-traits = "0.2.15" serde_json = "1.0.91" worker = { version = "0.0.12", optional = true } -spin-sdk = { version = "1.1.0", git = "https://github.com/fermyon/spin", tag = "v1.1.0", optional = true } +spin-sdk = { version = "1", git = "https://github.com/fermyon/spin", tag = "v1.1.0", optional = true } http = { version = "0.2", optional = true } bytes = { version = "1.4.0", optional = true } anyhow = "1.0.69" diff --git a/src/workers.rs b/src/workers.rs index 304a051..b55b8b4 100644 --- a/src/workers.rs +++ b/src/workers.rs @@ -6,12 +6,64 @@ use worker::*; use crate::{BatchResult, ResultSet, Statement}; +#[derive(Debug)] +pub struct WebSocketClient { + socket: WebSocket, + next_reqid: std::sync::atomic::AtomicI32, +} + +#[derive(Debug)] +pub struct HttpClient { + url: String, + auth: String, +} + +#[derive(Debug)] +pub enum ClientInner { + WebSocket(WebSocketClient), + Http(HttpClient), +} + +impl WebSocketClient { + fn send_request(&self, request: proto::Request) -> Result<()> { + // NOTICE: we effective allow concurrency of 1 here, until we implement + // id allocation andfMe request tracking + let request_id = self + .next_reqid + .fetch_add(1, std::sync::atomic::Ordering::Relaxed); + let msg = proto::ClientMsg::Request { + request_id, + request, + }; + + self.socket.send(&msg) + } + + async fn recv_response(event_stream: &mut EventStream<'_>) -> Result { + use futures_util::StreamExt; + + // NOTICE: we're effectively synchronously waiting for the response here + if let Some(event) = event_stream.next().await { + match event? { + WebsocketEvent::Message(msg) => { + let stmt_result: proto::ServerMsg = msg.json::()?; + Ok(stmt_result) + } + WebsocketEvent::Close(msg) => { + Err(Error::RustError(format!("connection closed: {msg:?}"))) + } + } + } else { + Err(Error::RustError("no response".to_string())) + } + } +} + /// Database client. This is the main structure used to /// communicate with the database. #[derive(Debug)] pub struct Client { - socket: WebSocket, - next_reqid: std::sync::atomic::AtomicI32, + pub inner: ClientInner, } impl Client { @@ -25,15 +77,26 @@ impl Client { let url = url.into(); // Auto-update the URL to start with https://. // It will get updated to wss via Workers API automatically - let url = if !url.contains("://") { - format!("https://{}", &url) + let (url, is_websocket) = if !url.contains("://") { + (format!("https://{}", &url), true) } else if let Some(url) = url.strip_prefix("libsql://") { - "https://".to_owned() + url + ("https://".to_owned() + url, true) } else if let Some(url) = url.strip_prefix("wss://") { - "https://".to_owned() + url + ("https://".to_owned() + url, true) + } else if let Some(url) = url.strip_prefix("ws://") { + ("https://".to_owned() + url, true) } else { - url + (url, false) }; + + if !is_websocket { + let inner = ClientInner::Http(HttpClient { + url: url.clone(), + auth: token.clone(), + }); + return Ok(Self { inner }); + } + let url = url::Url::parse(&url) .context("Failed to parse URL") .map_err(|e| Error::from(format!("{e}")))?; @@ -72,15 +135,16 @@ impl Client { // TODO: they could be pipelined with the first request to save latency. // For that, we need to keep the event stream open in the Client, // but that's tricky with the borrow checker. - Self::recv_response(&mut event_stream).await?; - Self::recv_response(&mut event_stream).await?; + WebSocketClient::recv_response(&mut event_stream).await?; + WebSocketClient::recv_response(&mut event_stream).await?; tracing::debug!("Stream opened"); drop(event_stream); - Ok(Self { + let inner = ClientInner::WebSocket(WebSocketClient { socket, next_reqid: std::sync::atomic::AtomicI32::new(1), - }) + }); + Ok(Self { inner }) } /// Creates a database client from a `Config` object. @@ -135,99 +199,104 @@ impl Client { .await } - fn send_request(&self, request: proto::Request) -> Result<()> { - // NOTICE: we effective allow concurrency of 1 here, until we implement - // id allocation andfMe request tracking - let request_id = self - .next_reqid - .fetch_add(1, std::sync::atomic::Ordering::Relaxed); - let msg = proto::ClientMsg::Request { - request_id, - request, - }; - - self.socket.send(&msg) - } - - async fn recv_response(event_stream: &mut EventStream<'_>) -> Result { - use futures_util::StreamExt; - - // NOTICE: we're effectively synchronously waiting for the response here - if let Some(event) = event_stream.next().await { - match event? { - WebsocketEvent::Message(msg) => { - let stmt_result: proto::ServerMsg = msg.json::()?; - Ok(stmt_result) - } - WebsocketEvent::Close(msg) => { - Err(Error::RustError(format!("connection closed: {msg:?}"))) - } - } - } else { - Err(Error::RustError("no response".to_string())) - } - } - async fn raw_batch( &self, stmts: impl IntoIterator>, ) -> Result { - let mut batch = proto::Batch::new(); + match &self.inner { + ClientInner::WebSocket(ws) => { + let mut batch = proto::Batch::new(); - for stmt in stmts.into_iter() { - let stmt: Statement = stmt.into(); - let mut hrana_stmt = proto::Stmt::new(stmt.sql, true); - for param in stmt.args { - hrana_stmt.bind(param); - } - batch.step(None, hrana_stmt); - } + for stmt in stmts.into_iter() { + let stmt: Statement = stmt.into(); + let mut hrana_stmt = proto::Stmt::new(stmt.sql, true); + for param in stmt.args { + hrana_stmt.bind(param); + } + batch.step(None, hrana_stmt); + } - let mut event_stream = self.socket.events()?; - - // NOTICE: if we want to support concurrent requests, we need to - // actually start managing stream ids - self.send_request(proto::Request::Batch(proto::BatchReq { - stream_id: 0, - batch, - }))?; - - match Self::recv_response(&mut event_stream).await? { - proto::ServerMsg::ResponseOk { - request_id: _, - response: proto::Response::Batch(proto::BatchResp { result }), - } => Ok(result), - proto::ServerMsg::ResponseError { - request_id: _, - error, - } => Err(Error::RustError(format!("{error}"))), - _ => Err(Error::RustError("unexpected response".to_string())), + let mut event_stream = ws.socket.events()?; + + // NOTICE: if we want to support concurrent requests, we need to + // actually start managing stream ids + ws.send_request(proto::Request::Batch(proto::BatchReq { + stream_id: 0, + batch, + }))?; + + match WebSocketClient::recv_response(&mut event_stream).await? { + proto::ServerMsg::ResponseOk { + request_id: _, + response: proto::Response::Batch(proto::BatchResp { result }), + } => Ok(result), + proto::ServerMsg::ResponseError { + request_id: _, + error, + } => Err(Error::RustError(format!("{error}"))), + _ => Err(Error::RustError("unexpected response".to_string())), + } + } + ClientInner::Http(http) => { + let mut headers = Headers::new(); + headers.append("Authorization", &http.auth).ok(); + let (body, stmts_count) = crate::client::statements_to_string(stmts); + let request_init = RequestInit { + body: Some(wasm_bindgen::JsValue::from_str(&body)), + headers, + cf: CfProperties::new(), + method: Method::Post, + redirect: RequestRedirect::Follow, + }; + let req = Request::new_with_init(&http.url, &request_init)?; + let mut response = Fetch::Request(req).send().await?; + if response.status_code() != 200 { + return Err(worker::Error::from(format!("{}", response.status_code()))); + } + let resp: String = response.text().await?; + let response_json: serde_json::Value = serde_json::from_str(&resp)?; + crate::client::http_json_to_batch_result(response_json, stmts_count).map_err(|e| { + worker::Error::from(format!("Error: {} ({:?})", e, request_init.body)) + }) + } } } async fn execute(&self, stmt: impl Into) -> Result { - let stmt: Statement = stmt.into(); - let mut hrana_stmt = proto::Stmt::new(stmt.sql, true); - for param in stmt.args { - hrana_stmt.bind(param); - } - - let mut event_stream = self.socket.events()?; - - self.send_request(proto::Request::Execute(proto::ExecuteReq { - stream_id: 0, - stmt: hrana_stmt, - }))?; - match Self::recv_response(&mut event_stream).await? { - proto::ServerMsg::ResponseOk { - request_id: _, - response: proto::Response::Execute(proto::ExecuteResp { result }), - } => Ok(ResultSet::from(result)), - proto::ServerMsg::ResponseError { - request_id: _, - error, - } => Err(Error::RustError(format!("{error}"))), - _ => Err(Error::RustError("unexpected response".to_string())), + match &self.inner { + ClientInner::WebSocket(ws) => { + let stmt: Statement = stmt.into(); + let mut hrana_stmt = proto::Stmt::new(stmt.sql, true); + for param in stmt.args { + hrana_stmt.bind(param); + } + + let mut event_stream = ws.socket.events()?; + + ws.send_request(proto::Request::Execute(proto::ExecuteReq { + stream_id: 0, + stmt: hrana_stmt, + }))?; + match WebSocketClient::recv_response(&mut event_stream).await? { + proto::ServerMsg::ResponseOk { + request_id: _, + response: proto::Response::Execute(proto::ExecuteResp { result }), + } => Ok(ResultSet::from(result)), + proto::ServerMsg::ResponseError { + request_id: _, + error, + } => Err(Error::RustError(format!("{error}"))), + _ => Err(Error::RustError("unexpected response".to_string())), + } + }, + ClientInner::Http(http) => { + let results = self.raw_batch(std::iter::once(stmt)).await?; + match (results.step_results.first(), results.step_errors.first()) { + (Some(Some(result)), Some(None)) => Ok(ResultSet::from(result.clone())), + (Some(None), Some(Some(err))) => Err(anyhow::anyhow!(err.message.clone())), + _ => unreachable!(), + } + } } } }