Skip to content

Commit e0ecd3d

Browse files
committed
wip: workers with HTTP or WebSocket
1 parent 3f085b4 commit e0ecd3d

File tree

2 files changed

+164
-95
lines changed

2 files changed

+164
-95
lines changed

Cargo.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ base64 = "0.21.0"
1515
num-traits = "0.2.15"
1616
serde_json = "1.0.91"
1717
worker = { version = "0.0.12", optional = true }
18-
spin-sdk = { version = "1.1.0", git = "https://github.com/fermyon/spin", tag = "v1.1.0", optional = true }
18+
spin-sdk = { version = "1", git = "https://github.com/fermyon/spin", tag = "v1.1.0", optional = true }
1919
http = { version = "0.2", optional = true }
2020
bytes = { version = "1.4.0", optional = true }
2121
anyhow = "1.0.69"

src/workers.rs

+163-94
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,64 @@ use worker::*;
66

77
use crate::{BatchResult, ResultSet, Statement};
88

9+
#[derive(Debug)]
10+
pub struct WebSocketClient {
11+
socket: WebSocket,
12+
next_reqid: std::sync::atomic::AtomicI32,
13+
}
14+
15+
#[derive(Debug)]
16+
pub struct HttpClient {
17+
url: String,
18+
auth: String,
19+
}
20+
21+
#[derive(Debug)]
22+
pub enum ClientInner {
23+
WebSocket(WebSocketClient),
24+
Http(HttpClient),
25+
}
26+
27+
impl WebSocketClient {
28+
fn send_request(&self, request: proto::Request) -> Result<()> {
29+
// NOTICE: we effective allow concurrency of 1 here, until we implement
30+
// id allocation andfMe request tracking
31+
let request_id = self
32+
.next_reqid
33+
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
34+
let msg = proto::ClientMsg::Request {
35+
request_id,
36+
request,
37+
};
38+
39+
self.socket.send(&msg)
40+
}
41+
42+
async fn recv_response(event_stream: &mut EventStream<'_>) -> Result<proto::ServerMsg> {
43+
use futures_util::StreamExt;
44+
45+
// NOTICE: we're effectively synchronously waiting for the response here
46+
if let Some(event) = event_stream.next().await {
47+
match event? {
48+
WebsocketEvent::Message(msg) => {
49+
let stmt_result: proto::ServerMsg = msg.json::<proto::ServerMsg>()?;
50+
Ok(stmt_result)
51+
}
52+
WebsocketEvent::Close(msg) => {
53+
Err(Error::RustError(format!("connection closed: {msg:?}")))
54+
}
55+
}
56+
} else {
57+
Err(Error::RustError("no response".to_string()))
58+
}
59+
}
60+
}
61+
962
/// Database client. This is the main structure used to
1063
/// communicate with the database.
1164
#[derive(Debug)]
1265
pub struct Client {
13-
socket: WebSocket,
14-
next_reqid: std::sync::atomic::AtomicI32,
66+
pub inner: ClientInner,
1567
}
1668

