Skip to content

Commit 5825633

Browse files
committed
feat: Add WebSocket support and batch document operations
- Introduce WebSocket functionality for real-time progress tracking during document uploads and analysis. - Implement batch operations for uploading multiple documents with progress feedback. - Enhance CLI with new batch command for managing document uploads and status checks. - Update API client to support filename overrides and improved document upload handling. - Add progress tracking utilities for better user experience during long-running operations.
1 parent 8b6a0d9 commit 5825633

File tree

12 files changed

+1468
-25
lines changed

12 files changed

+1468
-25
lines changed

Cargo.toml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ uuid = { version = "1.8", features = ["v4", "serde"] }
5757
base64 = "0.22"
5858
url = "2.5"
5959
comfy-table = "7.1"
60+
glob = "0.3"
6061

6162
# Security
6263
rpassword = "7.3"
@@ -65,6 +66,11 @@ sha2 = "0.10"
6566
# System utils
6667
open = "5.0"
6768

69+
# WebSocket support
70+
tokio-tungstenite = "0.23"
71+
futures-util = "0.3"
72+
backoff = { version = "0.4", features = ["tokio"] }
73+
6874
[dev-dependencies]
6975
mockito = "1.4"
7076
tempfile = "3.10"

src/api.rs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
pub mod analysis;
22
pub mod documents;
3+
pub mod progress;
4+
pub mod websocket;
35

