Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 10 additions & 6 deletions src/crates/core/src/service/remote_ssh/workspace_state.rs
Original file line number Diff line number Diff line change
Expand Up @@ -348,21 +348,25 @@ impl RemoteWorkspaceStateManager {
) -> PathBuf {
let remote_id = remote_connection_id
.map(str::trim)
.filter(|s| !s.is_empty());
if remote_id.is_none() {
.filter(|s| !s.is_empty())
.map(str::to_string);
let Some(remote_id) = remote_id else {
return PathBuf::from(workspace_path);
}
};
let path_norm = normalize_remote_workspace_path(workspace_path);
if let Some(host) = remote_ssh_host.map(str::trim).filter(|s| !s.is_empty()) {
return remote_workspace_session_mirror_dir(host, &path_norm);
}
if let Some(entry) = self.lookup_connection(workspace_path, remote_id).await {
if let Some(entry) = self
.lookup_connection(workspace_path, Some(remote_id.as_str()))
.await
{
if !entry.ssh_host.trim().is_empty() {
return remote_workspace_session_mirror_dir(&entry.ssh_host, &entry.remote_root);
}
return unresolved_remote_session_storage_dir(remote_id.unwrap(), &path_norm);
return unresolved_remote_session_storage_dir(&remote_id, &path_norm);
}
unresolved_remote_session_storage_dir(remote_id.unwrap(), &path_norm)
unresolved_remote_session_storage_dir(&remote_id, &path_norm)
}
}

Expand Down
155 changes: 126 additions & 29 deletions src/crates/services-integrations/src/remote_connect.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,12 @@ use bitfun_events::AgenticEvent;
use bitfun_runtime_ports::{
AgentInputAttachment, AgentSessionCreateRequest, AgentSubmissionRequest, AgentSubmissionSource,
};
use log::warn;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::sync::{Arc, RwLock, RwLockReadGuard, RwLockWriteGuard};

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
Expand Down Expand Up @@ -119,6 +120,32 @@ pub trait RemoteImageContextAdapter {
fn from_remote_image_context(context: RemoteImageContext) -> Self;
}

fn recover_read_lock<'a, T>(lock: &'a RwLock<T>, lock_name: &str) -> RwLockReadGuard<'a, T> {
match lock.read() {
Ok(guard) => guard,
Err(poisoned) => {
warn!(
"Remote session tracker {} lock was poisoned; recovering cached state",
lock_name
);
poisoned.into_inner()
}
}
}

fn recover_write_lock<'a, T>(lock: &'a RwLock<T>, lock_name: &str) -> RwLockWriteGuard<'a, T> {
match lock.write() {
Ok(guard) => guard,
Err(poisoned) => {
warn!(
"Remote session tracker {} lock was poisoned; recovering cached state",
lock_name
);
poisoned.into_inner()
}
}
}

pub fn build_remote_image_contexts(images: Option<&[ImageAttachment]>) -> Vec<RemoteImageContext> {
let Some(images) = images.filter(|images| !images.is_empty()) else {
return Vec::new();
Expand Down Expand Up @@ -1040,8 +1067,16 @@ impl RemoteSessionStateTracker {
self.version.fetch_add(1, Ordering::Relaxed);
}

fn state_read(&self) -> RwLockReadGuard<'_, TrackerState> {
recover_read_lock(&self.state, "state")
}

fn state_write(&self) -> RwLockWriteGuard<'_, TrackerState> {
recover_write_lock(&self.state, "state")
}

