From 46f053e806d37da0991b1a40a76651205a7480c8 Mon Sep 17 00:00:00 2001 From: wgqqqqq Date: Mon, 25 May 2026 11:08:49 +0800 Subject: [PATCH] Fix remote workspace task crash --- .../src/service/remote_ssh/workspace_state.rs | 16 +- .../src/remote_connect.rs | 155 ++++++++++++++---- 2 files changed, 136 insertions(+), 35 deletions(-) diff --git a/src/crates/core/src/service/remote_ssh/workspace_state.rs b/src/crates/core/src/service/remote_ssh/workspace_state.rs index f1eb4efab..9c5d8e7c7 100644 --- a/src/crates/core/src/service/remote_ssh/workspace_state.rs +++ b/src/crates/core/src/service/remote_ssh/workspace_state.rs @@ -348,21 +348,25 @@ impl RemoteWorkspaceStateManager { ) -> PathBuf { let remote_id = remote_connection_id .map(str::trim) - .filter(|s| !s.is_empty()); - if remote_id.is_none() { + .filter(|s| !s.is_empty()) + .map(str::to_string); + let Some(remote_id) = remote_id else { return PathBuf::from(workspace_path); - } + }; let path_norm = normalize_remote_workspace_path(workspace_path); if let Some(host) = remote_ssh_host.map(str::trim).filter(|s| !s.is_empty()) { return remote_workspace_session_mirror_dir(host, &path_norm); } - if let Some(entry) = self.lookup_connection(workspace_path, remote_id).await { + if let Some(entry) = self + .lookup_connection(workspace_path, Some(remote_id.as_str())) + .await + { if !entry.ssh_host.trim().is_empty() { return remote_workspace_session_mirror_dir(&entry.ssh_host, &entry.remote_root); } - return unresolved_remote_session_storage_dir(remote_id.unwrap(), &path_norm); + return unresolved_remote_session_storage_dir(&remote_id, &path_norm); } - unresolved_remote_session_storage_dir(remote_id.unwrap(), &path_norm) + unresolved_remote_session_storage_dir(&remote_id, &path_norm) } } diff --git a/src/crates/services-integrations/src/remote_connect.rs b/src/crates/services-integrations/src/remote_connect.rs index 65d224f8a..47a3b63ec 100644 --- a/src/crates/services-integrations/src/remote_connect.rs +++ b/src/crates/services-integrations/src/remote_connect.rs @@ -8,11 +8,12 @@ use bitfun_events::AgenticEvent; use bitfun_runtime_ports::{ AgentInputAttachment, AgentSessionCreateRequest, AgentSubmissionRequest, AgentSubmissionSource, }; +use log::warn; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::path::{Path, PathBuf}; use std::sync::atomic::{AtomicU64, Ordering}; -use std::sync::{Arc, RwLock}; +use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] #[serde(rename_all = "snake_case")] @@ -119,6 +120,32 @@ pub trait RemoteImageContextAdapter { fn from_remote_image_context(context: RemoteImageContext) -> Self; } +fn recover_read_lock<'a, T>(lock: &'a RwLock, lock_name: &str) -> RwLockReadGuard<'a, T> { + match lock.read() { + Ok(guard) => guard, + Err(poisoned) => { + warn!( + "Remote session tracker {} lock was poisoned; recovering cached state", + lock_name + ); + poisoned.into_inner() + } + } +} + +fn recover_write_lock<'a, T>(lock: &'a RwLock, lock_name: &str) -> RwLockWriteGuard<'a, T> { + match lock.write() { + Ok(guard) => guard, + Err(poisoned) => { + warn!( + "Remote session tracker {} lock was poisoned; recovering cached state", + lock_name + ); + poisoned.into_inner() + } + } +} + pub fn build_remote_image_contexts(images: Option<&[ImageAttachment]>) -> Vec { let Some(images) = images.filter(|images| !images.is_empty()) else { return Vec::new(); @@ -1040,8 +1067,16 @@ impl RemoteSessionStateTracker { self.version.fetch_add(1, Ordering::Relaxed); } + fn state_read(&self) -> RwLockReadGuard<'_, TrackerState> { + recover_read_lock(&self.state, "state") + } + + fn state_write(&self) -> RwLockWriteGuard<'_, TrackerState> { + recover_write_lock(&self.state, "state") + } + pub fn snapshot_active_turn(&self) -> Option { - let state = self.state.read().unwrap(); + let state = self.state_read(); let has_items = !state.active_items.is_empty(); state.turn_id.as_ref().map(|turn_id| ActiveTurnSnapshot { turn_id: turn_id.clone(), @@ -1067,27 +1102,27 @@ impl RemoteSessionStateTracker { } pub fn session_state(&self) -> String { - self.state.read().unwrap().session_state.clone() + self.state_read().session_state.clone() } pub fn title(&self) -> String { - self.state.read().unwrap().title.clone() + self.state_read().title.clone() } pub fn turn_status(&self) -> String { - self.state.read().unwrap().turn_status.clone() + self.state_read().turn_status.clone() } pub fn accumulated_text(&self) -> String { - self.state.read().unwrap().accumulated_text.clone() + self.state_read().accumulated_text.clone() } pub fn accumulated_thinking(&self) -> String { - self.state.read().unwrap().accumulated_thinking.clone() + self.state_read().accumulated_thinking.clone() } pub fn is_turn_finished(&self) -> bool { - let state = self.state.read().unwrap(); + let state = self.state_read(); state.turn_id.is_some() && matches!( state.turn_status.as_str(), @@ -1096,7 +1131,7 @@ impl RemoteSessionStateTracker { } pub fn initialize_active_turn(&self, turn_id: String) { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); if state.turn_id.is_none() { state.turn_id = Some(turn_id); state.turn_status = "active".to_string(); @@ -1107,7 +1142,7 @@ impl RemoteSessionStateTracker { } pub fn finalize_completed_turn(&self) { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); if matches!( state.turn_status.as_str(), "completed" | "failed" | "cancelled" @@ -1121,11 +1156,11 @@ impl RemoteSessionStateTracker { } pub fn is_persistence_dirty(&self) -> bool { - self.state.read().unwrap().persistence_dirty + self.state_read().persistence_dirty } pub fn mark_persistence_clean(&self) { - self.state.write().unwrap().persistence_dirty = false; + self.state_write().persistence_dirty = false; } fn find_mergeable_item( @@ -1230,7 +1265,7 @@ impl RemoteSessionStateTracker { return; } - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state .linked_subagent_sessions .insert(session_id.clone(), parent_session_id.clone()); @@ -1245,9 +1280,7 @@ impl RemoteSessionStateTracker { AE::TextChunk { session_id, .. } | AE::ThinkingChunk { session_id, .. } | AE::ToolEvent { session_id, .. } => self - .state - .read() - .unwrap() + .state_read() .linked_subagent_sessions .get(session_id) .is_some_and(|parent_session_id| parent_session_id == &self.target_session_id), @@ -1264,7 +1297,7 @@ impl RemoteSessionStateTracker { match event { AE::TextChunk { text, .. } => { let subagent_marker = if is_subagent { Some(true) } else { None }; - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); if !is_subagent { state.accumulated_text.push_str(text); } @@ -1290,7 +1323,7 @@ impl RemoteSessionStateTracker { } => { let clean = content.replace("", "").replace("", ""); let subagent_marker = if is_subagent { Some(true) } else { None }; - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); if !is_subagent { state.accumulated_thinking.push_str(&clean); } @@ -1336,7 +1369,7 @@ impl RemoteSessionStateTracker { .unwrap_or("") .to_string(); - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); let allow_name_fallback = tool_id.is_empty() && !tool_name.is_empty(); let mut pending_tool_event: Option = None; match event_type { @@ -1507,7 +1540,7 @@ impl RemoteSessionStateTracker { } } AE::DialogTurnStarted { turn_id, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.turn_id = Some(turn_id.clone()); state.turn_status = "active".to_string(); state.accumulated_text.clear(); @@ -1521,7 +1554,7 @@ impl RemoteSessionStateTracker { self.bump_version(); } AE::DialogTurnCompleted { turn_id, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.turn_status = "completed".to_string(); state.session_state = "idle".to_string(); state.persistence_dirty = true; @@ -1532,7 +1565,7 @@ impl RemoteSessionStateTracker { }); } AE::DialogTurnFailed { turn_id, error, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.turn_status = "failed".to_string(); state.session_state = "idle".to_string(); state.persistence_dirty = true; @@ -1544,7 +1577,7 @@ impl RemoteSessionStateTracker { }); } AE::DialogTurnCancelled { turn_id, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.turn_status = "cancelled".to_string(); state.session_state = "idle".to_string(); state.persistence_dirty = true; @@ -1555,19 +1588,19 @@ impl RemoteSessionStateTracker { }); } AE::ModelRoundStarted { round_index, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.round_index = *round_index; drop(state); self.bump_version(); } AE::SessionStateChanged { new_state, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.session_state = new_state.clone(); drop(state); self.bump_version(); } AE::SessionTitleGenerated { title, .. } if is_direct => { - let mut state = self.state.write().unwrap(); + let mut state = self.state_write(); state.title = title.clone(); drop(state); self.bump_version(); @@ -1598,6 +1631,18 @@ impl RemoteSessionTrackerRegistry { Self::default() } + fn trackers_read( + &self, + ) -> RwLockReadGuard<'_, HashMap>> { + recover_read_lock(&self.state_trackers, "registry") + } + + fn trackers_write( + &self, + ) -> RwLockWriteGuard<'_, HashMap>> { + recover_write_lock(&self.state_trackers, "registry") + } + pub fn ensure_tracker_with_host( &self, session_id: &str, @@ -1608,7 +1653,7 @@ impl RemoteSessionTrackerRegistry { } let tracker = { - let mut trackers = self.state_trackers.write().unwrap(); + let mut trackers = self.trackers_write(); if let Some(tracker) = trackers.get(session_id) { return tracker.clone(); } @@ -1626,7 +1671,7 @@ impl RemoteSessionTrackerRegistry { } pub fn get_tracker(&self, session_id: &str) -> Option> { - self.state_trackers.read().unwrap().get(session_id).cloned() + self.trackers_read().get(session_id).cloned() } pub fn remove_tracker_with_host( @@ -1634,7 +1679,7 @@ impl RemoteSessionTrackerRegistry { session_id: &str, host: &H, ) -> Option> { - let removed = self.state_trackers.write().unwrap().remove(session_id); + let removed = self.trackers_write().remove(session_id); if removed.is_some() { host.unsubscribe_tracker(session_id); } @@ -1642,6 +1687,58 @@ impl RemoteSessionTrackerRegistry { } } +#[cfg(test)] +mod tests { + use super::*; + use std::panic::{catch_unwind, AssertUnwindSafe}; + use std::sync::Arc; + + struct NoopTrackerHost; + + impl RemoteSessionTrackerHost for NoopTrackerHost { + fn subscribe_tracker(&self, _session_id: &str, _tracker: Arc) {} + + fn unsubscribe_tracker(&self, _session_id: &str) {} + + fn active_turn_id(&self, _session_id: &str) -> Option { + None + } + } + + #[test] + fn tracker_state_lock_recovers_after_poisoning() { + let tracker = Arc::new(RemoteSessionStateTracker::new("session-1".to_string())); + let poisoned_tracker = tracker.clone(); + + let _ = catch_unwind(AssertUnwindSafe(move || { + let _guard = poisoned_tracker.state.write().unwrap(); + panic!("poison tracker state lock"); + })); + + assert_eq!(tracker.session_state(), "idle"); + + tracker.initialize_active_turn("turn-1".to_string()); + assert_eq!(tracker.session_state(), "running"); + assert_eq!(tracker.turn_status(), "active"); + } + + #[test] + fn tracker_registry_lock_recovers_after_poisoning() { + let registry = Arc::new(RemoteSessionTrackerRegistry::new()); + let poisoned_registry = registry.clone(); + + let _ = catch_unwind(AssertUnwindSafe(move || { + let _guard = poisoned_registry.state_trackers.write().unwrap(); + panic!("poison tracker registry lock"); + })); + + let host = NoopTrackerHost; + let tracker = registry.ensure_tracker_with_host("session-1", &host); + assert_eq!(tracker.session_state(), "idle"); + assert!(registry.get_tracker("session-1").is_some()); + } +} + pub fn should_send_remote_model_catalog( current_model_catalog: Option<&RemoteModelCatalog>, known_model_catalog_version: Option,