Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,10 @@ private Flowable<Event> run(
return Flowable.empty();
} else {
logger.debug("Continuing to next step of the flow.");
return run(spanContext, invocationContext, stepsCompleted + 1);
// Wait until the Runner has persisted this step's events so the next step's
// request is not built from a stale session (see PersistBarrier).
return PersistBarrier.awaitPersisted(invocationContext, eventList)
.andThen(run(spanContext, invocationContext, stepsCompleted + 1));
}
}));
}
Expand Down
131 changes: 131 additions & 0 deletions core/src/main/java/com/google/adk/flows/llmflows/PersistBarrier.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.flows.llmflows;

import com.google.adk.agents.InvocationContext;
import com.google.adk.events.Event;
import com.google.common.annotations.VisibleForTesting;
import io.reactivex.rxjava3.core.Completable;
import io.reactivex.rxjava3.subjects.CompletableSubject;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

/**
* Lets {@link BaseLlmFlow}'s multi-step loop wait until the {@code Runner} -- the sole event
* persister -- has appended the current step's events, so the next step's request (built from
* {@code session.events()} by {@link Contents}) is not assembled from a stale session. The {@code
* Runner} calls {@link #markPersisted} (or {@link #markFailed}) after each append; the flow calls
* {@link #awaitPersisted} between steps. State lives in the per-invocation {@link
* InvocationContext#callbackContextData()} map, shared across the agent tree.
*
* <p>Each event id maps to a {@link CompletableSubject}: pending until its append finishes, then
* terminally completed or failed. The subject retains its terminal state, so {@code
* awaitPersisted}/{@code mark*} may happen in any order and a late await -- e.g. at a higher flow
* level across an agent transfer -- resolves immediately. If an append fails, the matching await
* fails with that error rather than blocking forever.
*
* <p>Thread-safe and lock-free: {@code markPersisted}/{@code markFailed} may run off-thread (async
* {@code appendEvent}) concurrently with {@code awaitPersisted}; {@link
* java.util.concurrent.ConcurrentHashMap#computeIfAbsent} hands both sides the same subject, which
* itself serializes its terminal signal against subscription.
*/
public final class PersistBarrier {

private static final String ENABLED_KEY = "com.google.adk.flows.llmflows.persistBarrier.enabled";
private static final String BARRIERS_KEY =
"com.google.adk.flows.llmflows.persistBarrier.barriers";

private PersistBarrier() {}

/**
* Marks that a {@code Runner} is driving this invocation and will resolve each appended event.
* Otherwise (flow run directly, e.g. unit tests) {@link #awaitPersisted} is a no-op, avoiding a
* deadlock waiting for a signal that never comes.
*/
public static void enable(InvocationContext context) {
context.callbackContextData().put(ENABLED_KEY, true);
}

/**
* Completes once every event in {@code events} has been {@link #markPersisted}, or fails if any
* was {@link #markFailed}; completes immediately if the barrier was never {@link #enable}d.
* Already-resolved events resolve immediately, so the order of {@code awaitPersisted}/{@code
* mark*} does not matter.
*/
public static Completable awaitPersisted(InvocationContext context, List<Event> events) {
Boolean enabled = (Boolean) context.callbackContextData().get(ENABLED_KEY);
if (enabled == null || !enabled) {
return Completable.complete();
}
Completable result = Completable.complete();
for (Event event : events) {
String eventId = event.id();
if (eventId != null) {
result = result.andThen(barrier(context, eventId));
}
}
return result;
}

/** Signals that the {@code Runner} persisted the event with the given id. */
public static void markPersisted(InvocationContext context, String eventId) {
if (eventId != null) {
barrier(context, eventId).onComplete();
}
}

/**
* Signals that persisting the event with the given id failed, so an await on it fails with {@code
* error} instead of blocking forever.
*/
public static void markFailed(InvocationContext context, String eventId, Throwable error) {
if (eventId != null) {
barrier(context, eventId).onError(error);
}
}

/**
* The per-event subject, created on first use. {@code computeIfAbsent} is atomic, so an awaiter
* and a concurrent mark share one subject regardless of order.
*/
private static CompletableSubject barrier(InvocationContext context, String eventId) {
return barriers(context).computeIfAbsent(eventId, unusedKey -> CompletableSubject.create());
}

/** Awaited-but-unresolved events; drains to 0 once a step's events are persisted or failed. */
@VisibleForTesting
static int pendingCount(InvocationContext context) {
int pending = 0;
for (CompletableSubject barrier : barriers(context).values()) {
if (!barrier.hasComplete() && !barrier.hasThrowable()) {
pending++;
}
}
return pending;
}

// Safe: BARRIERS_KEY only ever holds the Map created here.
@SuppressWarnings("unchecked")
private static Map<String, CompletableSubject> barriers(InvocationContext context) {
return (Map<String, CompletableSubject>)
context
.callbackContextData()
.computeIfAbsent(
BARRIERS_KEY, unusedKey -> new ConcurrentHashMap<String, CompletableSubject>());
}
}
14 changes: 14 additions & 0 deletions core/src/main/java/com/google/adk/runner/Runner.java
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import com.google.adk.artifacts.InMemoryArtifactService;
import com.google.adk.events.Event;
import com.google.adk.events.EventActions;
import com.google.adk.flows.llmflows.PersistBarrier;
import com.google.adk.memory.BaseMemoryService;
import com.google.adk.models.Model;
import com.google.adk.plugins.Plugin;
Expand Down Expand Up @@ -575,6 +576,9 @@ private Flowable<Event> runAgentWithUpdatedSession(
.content(content)
.build());