pub fn snapshot_active_turn(&self) -> Option<ActiveTurnSnapshot> {
let state = self.state.read().unwrap();
let state = self.state_read();
let has_items = !state.active_items.is_empty();
state.turn_id.as_ref().map(|turn_id| ActiveTurnSnapshot {
turn_id: turn_id.clone(),
Expand All @@ -1067,27 +1102,27 @@ impl RemoteSessionStateTracker {
}

pub fn session_state(&self) -> String {
self.state.read().unwrap().session_state.clone()
self.state_read().session_state.clone()
}

pub fn title(&self) -> String {
self.state.read().unwrap().title.clone()
self.state_read().title.clone()
}

pub fn turn_status(&self) -> String {
self.state.read().unwrap().turn_status.clone()
self.state_read().turn_status.clone()
}

pub fn accumulated_text(&self) -> String {
self.state.read().unwrap().accumulated_text.clone()
self.state_read().accumulated_text.clone()
}

pub fn accumulated_thinking(&self) -> String {
self.state.read().unwrap().accumulated_thinking.clone()
self.state_read().accumulated_thinking.clone()
}

pub fn is_turn_finished(&self) -> bool {
let state = self.state.read().unwrap();
let state = self.state_read();
state.turn_id.is_some()
&& matches!(
state.turn_status.as_str(),
Expand All @@ -1096,7 +1131,7 @@ impl RemoteSessionStateTracker {
}

pub fn initialize_active_turn(&self, turn_id: String) {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
if state.turn_id.is_none() {
state.turn_id = Some(turn_id);
state.turn_status = "active".to_string();
Expand All @@ -1107,7 +1142,7 @@ impl RemoteSessionStateTracker {
}

pub fn finalize_completed_turn(&self) {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
if matches!(
state.turn_status.as_str(),
"completed" | "failed" | "cancelled"
Expand All @@ -1121,11 +1156,11 @@ impl RemoteSessionStateTracker {
}

pub fn is_persistence_dirty(&self) -> bool {
self.state.read().unwrap().persistence_dirty
self.state_read().persistence_dirty
}

pub fn mark_persistence_clean(&self) {
self.state.write().unwrap().persistence_dirty = false;
self.state_write().persistence_dirty = false;
}

fn find_mergeable_item(
Expand Down Expand Up @@ -1230,7 +1265,7 @@ impl RemoteSessionStateTracker {
return;
}

let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state
.linked_subagent_sessions
.insert(session_id.clone(), parent_session_id.clone());
Expand All @@ -1245,9 +1280,7 @@ impl RemoteSessionStateTracker {
AE::TextChunk { session_id, .. }
| AE::ThinkingChunk { session_id, .. }
| AE::ToolEvent { session_id, .. } => self
.state
.read()
.unwrap()
.state_read()
.linked_subagent_sessions
.get(session_id)
.is_some_and(|parent_session_id| parent_session_id == &self.target_session_id),
Expand All @@ -1264,7 +1297,7 @@ impl RemoteSessionStateTracker {
match event {
AE::TextChunk { text, .. } => {
let subagent_marker = if is_subagent { Some(true) } else { None };
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
if !is_subagent {
state.accumulated_text.push_str(text);
}
Expand All @@ -1290,7 +1323,7 @@ impl RemoteSessionStateTracker {
} => {
let clean = content.replace("</thinking>", "").replace("<thinking>", "");
let subagent_marker = if is_subagent { Some(true) } else { None };
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
if !is_subagent {
state.accumulated_thinking.push_str(&clean);
}
Expand Down Expand Up @@ -1336,7 +1369,7 @@ impl RemoteSessionStateTracker {
.unwrap_or("")
.to_string();

let mut state = self.state.write().unwrap();
let mut state = self.state_write();
let allow_name_fallback = tool_id.is_empty() && !tool_name.is_empty();
let mut pending_tool_event: Option<TrackerEvent> = None;
match event_type {
Expand Down Expand Up @@ -1507,7 +1540,7 @@ impl RemoteSessionStateTracker {
}
}
AE::DialogTurnStarted { turn_id, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.turn_id = Some(turn_id.clone());
state.turn_status = "active".to_string();
state.accumulated_text.clear();
Expand All @@ -1521,7 +1554,7 @@ impl RemoteSessionStateTracker {
self.bump_version();
}
AE::DialogTurnCompleted { turn_id, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.turn_status = "completed".to_string();
state.session_state = "idle".to_string();
state.persistence_dirty = true;
Expand All @@ -1532,7 +1565,7 @@ impl RemoteSessionStateTracker {
});
}
AE::DialogTurnFailed { turn_id, error, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.turn_status = "failed".to_string();
state.session_state = "idle".to_string();
state.persistence_dirty = true;
Expand All @@ -1544,7 +1577,7 @@ impl RemoteSessionStateTracker {
});
}
AE::DialogTurnCancelled { turn_id, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.turn_status = "cancelled".to_string();
state.session_state = "idle".to_string();
state.persistence_dirty = true;
Expand All @@ -1555,19 +1588,19 @@ impl RemoteSessionStateTracker {
});
}
AE::ModelRoundStarted { round_index, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.round_index = *round_index;
drop(state);
self.bump_version();
}
AE::SessionStateChanged { new_state, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.session_state = new_state.clone();
drop(state);
self.bump_version();
}
AE::SessionTitleGenerated { title, .. } if is_direct => {
let mut state = self.state.write().unwrap();
let mut state = self.state_write();
state.title = title.clone();
drop(state);
self.bump_version();
Expand Down Expand Up @@ -1598,6 +1631,18 @@ impl RemoteSessionTrackerRegistry {
Self::default()
}

fn trackers_read(
&self,
) -> RwLockReadGuard<'_, HashMap<String, Arc<RemoteSessionStateTracker>>> {
recover_read_lock(&self.state_trackers, "registry")
}

fn trackers_write(
&self,
) -> RwLockWriteGuard<'_, HashMap<String, Arc<RemoteSessionStateTracker>>> {
recover_write_lock(&self.state_trackers, "registry")
}

pub fn ensure_tracker_with_host<H: RemoteSessionTrackerHost>(
&self,
session_id: &str,
Expand All @@ -1608,7 +1653,7 @@ impl RemoteSessionTrackerRegistry {
}

let tracker = {
let mut trackers = self.state_trackers.write().unwrap();
let mut trackers = self.trackers_write();
if let Some(tracker) = trackers.get(session_id) {
return tracker.clone();
}
Expand All @@ -1626,22 +1671,74 @@ impl RemoteSessionTrackerRegistry {
}

pub fn get_tracker(&self, session_id: &str) -> Option<Arc<RemoteSessionStateTracker>> {
self.state_trackers.read().unwrap().get(session_id).cloned()
self.trackers_read().get(session_id).cloned()
}

pub fn remove_tracker_with_host<H: RemoteSessionTrackerHost>(
&self,
session_id: &str,
host: &H,
) -> Option<Arc<RemoteSessionStateTracker>> {
let removed = self.state_trackers.write().unwrap().remove(session_id);
let removed = self.trackers_write().remove(session_id);
if removed.is_some() {
host.unsubscribe_tracker(session_id);
}
removed
}
}

#[cfg(test)]
mod tests {
use super::*;
use std::panic::{catch_unwind, AssertUnwindSafe};
use std::sync::Arc;

struct NoopTrackerHost;

impl RemoteSessionTrackerHost for NoopTrackerHost {
fn subscribe_tracker(&self, _session_id: &str, _tracker: Arc<RemoteSessionStateTracker>) {}

fn unsubscribe_tracker(&self, _session_id: &str) {}

fn active_turn_id(&self, _session_id: &str) -> Option<String> {
None
}
}

#[test]
fn tracker_state_lock_recovers_after_poisoning() {
let tracker = Arc::new(RemoteSessionStateTracker::new("session-1".to_string()));
let poisoned_tracker = tracker.clone();

let _ = catch_unwind(AssertUnwindSafe(move || {
let _guard = poisoned_tracker.state.write().unwrap();
panic!("poison tracker state lock");
}));

assert_eq!(tracker.session_state(), "idle");

tracker.initialize_active_turn("turn-1".to_string());
assert_eq!(tracker.session_state(), "running");
assert_eq!(tracker.turn_status(), "active");
}

#[test]
fn tracker_registry_lock_recovers_after_poisoning() {
let registry = Arc::new(RemoteSessionTrackerRegistry::new());
let poisoned_registry = registry.clone();

let _ = catch_unwind(AssertUnwindSafe(move || {
let _guard = poisoned_registry.state_trackers.write().unwrap();
panic!("poison tracker registry lock");
}));

let host = NoopTrackerHost;
let tracker = registry.ensure_tracker_with_host("session-1", &host);
assert_eq!(tracker.session_state(), "idle");
assert!(registry.get_tracker("session-1").is_some());
}
}

pub fn should_send_remote_model_catalog(
current_model_catalog: Option<&RemoteModelCatalog>,
known_model_catalog_version: Option<u64>,
Expand Down
Loading