1+ use serde:: { Deserialize , Serialize } ;
2+ use std:: collections:: HashMap ;
3+ use std:: sync:: Arc ;
4+ use tokio:: sync:: RwLock ;
5+ use uuid:: Uuid ;
6+
7+ use super :: websocket:: { ProgressWebSocket , WebSocketConfig } ;
8+
9+ #[ derive( Debug , Clone , PartialEq , Serialize , Deserialize ) ]
10+ #[ serde( rename_all = "snake_case" ) ]
11+ pub enum ChannelType {
12+ Upload ,
13+ Analysis ,
14+ Batch ,
15+ User ,
16+ }
17+
18+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
19+ #[ serde( tag = "type" , rename_all = "snake_case" ) ]
20+ pub enum ProgressEvent {
21+ Upload ( UploadProgressEvent ) ,
22+ Analysis ( AnalysisProgressEvent ) ,
23+ Batch ( BatchProgressEvent ) ,
24+ Error ( ErrorEvent ) ,
25+ Complete ( CompleteEvent ) ,
26+ }
27+
28+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
29+ pub struct UploadProgressEvent {
30+ pub document_id : Uuid ,
31+ pub file_name : String ,
32+ pub bytes_uploaded : u64 ,
33+ pub total_bytes : u64 ,
34+ pub progress : u8 ,
35+ pub message : String ,
36+ }
37+
38+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
39+ pub struct AnalysisProgressEvent {
40+ pub analysis_id : Uuid ,
41+ pub document_id : Uuid ,
42+ pub stage : AnalysisStage ,
43+ pub progress : u8 ,
44+ pub message : String ,
45+ pub details : Option < serde_json:: Value > ,
46+ }
47+
48+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
49+ #[ serde( rename_all = "snake_case" ) ]
50+ pub enum AnalysisStage {
51+ Queued ,
52+ Starting ,
53+ ExtractingText ,
54+ ChunkingText ,
55+ GeneratingEmbeddings ,
56+ AnalyzingContent ,
57+ Finalizing ,
58+ Completed ,
59+ }
60+
61+ impl AnalysisStage {
62+ pub fn display_name ( & self ) -> & str {
63+ match self {
64+ Self :: Queued => "Queued" ,
65+ Self :: Starting => "Starting" ,
66+ Self :: ExtractingText => "Extracting Text" ,
67+ Self :: ChunkingText => "Chunking Text" ,
68+ Self :: GeneratingEmbeddings => "Generating Embeddings" ,
69+ Self :: AnalyzingContent => "Analyzing Content" ,
70+ Self :: Finalizing => "Finalizing" ,
71+ Self :: Completed => "Completed" ,
72+ }
73+ }
74+
75+ pub fn is_terminal ( & self ) -> bool {
76+ matches ! ( self , Self :: Completed )
77+ }
78+ }
79+
80+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
81+ pub struct BatchProgressEvent {
82+ pub batch_id : Uuid ,
83+ pub total_files : usize ,
84+ pub completed_files : usize ,
85+ pub current_file : Option < String > ,
86+ pub overall_progress : u8 ,
87+ pub file_progress : HashMap < Uuid , FileProgress > ,
88+ }
89+
90+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
91+ pub struct FileProgress {
92+ pub document_id : Uuid ,
93+ pub file_name : String ,
94+ pub status : FileStatus ,
95+ pub progress : u8 ,
96+ pub message : String ,
97+ }
98+
99+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
100+ #[ serde( rename_all = "snake_case" ) ]
101+ pub enum FileStatus {
102+ Pending ,
103+ Uploading ,
104+ Processing ,
105+ Completed ,
106+ Failed ,
107+ }
108+
109+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
110+ pub struct ErrorEvent {
111+ pub id : Uuid ,
112+ pub error_type : ErrorType ,
113+ pub message : String ,
114+ pub details : Option < serde_json:: Value > ,
115+ }
116+
117+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
118+ #[ serde( rename_all = "snake_case" ) ]
119+ pub enum ErrorType {
120+ Upload ,
121+ Analysis ,
122+ System ,
123+ Network ,
124+ }
125+
126+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
127+ pub struct CompleteEvent {
128+ pub id : Uuid ,
129+ pub event_type : CompleteEventType ,
130+ pub message : String ,
131+ pub result : Option < serde_json:: Value > ,
132+ }
133+
134+ #[ derive( Debug , Clone , Serialize , Deserialize ) ]
135+ #[ serde( rename_all = "snake_case" ) ]
136+ pub enum CompleteEventType {
137+ Upload ,
138+ Analysis ,
139+ Batch ,
140+ }
141+
142+ pub struct ProgressTracker {
143+ events : Arc < RwLock < HashMap < Uuid , Vec < ProgressEvent > > > > ,
144+ websocket : Arc < RwLock < ProgressWebSocket > > ,
145+ active_subscriptions : Arc < RwLock < HashMap < Uuid , ChannelType > > > ,
146+ }
147+
148+ impl ProgressTracker {
149+ pub fn new ( config : WebSocketConfig , token : String ) -> Self {
150+ let websocket = ProgressWebSocket :: new ( config, token) ;
151+
152+ Self {
153+ events : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
154+ websocket : Arc :: new ( RwLock :: new ( websocket) ) ,
155+ active_subscriptions : Arc :: new ( RwLock :: new ( HashMap :: new ( ) ) ) ,
156+ }
157+ }
158+
159+ pub async fn connect ( & self ) -> anyhow:: Result < ( ) > {
160+ let mut ws = self . websocket . write ( ) . await ;
161+ ws. connect ( ) . await
162+ }
163+
164+ pub async fn track_upload ( & self , document_id : Uuid ) -> anyhow:: Result < ( ) > {
165+ let mut ws = self . websocket . write ( ) . await ;
166+ ws. subscribe_upload ( document_id) . await ?;
167+
168+ self . active_subscriptions
169+ . write ( )
170+ . await
171+ . insert ( document_id, ChannelType :: Upload ) ;
172+
173+ Ok ( ( ) )
174+ }
175+
176+ pub async fn track_analysis ( & self , analysis_id : Uuid ) -> anyhow:: Result < ( ) > {
177+ let mut ws = self . websocket . write ( ) . await ;
178+ ws. subscribe_analysis ( analysis_id) . await ?;
179+
180+ self . active_subscriptions
181+ . write ( )
182+ . await
183+ . insert ( analysis_id, ChannelType :: Analysis ) ;
184+
185+ Ok ( ( ) )
186+ }
187+
188+ pub async fn track_batch ( & self , batch_id : Uuid ) -> anyhow:: Result < ( ) > {
189+ let mut ws = self . websocket . write ( ) . await ;
190+ ws. subscribe_batch ( batch_id) . await ?;
191+
192+ self . active_subscriptions
193+ . write ( )
194+ . await
195+ . insert ( batch_id, ChannelType :: Batch ) ;
196+
197+ Ok ( ( ) )
198+ }
199+
200+ pub async fn get_events ( & self , id : Uuid ) -> Vec < ProgressEvent > {
201+ let events = self . events . read ( ) . await ;
202+ events. get ( & id) . cloned ( ) . unwrap_or_default ( )
203+ }
204+
205+ pub async fn get_latest_event ( & self , id : Uuid ) -> Option < ProgressEvent > {
206+ let events = self . events . read ( ) . await ;
207+ events. get ( & id) . and_then ( |e| e. last ( ) . cloned ( ) )
208+ }
209+
210+ pub async fn process_events ( & self ) -> anyhow:: Result < ( ) > {
211+ let mut ws = self . websocket . write ( ) . await ;
212+
213+ while let Some ( event) = ws. next_event ( ) . await {
214+ let id = match & event {
215+ ProgressEvent :: Upload ( e) => e. document_id ,
216+ ProgressEvent :: Analysis ( e) => e. analysis_id ,
217+ ProgressEvent :: Batch ( e) => e. batch_id ,
218+ ProgressEvent :: Error ( e) => e. id ,
219+ ProgressEvent :: Complete ( e) => e. id ,
220+ } ;
221+
222+ // Store event
223+ let mut events = self . events . write ( ) . await ;
224+ events. entry ( id) . or_insert_with ( Vec :: new) . push ( event. clone ( ) ) ;
225+
226+ // Clean up completed subscriptions
227+ if matches ! ( & event, ProgressEvent :: Complete ( _) | ProgressEvent :: Error ( _) ) {
228+ if let Some ( channel_type) = self . active_subscriptions . read ( ) . await . get ( & id) {
229+ ws. unsubscribe ( channel_type. clone ( ) , id) . await ?;
230+ self . active_subscriptions . write ( ) . await . remove ( & id) ;
231+ }
232+ }
233+ }
234+
235+ Ok ( ( ) )
236+ }
237+
238+ pub async fn start_processing ( self : Arc < Self > ) {
239+ tokio:: spawn ( async move {
240+ loop {
241+ if let Err ( e) = self . process_events ( ) . await {
242+ tracing:: error!( "Error processing events: {}" , e) ;
243+
244+ // Try to reconnect
245+ let mut ws = self . websocket . write ( ) . await ;
246+ if !ws. is_connected ( ) . await {
247+ if let Err ( e) = ws. handle_reconnect ( ) . await {
248+ tracing:: error!( "Failed to reconnect: {}" , e) ;
249+ tokio:: time:: sleep ( tokio:: time:: Duration :: from_secs ( 5 ) ) . await ;
250+ }
251+ }
252+ }
253+ }
254+ } ) ;
255+ }
256+
257+ pub async fn disconnect ( & self ) {
258+ let mut ws = self . websocket . write ( ) . await ;
259+ ws. disconnect ( ) . await ;
260+ }
261+ }
0 commit comments