Skip to content

Commit c857ccc

Browse files
committed
Implemented automatic websocket reloading, acking messages and updating state,
and some styles
1 parent bd20f77 commit c857ccc

File tree

10 files changed

+733
-4643
lines changed

10 files changed

+733
-4643
lines changed

crates/turborepo-ui/src/wui/mod.rs

+90-9
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//! Web UI for Turborepo. Creates a WebSocket server that can be subscribed to
22
//! by a web client to display the status of tasks.
33
4-
use std::io::Write;
4+
use std::{collections::HashSet, io::Write};
55

66
use axum::{
77
extract::{
@@ -13,8 +13,9 @@ use axum::{
1313
routing::get,
1414
Router,
1515
};
16-
use serde::Serialize;
16+
use serde::{Deserialize, Serialize};
1717
use thiserror::Error;
18+
use tokio::select;
1819
use tower_http::cors::{Any, CorsLayer};
1920
use tracing::log::warn;
2021

@@ -29,6 +30,10 @@ pub enum Error {
2930
Server(#[from] std::io::Error),
3031
#[error("failed to start websocket server: {0}")]
3132
WebSocket(#[source] axum::Error),
33+
#[error("failed to serialize message: {0}")]
34+
Serde(#[from] serde_json::Error),
35+
#[error("failed to send message")]
36+
Send(#[from] axum::Error),
3237
}
3338

3439
#[derive(Clone)]
@@ -98,6 +103,7 @@ impl UISender for WebUISender {
98103
// Specific events that the websocket server can send to the client,
99104
// not all the `Event` types from the TUI
100105
#[derive(Debug, Clone, Serialize)]
106+
#[serde(tag = "type", content = "payload")]
101107
pub enum WebUIEvent {
102108
StartTask {
103109
task: String,
@@ -122,14 +128,37 @@ pub enum WebUIEvent {
122128
Stop,
123129
}
124130

131+
#[derive(Debug, Clone, Serialize)]
132+
pub struct ServerMessage<'a> {
133+
pub id: u32,
134+
#[serde(flatten)]
135+
pub payload: &'a WebUIEvent,
136+
}
137+
138+
#[derive(Debug, Clone, Serialize, Deserialize)]
139+
#[serde(tag = "type", content = "payload")]
140+
pub enum ClientMessage {
141+
/// Acknowledges the receipt of a message.
142+
/// If we don't receive an ack, we will resend the message
143+
Ack { id: u32 },
144+
/// Asks for all messages from the given id onwards
145+
CatchUp { start_id: u32 },
146+
}
147+
125148
struct AppState {
126149
rx: tokio::sync::broadcast::Receiver<WebUIEvent>,
150+
acks: HashSet<u32>,
151+
messages: Vec<(WebUIEvent, u32)>,
152+
current_id: u32,
127153
}
128154

129155
impl Clone for AppState {
130156
fn clone(&self) -> Self {
131157
Self {
132158
rx: self.rx.resubscribe(),
159+
acks: self.acks.clone(),
160+
messages: self.messages.clone(),
161+
current_id: self.current_id,
133162
}
134163
}
135164
}
@@ -138,13 +167,60 @@ async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl In
138167
ws.on_upgrade(|socket| handle_socket(socket, state))
139168
}
140169

141-
async fn handle_socket(mut socket: WebSocket, state: AppState) {
170+
async fn handle_socket(socket: WebSocket, state: AppState) {
171+
if let Err(e) = handle_socket_inner(socket, state).await {
172+
warn!("error handling socket: {e}");
173+
}
174+
}
175+
176+
async fn handle_socket_inner(mut socket: WebSocket, state: AppState) -> Result<(), Error> {
142177
let mut state = state.clone();
143-
while let Ok(event) = state.rx.recv().await {
144-
let message_payload = serde_json::to_string(&event).unwrap();
145-
if socket.send(Message::Text(message_payload)).await.is_err() {
146-
// client disconnected
147-
return;
178+
let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
179+
loop {
180+
select! {
181+
biased;
182+
Ok(event) = state.rx.recv() => {
183+
let id = state.current_id;
184+
state.current_id += 1;
185+
let message_payload = serde_json::to_string(&ServerMessage {
186+
id,
187+
payload: &event
188+
})?;
189+
190+
state.messages.push((event, id));
191+
println!("1");
192+
socket.send(Message::Text(message_payload)).await?;
193+
}
194+
// Every 100ms, check if we need to resend any messages
195+
_ = interval.tick() => {
196+
for (event, id) in &state.messages {
197+
if !state.acks.contains(&id) {
198+
let message_payload = serde_json::to_string(event).unwrap();
199+
println!("2");
200+
socket.send(Message::Text(message_payload)).await?;
201+
}
202+
};
203+
}
204+
message = socket.recv() => {
205+
if let Some(Ok(message)) = message {
206+
let message_payload = message.into_text()?;
207+
if message_payload.is_empty() {
208+
continue;
209+
}
210+
if let Ok(event) = serde_json::from_str::<ClientMessage>(&message_payload) {
211+
match event {
212+
ClientMessage::Ack { id } => {
213+
state.acks.insert(id);
214+
}
215+
ClientMessage::CatchUp { start_id } => {
216+
// TODO: implement
217+
}
218+
}
219+
} else {
220+
warn!("failed to deserialize message from client: {message_payload}");
221+
}
222+
}
223+
},
148224
}
149225
}
150226
}
@@ -161,7 +237,12 @@ pub async fn start_ws_server(
161237
let app = Router::new()
162238
.route("/ws", get(handler))
163239
.layer(cors)
164-
.with_state(AppState { rx });
240+
.with_state(AppState {
241+
rx,
242+
acks: HashSet::new(),
243+
messages: Vec::new(),
244+
current_id: 0,
245+
});
165246

166247
let listener = tokio::net::TcpListener::bind("127.0.0.1:1337").await?;
167248
println!("Web UI listening on port 1337");

0 commit comments

Comments
 (0)