1769
impl Client {
@@ -25,15 +77,26 @@ impl Client {
2577
let url = url.into();
2678
// Auto-update the URL to start with https://.
2779
// It will get updated to wss via Workers API automatically
28-
let url = if !url.contains("://") {
29-
format!("https://{}", &url)
80+
let (url, is_websocket) = if !url.contains("://") {
81+
(format!("https://{}", &url), true)
3082
} else if let Some(url) = url.strip_prefix("libsql://") {
31-
"https://".to_owned() + url
83+
("https://".to_owned() + url, true)
3284
} else if let Some(url) = url.strip_prefix("wss://") {
33-
"https://".to_owned() + url
85+
("https://".to_owned() + url, true)
86+
} else if let Some(url) = url.strip_prefix("ws://") {
87+
("https://".to_owned() + url, true)
3488
} else {
35-
url
89+
(url, false)
3690
};
91+
92+
if !is_websocket {
93+
let inner = ClientInner::Http(HttpClient {
94+
url: url.clone(),
95+
auth: token.clone(),
96+
});
97+
return Ok(Self { inner });
98+
}
99+
37100
let url = url::Url::parse(&url)
38101
.context("Failed to parse URL")
39102
.map_err(|e| Error::from(format!("{e}")))?;
@@ -72,15 +135,16 @@ impl Client {
72135
// TODO: they could be pipelined with the first request to save latency.
73136
// For that, we need to keep the event stream open in the Client,
74137
// but that's tricky with the borrow checker.
75-
Self::recv_response(&mut event_stream).await?;
76-
Self::recv_response(&mut event_stream).await?;
138+
WebSocketClient::recv_response(&mut event_stream).await?;
139+
WebSocketClient::recv_response(&mut event_stream).await?;
77140

78141
tracing::debug!("Stream opened");
79142
drop(event_stream);
80-
Ok(Self {
143+
let inner = ClientInner::WebSocket(WebSocketClient {
81144
socket,
82145
next_reqid: std::sync::atomic::AtomicI32::new(1),
83-
})
146+
});
147+
Ok(Self { inner })
84148
}
85149

86150
/// Creates a database client from a `Config` object.
@@ -135,99 +199,104 @@ impl Client {
135199
.await
136200
}
137201

138-
fn send_request(&self, request: proto::Request) -> Result<()> {
139-
// NOTICE: we effective allow concurrency of 1 here, until we implement
140-
// id allocation andfMe request tracking
141-
let request_id = self
142-
.next_reqid
143-
.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
144-
let msg = proto::ClientMsg::Request {
145-
request_id,
146-
request,
147-
};
148-
149-
self.socket.send(&msg)
150-
}
151-
152-
async fn recv_response(event_stream: &mut EventStream<'_>) -> Result<proto::ServerMsg> {
153-
use futures_util::StreamExt;
154-
155-
// NOTICE: we're effectively synchronously waiting for the response here
156-
if let Some(event) = event_stream.next().await {
157-
match event? {
158-
WebsocketEvent::Message(msg) => {
159-
let stmt_result: proto::ServerMsg = msg.json::<proto::ServerMsg>()?;
160-
Ok(stmt_result)
161-
}
162-
WebsocketEvent::Close(msg) => {
163-
Err(Error::RustError(format!("connection closed: {msg:?}")))
164-
}
165-
}
166-
} else {
167-
Err(Error::RustError("no response".to_string()))
168-
}
169-
}
170-
171202
async fn raw_batch(
172203
&self,
173204
stmts: impl IntoIterator<Item = impl Into<Statement>>,
174205
) -> Result<BatchResult> {
175-
let mut batch = proto::Batch::new();
206+
match &self.inner {
207+
ClientInner::WebSocket(ws) => {
208+
let mut batch = proto::Batch::new();
176209

177-
for stmt in stmts.into_iter() {
178-
let stmt: Statement = stmt.into();
179-
let mut hrana_stmt = proto::Stmt::new(stmt.sql, true);
180-
for param in stmt.args {
181-
hrana_stmt.bind(param);
182-
}
183-
batch.step(None, hrana_stmt);
184-
}
210+
for stmt in stmts.into_iter() {
211+
let stmt: Statement = stmt.into();
212+
let mut hrana_stmt = proto::Stmt::new(stmt.sql, true);
213+
for param in stmt.args {
214+
hrana_stmt.bind(param);
215+
}
216+
batch.step(None, hrana_stmt);
217+
}
185218

186-
let mut event_stream = self.socket.events()?;
187-
188-
// NOTICE: if we want to support concurrent requests, we need to
189-
// actually start managing stream ids
190-
self.send_request(proto::Request::Batch(proto::BatchReq {
191-
stream_id: 0,
192-
batch,
193-
}))?;
194-
195-
match Self::recv_response(&mut event_stream).await? {
196-
proto::ServerMsg::ResponseOk {
197-
request_id: _,
198-
response: proto::Response::Batch(proto::BatchResp { result }),
199-
} => Ok(result),
200-
proto::ServerMsg::ResponseError {
201-
request_id: _,
202-
error,
203-
} => Err(Error::RustError(format!("{error}"))),
204-
_ => Err(Error::RustError("unexpected response".to_string())),
219+
let mut event_stream = ws.socket.events()?;
220+
221+
// NOTICE: if we want to support concurrent requests, we need to
222+
// actually start managing stream ids
223+
ws.send_request(proto::Request::Batch(proto::BatchReq {
224+
stream_id: 0,
225+
batch,
226+
}))?;
227+
228+
match WebSocketClient::recv_response(&mut event_stream).await? {
229+
proto::ServerMsg::ResponseOk {
230+
request_id: _,
231+
response: proto::Response::Batch(proto::BatchResp { result }),
232+
} => Ok(result),
233+
proto::ServerMsg::ResponseError {
234+
request_id: _,
235+
error,
236+
} => Err(Error::RustError(format!("{error}"))),
237+
_ => Err(Error::RustError("unexpected response".to_string())),
238+
}
239+
}
240+
ClientInner::Http(http) => {
241+
let mut headers = Headers::new();
242+
headers.append("Authorization", &http.auth).ok();
243+
let (body, stmts_count) = crate::client::statements_to_string(stmts);
244+
let request_init = RequestInit {
245+
body: Some(wasm_bindgen::JsValue::from_str(&body)),
246+
headers,
247+
cf: CfProperties::new(),
248+
method: Method::Post,
249+
redirect: RequestRedirect::Follow,
250+
};
251+
let req = Request::new_with_init(&http.url, &request_init)?;
252+
let mut response = Fetch::Request(req).send().await?;
253+
if response.status_code() != 200 {
254+
return Err(worker::Error::from(format!("{}", response.status_code())));
255+
}
256+
let resp: String = response.text().await?;
257+
let response_json: serde_json::Value = serde_json::from_str(&resp)?;
258+
crate::client::http_json_to_batch_result(response_json, stmts_count).map_err(|e| {
259+
worker::Error::from(format!("Error: {} ({:?})", e, request_init.body))
260+
})
261+
}
205262
}
206263
}
207264

208265
async fn execute(&self, stmt: impl Into<Statement>) -> Result<ResultSet> {
209-
let stmt: Statement = stmt.into();
210-
let mut hrana_stmt = proto::Stmt::new(stmt.sql, true);
211-
for param in stmt.args {
212-
hrana_stmt.bind(param);
213-
}
214-
215-
let mut event_stream = self.socket.events()?;
216-
217-
self.send_request(proto::Request::Execute(proto::ExecuteReq {
218-
stream_id: 0,
219-
stmt: hrana_stmt,
220-
}))?;
221-
match Self::recv_response(&mut event_stream).await? {
222-
proto::ServerMsg::ResponseOk {
223-
request_id: _,
224-
response: proto::Response::Execute(proto::ExecuteResp { result }),
225-
} => Ok(ResultSet::from(result)),
226-
proto::ServerMsg::ResponseError {
227-
request_id: _,
228-
error,
229-
} => Err(Error::RustError(format!("{error}"))),
230-
_ => Err(Error::RustError("unexpected response".to_string())),
266+
match &self.inner {
267+
ClientInner::WebSocket(ws) => {
268+
let stmt: Statement = stmt.into();
269+
let mut hrana_stmt = proto::Stmt::new(stmt.sql, true);
270+
for param in stmt.args {
271+
hrana_stmt.bind(param);
272+
}
273+
274+
let mut event_stream = ws.socket.events()?;
275+
276+
ws.send_request(proto::Request::Execute(proto::ExecuteReq {
277+
stream_id: 0,
278+
stmt: hrana_stmt,
279+
}))?;
280+
match WebSocketClient::recv_response(&mut event_stream).await? {
281+
proto::ServerMsg::ResponseOk {
282+
request_id: _,
283+
response: proto::Response::Execute(proto::ExecuteResp { result }),
284+
} => Ok(ResultSet::from(result)),
285+
proto::ServerMsg::ResponseError {
286+
request_id: _,
287+
error,
288+
} => Err(Error::RustError(format!("{error}"))),
289+
_ => Err(Error::RustError("unexpected response".to_string())),
290+
}
291+
},
292+
ClientInner::Http(http) => {
293+
let results = self.raw_batch(std::iter::once(stmt)).await?;
294+
match (results.step_results.first(), results.step_errors.first()) {
295+
(Some(Some(result)), Some(None)) => Ok(ResultSet::from(result.clone())),
296+
(Some(None), Some(Some(err))) => Err(anyhow::anyhow!(err.message.clone())),
297+
_ => unreachable!(),
298+
}
299+
}
231300
}
232301
}
233302
}

0 commit comments

Comments
 (0)