Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions core/src/agent_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,15 @@ impl AgentSession {
.collect()
}

/// Cancel an in-flight delegated subagent task by id. Returns `true`
/// when a cancellation token was found and fired, `false` when the
/// task id is unknown or the task has already finished. The eventual
/// `SubagentEnd` from the cancelled child loop won't downgrade the
/// terminal status — it stays `Cancelled`.
pub async fn cancel_subagent_task(&self, task_id: &str) -> bool {
self.subagent_tasks.cancel(task_id).await
}

/// Return a snapshot of the session's conversation history.
pub fn history(&self) -> Vec<Message> {
SessionView::from_session(self).history()
Expand Down
6 changes: 6 additions & 0 deletions core/src/agent_api/capabilities.rs
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ pub(super) struct SessionCapabilities {
pub(super) context_providers: Vec<Arc<dyn ContextProvider>>,
pub(super) skill_registry: Arc<SkillRegistry>,
pub(super) agent_registry: Arc<AgentRegistry>,
pub(super) subagent_tasks: Arc<crate::subagent_task_tracker::InMemorySubagentTaskTracker>,
}

pub(super) fn build_session_capabilities(input: SessionCapabilityInput<'_>) -> SessionCapabilities {
Expand All @@ -60,12 +61,14 @@ pub(super) fn build_session_capabilities(input: SessionCapabilityInput<'_>) -> S
.set_search_config(search_config.clone());
}

let subagent_tasks = Arc::new(crate::subagent_task_tracker::InMemorySubagentTaskTracker::new());
let agent_registry = register_task_capability(
input.code_config,
input.opts,
input.workspace,
Arc::clone(&input.llm_client),
&tool_executor,
Arc::clone(&subagent_tasks),
);

// Register generate_object tool (structured JSON output)
Expand All @@ -90,6 +93,7 @@ pub(super) fn build_session_capabilities(input: SessionCapabilityInput<'_>) -> S
context_providers,
skill_registry,
agent_registry,
subagent_tasks,
}
}

Expand Down Expand Up @@ -136,6 +140,7 @@ fn register_task_capability(
workspace: &Path,
llm_client: Arc<dyn LlmClient>,
tool_executor: &Arc<ToolExecutor>,
subagent_tasks: Arc<crate::subagent_task_tracker::InMemorySubagentTaskTracker>,
) -> Arc<AgentRegistry> {
use crate::child_run::ChildRunContext;
use crate::subagent::load_agents_from_dir;
Expand Down Expand Up @@ -177,6 +182,7 @@ fn register_task_capability(
workspace.display().to_string(),
opts.mcp_manager.clone(),
Some(parent_context),
Some(subagent_tasks),
);
registry
}
Expand Down
3 changes: 2 additions & 1 deletion core/src/agent_api/session_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,7 @@ pub(super) fn build_agent_session(
let tool_defs = capabilities.tool_defs;
let context_providers = capabilities.context_providers;
let effective_registry = capabilities.skill_registry;
let subagent_tasks = capabilities.subagent_tasks;

let prompt_slots = opts
.prompt_slots
Expand Down Expand Up @@ -219,7 +220,7 @@ pub(super) fn build_agent_session(
cancel_token: Arc::new(tokio::sync::Mutex::new(None)),
current_run_id: Arc::new(tokio::sync::Mutex::new(None)),
run_store: Arc::new(crate::run::InMemoryRunStore::new()),
subagent_tasks: Arc::new(crate::subagent_task_tracker::InMemorySubagentTaskTracker::new()),
subagent_tasks,
active_tools: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
trace_sink,
verification_reports: Arc::new(RwLock::new(Vec::new())),
Expand Down
56 changes: 56 additions & 0 deletions core/src/agent_api/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2468,3 +2468,59 @@ async fn subagent_tasks_scope_to_parent_session() {
assert!(session_b.subagent_tasks().await.is_empty());
assert!(session_b.subagent_task("task-from-a").await.is_none());
}

#[tokio::test]
async fn cancel_subagent_task_marks_snapshot_cancelled() {
use super::runtime_events::RuntimeEventSink;
use crate::agent::AgentEvent;
use crate::subagent_task_tracker::SubagentStatus;
use tokio_util::sync::CancellationToken;

let agent = Agent::from_config(test_config()).await.unwrap();
let session = agent.session("/tmp/test-ws-subagent-cancel", None).unwrap();
let run = session
.run_store
.create_run(session.session_id(), "parent")
.await;
let sink = RuntimeEventSink::from_session(&session, &run.id);

let task_id = "task-to-cancel".to_string();
sink.observe(&AgentEvent::SubagentStart {
task_id: task_id.clone(),
session_id: format!("task-run-{}", task_id),
parent_session_id: session.session_id().to_string(),
agent: "explore".to_string(),
description: "long task".to_string(),
})
.await;

// Simulate what TaskExecutor would do: register a cancellation token
// for this in-flight task so the public API has something to fire.
let token = CancellationToken::new();
session
.subagent_tasks
.register_canceller(&task_id, token.clone())
.await;

assert!(session.cancel_subagent_task(&task_id).await);
assert!(token.is_cancelled());

let snap = session.subagent_task(&task_id).await.unwrap();
assert_eq!(snap.status, SubagentStatus::Cancelled);

// A late SubagentEnd from the cancelled child must not downgrade.
sink.observe(&AgentEvent::SubagentEnd {
task_id: task_id.clone(),
session_id: format!("task-run-{}", task_id),
agent: "explore".to_string(),
output: "Task cancelled by caller".to_string(),
success: false,
})
.await;
let snap = session.subagent_task(&task_id).await.unwrap();
assert_eq!(snap.status, SubagentStatus::Cancelled);

// Cancelling again or against an unknown id is a no-op.
assert!(!session.cancel_subagent_task(&task_id).await);
assert!(!session.cancel_subagent_task("task-unknown").await);
}
116 changes: 111 additions & 5 deletions core/src/subagent_task_tracker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,15 @@ use crate::agent::AgentEvent;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use tokio::sync::RwLock;
use tokio_util::sync::CancellationToken;

#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SubagentStatus {
Running,
Completed,
Failed,
Cancelled,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
Expand Down Expand Up @@ -47,13 +49,52 @@ pub struct SubagentTaskSnapshot {
#[derive(Debug, Default)]
pub struct InMemorySubagentTaskTracker {
tasks: RwLock<HashMap<String, SubagentTaskSnapshot>>,
cancellers: RwLock<HashMap<String, CancellationToken>>,
}

impl InMemorySubagentTaskTracker {
pub fn new() -> Self {
Self::default()
}

/// Register a `CancellationToken` for a running task so callers can
/// trigger cancellation through `cancel(task_id)`. The task executor
/// is expected to remove the entry on exit via `clear_canceller`.
pub async fn register_canceller(&self, task_id: &str, token: CancellationToken) {
self.cancellers
.write()
.await
.insert(task_id.to_string(), token);
}

pub async fn clear_canceller(&self, task_id: &str) {
self.cancellers.write().await.remove(task_id);
}

/// Fire the registered token and mark the snapshot as `Cancelled`.
/// Returns `true` if a token was found (caller can interpret as
/// "cancellation initiated"), `false` if the task id was unknown or
/// the task already finished. The eventual `SubagentEnd` event won't
/// overwrite the Cancelled status — see `record_event`.
pub async fn cancel(&self, task_id: &str) -> bool {
let token = self.cancellers.write().await.remove(task_id);
match token {
Some(token) => {
token.cancel();
let now = now_ms();
let mut tasks = self.tasks.write().await;
if let Some(entry) = tasks.get_mut(task_id) {
if entry.status == SubagentStatus::Running {
entry.status = SubagentStatus::Cancelled;
entry.updated_ms = now;
}
}
true
}
None => false,
}
}

/// Apply a single agent event to the tracker. Non-subagent events are ignored.
pub async fn record_event(&self, event: &AgentEvent) {
match event {
Expand Down Expand Up @@ -148,11 +189,16 @@ impl InMemorySubagentTaskTracker {
success: None,
progress: Vec::new(),
});
entry.status = if *success {
SubagentStatus::Completed
} else {
SubagentStatus::Failed
};
// Preserve a pre-set Cancelled status (set by `cancel()`)
// — a late SubagentEnd from the cancelled child loop is
// expected and must not downgrade the terminal state.
if entry.status != SubagentStatus::Cancelled {
entry.status = if *success {
SubagentStatus::Completed
} else {
SubagentStatus::Failed
};
}
entry.updated_ms = now;
entry.finished_ms = Some(now);
entry.output = Some(output.clone());
Expand Down Expand Up @@ -333,4 +379,64 @@ mod tests {
.await;
assert!(tracker.list().await.is_empty());
}

#[tokio::test]
async fn cancel_fires_token_and_marks_snapshot_cancelled() {
let tracker = InMemorySubagentTaskTracker::new();
tracker
.record_event(&start_event("task-c", "parent", "child"))
.await;

let token = CancellationToken::new();
tracker.register_canceller("task-c", token.clone()).await;
assert!(!token.is_cancelled());

let fired = tracker.cancel("task-c").await;
assert!(fired, "cancel should report success");
assert!(token.is_cancelled(), "registered token should be triggered");

let snap = tracker.get("task-c").await.unwrap();
assert_eq!(snap.status, SubagentStatus::Cancelled);
}

#[tokio::test]
async fn cancel_returns_false_for_unknown_task() {
let tracker = InMemorySubagentTaskTracker::new();
assert!(!tracker.cancel("task-does-not-exist").await);
}

#[tokio::test]
async fn late_subagent_end_does_not_downgrade_cancelled_status() {
let tracker = InMemorySubagentTaskTracker::new();
tracker
.record_event(&start_event("task-d", "parent", "child"))
.await;
let token = CancellationToken::new();
tracker.register_canceller("task-d", token).await;
assert!(tracker.cancel("task-d").await);

// The cancelled child loop will still emit a (likely failed)
// SubagentEnd. The terminal status should remain Cancelled.
tracker
.record_event(&end_event("task-d", "child", false))
.await;
let snap = tracker.get("task-d").await.unwrap();
assert_eq!(snap.status, SubagentStatus::Cancelled);
assert!(snap.finished_ms.is_some());
assert_eq!(snap.success, Some(false));
}

#[tokio::test]
async fn clear_canceller_disarms_future_cancel_calls() {
let tracker = InMemorySubagentTaskTracker::new();
tracker
.record_event(&start_event("task-e", "parent", "child"))
.await;
let token = CancellationToken::new();
tracker.register_canceller("task-e", token.clone()).await;
tracker.clear_canceller("task-e").await;

assert!(!tracker.cancel("task-e").await);
assert!(!token.is_cancelled());
}
}
16 changes: 15 additions & 1 deletion core/src/tools/builtin/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -93,21 +93,32 @@ pub fn register_task(
agent_registry: Arc<crate::subagent::AgentRegistry>,
workspace: String,
) {
register_task_with_mcp(registry, llm_client, agent_registry, workspace, None, None);
register_task_with_mcp(
registry,
llm_client,
agent_registry,
workspace,
None,
None,
None,
);
}

/// Register the task delegation tools with optional MCP manager and parent context.
///
/// When `mcp_manager` is provided, delegated child sessions will have access
/// to all MCP tools from connected servers.
/// When `parent_context` is provided, child runs inherit parent capabilities.
/// When `subagent_tracker` is provided, each task registers a
/// `CancellationToken` against it so callers can cancel by `task_id`.
pub fn register_task_with_mcp(
registry: &Arc<ToolRegistry>,
llm_client: Arc<dyn crate::llm::LlmClient>,
agent_registry: Arc<crate::subagent::AgentRegistry>,
workspace: String,
mcp_manager: Option<Arc<crate::mcp::manager::McpManager>>,
parent_context: Option<crate::child_run::ChildRunContext>,
subagent_tracker: Option<Arc<crate::subagent_task_tracker::InMemorySubagentTaskTracker>>,
) {
use crate::tools::task::{ParallelTaskTool, TaskExecutor, TaskTool};
let mut executor = match mcp_manager {
Expand All @@ -117,6 +128,9 @@ pub fn register_task_with_mcp(
if let Some(ctx) = parent_context {
executor = executor.with_parent_context(ctx);
}
if let Some(tracker) = subagent_tracker {
executor = executor.with_subagent_tracker(tracker);
}
let executor = Arc::new(executor);
registry.register_builtin(Arc::new(TaskTool::new(Arc::clone(&executor))));
registry.register_builtin(Arc::new(ParallelTaskTool::new(Arc::clone(&executor))));
Expand Down
Loading
Loading