46
use crate::auth::AuthManager;
57
use crate::config::Config;
@@ -51,7 +53,7 @@ impl ApiClient {
5153
println!("📤 Uploading document...");
5254
let document = self
5355
.document_client
54-
.upload_document(file_path, &token, category, None)
56+
.upload_document(file_path, &token, category, None, None)
5557
.await?;
5658

5759
// Start analysis
@@ -187,6 +189,7 @@ impl ApiClient {
187189
file_path: &Path,
188190
category: Option<DocumentCategory>,
189191
description: Option<String>,
192+
filename_override: Option<String>,
190193
) -> Result<DocumentResponse> {
191194
let token = self
192195
.auth_manager
@@ -195,7 +198,7 @@ impl ApiClient {
195198
.context("Authentication required. Please run 'kanuni auth login' first.")?;
196199

197200
self.document_client
198-
.upload_document(file_path, &token, category, description)
201+
.upload_document(file_path, &token, category, description, filename_override)
199202
.await
200203
}
201204
}

src/api/documents.rs

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use uuid::Uuid;
1010
#[derive(Debug, Clone, Serialize)]
1111
pub struct UploadDocumentRequest {
1212
pub filename: String,
13+
pub filename_override: Option<String>, // Optional filename override
1314
pub category: Option<DocumentCategory>,
1415
pub description: Option<String>,
1516
pub tags: Option<Vec<String>>,
@@ -130,6 +131,7 @@ impl DocumentClient {
130131
token: &str,
131132
category: Option<DocumentCategory>,
132133
description: Option<String>,
134+
filename_override: Option<String>,
133135
) -> Result<DocumentResponse> {
134136
// Read file metadata
135137
let metadata = fs::metadata(file_path).context("Failed to read file metadata")?;
@@ -157,6 +159,7 @@ impl DocumentClient {
157159
// Step 1: Request upload URL
158160
let upload_request = UploadDocumentRequest {
159161
filename: filename.clone(),
162+
filename_override, // Use the provided filename override
160163
category,
161164
description,
162165
tags: None,
@@ -455,7 +458,24 @@ impl DocumentClient {
455458
let output_file = if let Some(path) = output_path {
456459
path.to_path_buf()
457460
} else {
458-
Path::new(&document.filename).to_path_buf()
461+
// Add extension based on mime_type if filename doesn't have one
462+
let mut filename = document.filename.clone();
463+
if !filename.contains('.') {
464+
let extension = match document.mime_type.as_str() {
465+
"application/pdf" => ".pdf",
466+
"application/vnd.openxmlformats-officedocument.wordprocessingml.document" => ".docx",
467+
"application/msword" => ".doc",
468+
"text/plain" => ".txt",
469+
"text/rtf" | "application/rtf" => ".rtf",
470+
"image/png" => ".png",
471+
"image/jpeg" => ".jpg",
472+
_ => "",
473+
};
474+
if !extension.is_empty() {
475+
filename.push_str(extension);
476+
}
477+
}
478+
Path::new(&filename).to_path_buf()
459479
};
460480

461481
// Save to file

src/api/progress.rs

Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
use serde::{Deserialize, Serialize};
2+
use std::collections::HashMap;
3+
use std::sync::Arc;
4+
use tokio::sync::RwLock;
5+
use uuid::Uuid;
6+
7+
use super::websocket::{ProgressWebSocket, WebSocketConfig};
8+
9+
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10+
#[serde(rename_all = "snake_case")]
11+
pub enum ChannelType {
12+
Upload,
13+
Analysis,
14+
Batch,
15+
User,
16+
}
17+
18+
#[derive(Debug, Clone, Serialize, Deserialize)]
19+
#[serde(tag = "type", rename_all = "snake_case")]
20+
pub enum ProgressEvent {
21+
Upload(UploadProgressEvent),
22+
Analysis(AnalysisProgressEvent),
23+
Batch(BatchProgressEvent),
24+
Error(ErrorEvent),
25+
Complete(CompleteEvent),
26+
}
27+
28+
#[derive(Debug, Clone, Serialize, Deserialize)]
29+
pub struct UploadProgressEvent {
30+
pub document_id: Uuid,
31+
pub file_name: String,
32+
pub bytes_uploaded: u64,
33+
pub total_bytes: u64,
34+
pub progress: u8,
35+
pub message: String,
36+
}
37+
38+
#[derive(Debug, Clone, Serialize, Deserialize)]
39+
pub struct AnalysisProgressEvent {
40+
pub analysis_id: Uuid,
41+
pub document_id: Uuid,
42+
pub stage: AnalysisStage,
43+
pub progress: u8,
44+
pub message: String,
45+
pub details: Option<serde_json::Value>,
46+
}
47+
48+
#[derive(Debug, Clone, Serialize, Deserialize)]
49+
#[serde(rename_all = "snake_case")]
50+
pub enum AnalysisStage {
51+
Queued,
52+
Starting,
53+
ExtractingText,
54+
ChunkingText,
55+
GeneratingEmbeddings,
56+
AnalyzingContent,
57+
Finalizing,
58+
Completed,
59+
}
60+
61+
impl AnalysisStage {
62+
pub fn display_name(&self) -> &str {
63+
match self {
64+
Self::Queued => "Queued",
65+
Self::Starting => "Starting",
66+
Self::ExtractingText => "Extracting Text",
67+
Self::ChunkingText => "Chunking Text",
68+
Self::GeneratingEmbeddings => "Generating Embeddings",
69+
Self::AnalyzingContent => "Analyzing Content",
70+
Self::Finalizing => "Finalizing",
71+
Self::Completed => "Completed",
72+
}
73+
}
74+
75+
pub fn is_terminal(&self) -> bool {
76+
matches!(self, Self::Completed)
77+
}
78+
}
79+
80+
#[derive(Debug, Clone, Serialize, Deserialize)]
81+
pub struct BatchProgressEvent {
82+
pub batch_id: Uuid,
83+
pub total_files: usize,
84+
pub completed_files: usize,
85+
pub current_file: Option<String>,
86+
pub overall_progress: u8,
87+
pub file_progress: HashMap<Uuid, FileProgress>,
88+
}
89+
90+
#[derive(Debug, Clone, Serialize, Deserialize)]
91+
pub struct FileProgress {
92+
pub document_id: Uuid,
93+
pub file_name: String,
94+
pub status: FileStatus,
95+
pub progress: u8,
96+
pub message: String,
97+
}
98+
99+
#[derive(Debug, Clone, Serialize, Deserialize)]
100+
#[serde(rename_all = "snake_case")]
101+
pub enum FileStatus {
102+
Pending,
103+
Uploading,
104+
Processing,
105+
Completed,
106+
Failed,
107+
}
108+
109+
#[derive(Debug, Clone, Serialize, Deserialize)]
110+
pub struct ErrorEvent {
111+
pub id: Uuid,
112+
pub error_type: ErrorType,
113+
pub message: String,
114+
pub details: Option<serde_json::Value>,
115+
}
116+
117+
#[derive(Debug, Clone, Serialize, Deserialize)]
118+
#[serde(rename_all = "snake_case")]
119+
pub enum ErrorType {
120+
Upload,
121+
Analysis,
122+
System,
123+
Network,
124+
}
125+
126+
#[derive(Debug, Clone, Serialize, Deserialize)]
127+
pub struct CompleteEvent {
128+
pub id: Uuid,
129+
pub event_type: CompleteEventType,
130+
pub message: String,
131+
pub result: Option<serde_json::Value>,
132+
}
133+
134+
#[derive(Debug, Clone, Serialize, Deserialize)]
135+
#[serde(rename_all = "snake_case")]
136+
pub enum CompleteEventType {
137+
Upload,
138+
Analysis,
139+
Batch,
140+
}
141+
142+
pub struct ProgressTracker {
143+
events: Arc<RwLock<HashMap<Uuid, Vec<ProgressEvent>>>>,
144+
websocket: Arc<RwLock<ProgressWebSocket>>,
145+
active_subscriptions: Arc<RwLock<HashMap<Uuid, ChannelType>>>,
146+
}
147+
148+
impl ProgressTracker {
149+
pub fn new(config: WebSocketConfig, token: String) -> Self {
150+
let websocket = ProgressWebSocket::new(config, token);
151+
152+
Self {
153+
events: Arc::new(RwLock::new(HashMap::new())),
154+
websocket: Arc::new(RwLock::new(websocket)),
155+
active_subscriptions: Arc::new(RwLock::new(HashMap::new())),
156+
}
157+
}
158+
159+
pub async fn connect(&self) -> anyhow::Result<()> {
160+
let mut ws = self.websocket.write().await;
161+
ws.connect().await
162+
}
163+
164+
pub async fn track_upload(&self, document_id: Uuid) -> anyhow::Result<()> {
165+
let mut ws = self.websocket.write().await;
166+
ws.subscribe_upload(document_id).await?;
167+
168+
self.active_subscriptions
169+
.write()
170+
.await
171+
.insert(document_id, ChannelType::Upload);
172+
173+
Ok(())
174+
}
175+
176+
pub async fn track_analysis(&self, analysis_id: Uuid) -> anyhow::Result<()> {
177+
let mut ws = self.websocket.write().await;
178+
ws.subscribe_analysis(analysis_id).await?;
179+
180+
self.active_subscriptions
181+
.write()
182+
.await
183+
.insert(analysis_id, ChannelType::Analysis);
184+
185+
Ok(())
186+
}
187+
188+
pub async fn track_batch(&self, batch_id: Uuid) -> anyhow::Result<()> {
189+
let mut ws = self.websocket.write().await;
190+
ws.subscribe_batch(batch_id).await?;
191+
192+
self.active_subscriptions
193+
.write()
194+
.await
195+
.insert(batch_id, ChannelType::Batch);
196+
197+
Ok(())
198+
}
199+
200+
pub async fn get_events(&self, id: Uuid) -> Vec<ProgressEvent> {
201+
let events = self.events.read().await;
202+
events.get(&id).cloned().unwrap_or_default()
203+
}
204+
205+
pub async fn get_latest_event(&self, id: Uuid) -> Option<ProgressEvent> {
206+
let events = self.events.read().await;
207+
events.get(&id).and_then(|e| e.last().cloned())
208+
}
209+
210+
pub async fn process_events(&self) -> anyhow::Result<()> {
211+
let mut ws = self.websocket.write().await;
212+
213+
while let Some(event) = ws.next_event().await {
214+
let id = match &event {
215+
ProgressEvent::Upload(e) => e.document_id,
216+
ProgressEvent::Analysis(e) => e.analysis_id,
217+
ProgressEvent::Batch(e) => e.batch_id,
218+
ProgressEvent::Error(e) => e.id,
219+
ProgressEvent::Complete(e) => e.id,
220+
};
221+
222+
// Store event
223+
let mut events = self.events.write().await;
224+
events.entry(id).or_insert_with(Vec::new).push(event.clone());
225+
226+
// Clean up completed subscriptions
227+
if matches!(&event, ProgressEvent::Complete(_) | ProgressEvent::Error(_)) {
228+
if let Some(channel_type) = self.active_subscriptions.read().await.get(&id) {
229+
ws.unsubscribe(channel_type.clone(), id).await?;
230+
self.active_subscriptions.write().await.remove(&id);
231+
}
232+
}
233+
}
234+
235+
Ok(())
236+
}
237+
238+
pub async fn start_processing(self: Arc<Self>) {
239+
tokio::spawn(async move {
240+
loop {
241+
if let Err(e) = self.process_events().await {
242+
tracing::error!("Error processing events: {}", e);
243+
244+
// Try to reconnect
245+
let mut ws = self.websocket.write().await;
246+
if !ws.is_connected().await {
247+
if let Err(e) = ws.handle_reconnect().await {
248+
tracing::error!("Failed to reconnect: {}", e);
249+
tokio::time::sleep(tokio::time::Duration::from_secs(5)).await;
250+
}
251+
}
252+
}
253+
}
254+
});
255+
}
256+
257+
pub async fn disconnect(&self) {
258+
let mut ws = self.websocket.write().await;
259+
ws.disconnect().await;
260+
}
261+
}

0 commit comments

Comments
 (0)