// Let BaseLlmFlow block each step until this Runner has persisted the prior step's events.
PersistBarrier.enable(contextWithUpdatedSession);

// Agent execution
Flowable<Event> agentEvents =
contextWithUpdatedSession
Expand All @@ -584,6 +588,16 @@ private Flowable<Event> runAgentWithUpdatedSession(
agentEvent ->
this.sessionService
.appendEvent(updatedSession, agentEvent)
// Release (or fail) BaseLlmFlow's wait for this step; the Runner stays the
// sole appendEvent caller (see PersistBarrier).
.doOnSuccess(
unusedEvent ->
PersistBarrier.markPersisted(
contextWithUpdatedSession, agentEvent.id()))
.doOnError(
error ->
PersistBarrier.markFailed(
contextWithUpdatedSession, agentEvent.id(), error))
.flatMap(
registeredEvent -> {
// TODO: remove this hack after deprecating runAsync with Session.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,215 @@
/*
* Copyright 2025 Google LLC
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package com.google.adk.flows.llmflows;

import static com.google.common.truth.Truth.assertThat;
import static org.mockito.Mockito.mock;

import com.google.adk.agents.BaseAgent;
import com.google.adk.agents.InvocationContext;
import com.google.adk.events.Event;
import com.google.adk.sessions.BaseSessionService;
import com.google.adk.sessions.Session;
import com.google.common.collect.ImmutableList;
import io.reactivex.rxjava3.observers.TestObserver;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import org.junit.Before;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

@RunWith(JUnit4.class)
public final class PersistBarrierTest {

private InvocationContext context;

@Before
public void setUp() {
context =
InvocationContext.builder()
.sessionService(mock(BaseSessionService.class))
.invocationId("inv-1")
.agent(mock(BaseAgent.class))
.session(Session.builder("s").build())
.build();
}

private static Event event(String id) {
return Event.builder().id(id).author("agent").build();
}

@Test
public void awaitBeforeMark_completesOnMark_andDrainsPending() {
PersistBarrier.enable(context);

TestObserver<Void> observer =
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();

observer.assertNotComplete();
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(1);

PersistBarrier.markPersisted(context, "e1");

observer.assertComplete();
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void markBeforeAwait_completesImmediately_noPending() {
PersistBarrier.enable(context);

PersistBarrier.markPersisted(context, "e1");
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();

assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void sameEventAwaitedTwice_secondAwaitStillCompletes_andNothingLingers() {
// Mirrors an agent transfer: a sub-agent event is awaited by both the sub-agent and parent
// flows but persisted once; the second await must still complete.
PersistBarrier.enable(context);

TestObserver<Void> subLevel =
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
PersistBarrier.markPersisted(context, "e1");
subLevel.assertComplete();
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);

PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void multiEventStep_completesOnlyAfterAllMarked() {
PersistBarrier.enable(context);

TestObserver<Void> observer =
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"), event("e2"))).test();

PersistBarrier.markPersisted(context, "e1");
observer.assertNotComplete();

PersistBarrier.markPersisted(context, "e2");
observer.assertComplete();
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void markFailedBeforeAwait_awaitFails() {
PersistBarrier.enable(context);
RuntimeException error = new RuntimeException("append failed");

PersistBarrier.markFailed(context, "e1", error);
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertError(error);

assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void awaitBeforeMarkFailed_awaitFails() {
PersistBarrier.enable(context);
RuntimeException error = new RuntimeException("append failed");

TestObserver<Void> observer =
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test();
observer.assertNotComplete();

PersistBarrier.markFailed(context, "e1", error);

observer.assertError(error);
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void stepWithOneFailedEvent_awaitFails() {
// A step's await fails if any of its events fails to persist, so the next step does not run.
PersistBarrier.enable(context);
RuntimeException error = new RuntimeException("append failed");

TestObserver<Void> observer =
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"), event("e2"))).test();

PersistBarrier.markPersisted(context, "e1");
PersistBarrier.markFailed(context, "e2", error);

observer.assertError(error);
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void notEnabled_awaitIsNoOp() {
// No enable(): flow runs without a Runner, so await must not block forever.
PersistBarrier.awaitPersisted(context, ImmutableList.of(event("e1"))).test().assertComplete();
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

@Test
public void concurrentAwaitAndMark_allComplete_andDrain() throws Exception {
// awaitPersisted (flow thread) and markPersisted (async appendEvent thread) race on each id;
// none may be stranded and every subject must be dropped.
PersistBarrier.enable(context);
int eventCount = 1000;
List<String> ids = new ArrayList<>();
for (int i = 0; i < eventCount; i++) {
ids.add("e" + i);
}
List<TestObserver<Void>> observers = Collections.synchronizedList(new ArrayList<>());
CountDownLatch start = new CountDownLatch(1);

Thread awaiter =
new Thread(
() -> {
awaitQuietly(start);
for (String id : ids) {
observers.add(
PersistBarrier.awaitPersisted(context, ImmutableList.of(event(id))).test());
}
});
Thread marker =
new Thread(
() -> {
awaitQuietly(start);
for (String id : ids) {
PersistBarrier.markPersisted(context, id);
}
});

awaiter.start();
marker.start();
start.countDown();
awaiter.join();
marker.join();

for (TestObserver<Void> observer : observers) {
observer.assertComplete();
}
assertThat(PersistBarrier.pendingCount(context)).isEqualTo(0);
}

private static void awaitQuietly(CountDownLatch latch) {
try {
latch.await();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new AssertionError(e);
}
}
}
Loading
Loading