1
1
//! Web UI for Turborepo. Creates a WebSocket server that can be subscribed to
2
2
//! by a web client to display the status of tasks.
3
3
4
- use std:: io:: Write ;
4
+ use std:: { collections :: HashSet , io:: Write } ;
5
5
6
6
use axum:: {
7
7
extract:: {
@@ -13,8 +13,9 @@ use axum::{
13
13
routing:: get,
14
14
Router ,
15
15
} ;
16
- use serde:: Serialize ;
16
+ use serde:: { Deserialize , Serialize } ;
17
17
use thiserror:: Error ;
18
+ use tokio:: select;
18
19
use tower_http:: cors:: { Any , CorsLayer } ;
19
20
use tracing:: log:: warn;
20
21
@@ -29,6 +30,10 @@ pub enum Error {
29
30
Server ( #[ from] std:: io:: Error ) ,
30
31
#[ error( "failed to start websocket server: {0}" ) ]
31
32
WebSocket ( #[ source] axum:: Error ) ,
33
+ #[ error( "failed to serialize message: {0}" ) ]
34
+ Serde ( #[ from] serde_json:: Error ) ,
35
+ #[ error( "failed to send message" ) ]
36
+ Send ( #[ from] axum:: Error ) ,
32
37
}
33
38
34
39
#[ derive( Clone ) ]
@@ -98,6 +103,7 @@ impl UISender for WebUISender {
98
103
// Specific events that the websocket server can send to the client,
99
104
// not all the `Event` types from the TUI
100
105
#[ derive( Debug , Clone , Serialize ) ]
106
+ #[ serde( tag = "type" , content = "payload" ) ]
101
107
pub enum WebUIEvent {
102
108
StartTask {
103
109
task : String ,
@@ -122,14 +128,37 @@ pub enum WebUIEvent {
122
128
Stop ,
123
129
}
124
130
131
+ #[ derive( Debug , Clone , Serialize ) ]
132
+ pub struct ServerMessage < ' a > {
133
+ pub id : u32 ,
134
+ #[ serde( flatten) ]
135
+ pub payload : & ' a WebUIEvent ,
136
+ }
137
+
138
+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
139
+ #[ serde( tag = "type" , content = "payload" ) ]
140
+ pub enum ClientMessage {
141
+ /// Acknowledges the receipt of a message.
142
+ /// If we don't receive an ack, we will resend the message
143
+ Ack { id : u32 } ,
144
+ /// Asks for all messages from the given id onwards
145
+ CatchUp { start_id : u32 } ,
146
+ }
147
+
125
148
struct AppState {
126
149
rx : tokio:: sync:: broadcast:: Receiver < WebUIEvent > ,
150
+ acks : HashSet < u32 > ,
151
+ messages : Vec < ( WebUIEvent , u32 ) > ,
152
+ current_id : u32 ,
127
153
}
128
154
129
155
impl Clone for AppState {
130
156
fn clone ( & self ) -> Self {
131
157
Self {
132
158
rx : self . rx . resubscribe ( ) ,
159
+ acks : self . acks . clone ( ) ,
160
+ messages : self . messages . clone ( ) ,
161
+ current_id : self . current_id ,
133
162
}
134
163
}
135
164
}
@@ -138,13 +167,60 @@ async fn handler(ws: WebSocketUpgrade, State(state): State<AppState>) -> impl In
138
167
ws. on_upgrade ( |socket| handle_socket ( socket, state) )
139
168
}
140
169
141
- async fn handle_socket ( mut socket : WebSocket , state : AppState ) {
170
+ async fn handle_socket ( socket : WebSocket , state : AppState ) {
171
+ if let Err ( e) = handle_socket_inner ( socket, state) . await {
172
+ warn ! ( "error handling socket: {e}" ) ;
173
+ }
174
+ }
175
+
176
+ async fn handle_socket_inner ( mut socket : WebSocket , state : AppState ) -> Result < ( ) , Error > {
142
177
let mut state = state. clone ( ) ;
143
- while let Ok ( event) = state. rx . recv ( ) . await {
144
- let message_payload = serde_json:: to_string ( & event) . unwrap ( ) ;
145
- if socket. send ( Message :: Text ( message_payload) ) . await . is_err ( ) {
146
- // client disconnected
147
- return ;
178
+ let mut interval = tokio:: time:: interval ( std:: time:: Duration :: from_millis ( 100 ) ) ;
179
+ loop {
180
+ select ! {
181
+ biased;
182
+ Ok ( event) = state. rx. recv( ) => {
183
+ let id = state. current_id;
184
+ state. current_id += 1 ;
185
+ let message_payload = serde_json:: to_string( & ServerMessage {
186
+ id,
187
+ payload: & event
188
+ } ) ?;
189
+
190
+ state. messages. push( ( event, id) ) ;
191
+ println!( "1" ) ;
192
+ socket. send( Message :: Text ( message_payload) ) . await ?;
193
+ }
194
+ // Every 100ms, check if we need to resend any messages
195
+ _ = interval. tick( ) => {
196
+ for ( event, id) in & state. messages {
197
+ if !state. acks. contains( & id) {
198
+ let message_payload = serde_json:: to_string( event) . unwrap( ) ;
199
+ println!( "2" ) ;
200
+ socket. send( Message :: Text ( message_payload) ) . await ?;
201
+ }
202
+ } ;
203
+ }
204
+ message = socket. recv( ) => {
205
+ if let Some ( Ok ( message) ) = message {
206
+ let message_payload = message. into_text( ) ?;
207
+ if message_payload. is_empty( ) {
208
+ continue ;
209
+ }
210
+ if let Ok ( event) = serde_json:: from_str:: <ClientMessage >( & message_payload) {
211
+ match event {
212
+ ClientMessage :: Ack { id } => {
213
+ state. acks. insert( id) ;
214
+ }
215
+ ClientMessage :: CatchUp { start_id } => {
216
+ // TODO: implement
217
+ }
218
+ }
219
+ } else {
220
+ warn!( "failed to deserialize message from client: {message_payload}" ) ;
221
+ }
222
+ }
223
+ } ,
148
224
}
149
225
}
150
226
}
@@ -161,7 +237,12 @@ pub async fn start_ws_server(
161
237
let app = Router :: new ( )
162
238
. route ( "/ws" , get ( handler) )
163
239
. layer ( cors)
164
- . with_state ( AppState { rx } ) ;
240
+ . with_state ( AppState {
241
+ rx,
242
+ acks : HashSet :: new ( ) ,
243
+ messages : Vec :: new ( ) ,
244
+ current_id : 0 ,
245
+ } ) ;
165
246
166
247
let listener = tokio:: net:: TcpListener :: bind ( "127.0.0.1:1337" ) . await ?;
167
248
println ! ( "Web UI listening on port 1337" ) ;
0 commit comments