From df66489f32c0e824d208318ce78cd1cabf4d1b28 Mon Sep 17 00:00:00 2001 From: Pukaphol Thienpreecha Date: Sun, 31 May 2026 13:09:22 -0700 Subject: [PATCH] feat(memory): add RedisMemoryService --- pyproject.toml | 2 + src/google/adk/cli/service_registry.py | 9 + src/google/adk/memory/__init__.py | 3 + src/google/adk/memory/redis_memory_service.py | 257 ++++++++++++ tests/unittests/cli/test_service_registry.py | 21 + .../memory/test_redis_memory_service.py | 387 ++++++++++++++++++ 6 files changed, 679 insertions(+) create mode 100644 src/google/adk/memory/redis_memory_service.py create mode 100644 tests/unittests/memory/test_redis_memory_service.py diff --git a/pyproject.toml b/pyproject.toml index c8f8a86a16..fd54eda08a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -90,6 +90,7 @@ optional-dependencies.all = [ "opentelemetry-resourcedetector-gcp>=1.9.0a0,<2", "pyarrow>=14", "python-dateutil>=2.9.0.post0,<3", + "redis>=5,<7", "sqlalchemy>=2,<3", "sqlalchemy-spanner>=1.14", ] @@ -177,6 +178,7 @@ optional-dependencies.mcp = [ ] optional-dependencies.otel-gcp = [ "opentelemetry-instrumentation-google-genai>=0.6b0,<1" ] +optional-dependencies.redis = [ "redis>=5,<7" ] optional-dependencies.slack = [ "slack-bolt>=1.22" ] optional-dependencies.test = [ "a2a-sdk>=0.3,<0.4", diff --git a/src/google/adk/cli/service_registry.py b/src/google/adk/cli/service_registry.py index 517222d932..7128c4a1ee 100644 --- a/src/google/adk/cli/service_registry.py +++ b/src/google/adk/cli/service_registry.py @@ -346,9 +346,18 @@ def agentengine_memory_factory(uri: str, **kwargs): ) return VertexAiMemoryBankService(**params) + def redis_memory_factory(uri: str, **kwargs): + from ..memory.redis_memory_service import RedisMemoryService + + kwargs_copy = kwargs.copy() + kwargs_copy.pop("agents_dir", None) + return RedisMemoryService(redis_url=uri, **kwargs_copy) + registry.register_memory_service("memory", memory_memory_factory) registry.register_memory_service("rag", rag_memory_factory) registry.register_memory_service("agentengine", agentengine_memory_factory) + for scheme in ["redis", "rediss"]: + registry.register_memory_service(scheme, redis_memory_factory) # -- A2A Task Store Services -- def memory_task_store_factory(uri: str, **kwargs: Any) -> Any: diff --git a/src/google/adk/memory/__init__.py b/src/google/adk/memory/__init__.py index 1361b34e36..b7337f0eb0 100644 --- a/src/google/adk/memory/__init__.py +++ b/src/google/adk/memory/__init__.py @@ -22,18 +22,21 @@ if TYPE_CHECKING: from .in_memory_memory_service import InMemoryMemoryService + from .redis_memory_service import RedisMemoryService from .vertex_ai_memory_bank_service import VertexAiMemoryBankService from .vertex_ai_rag_memory_service import VertexAiRagMemoryService __all__ = [ 'BaseMemoryService', 'InMemoryMemoryService', + 'RedisMemoryService', 'VertexAiMemoryBankService', 'VertexAiRagMemoryService', ] _LAZY_MEMBERS: dict[str, str] = { 'InMemoryMemoryService': 'in_memory_memory_service', + 'RedisMemoryService': 'redis_memory_service', 'VertexAiMemoryBankService': 'vertex_ai_memory_bank_service', 'VertexAiRagMemoryService': 'vertex_ai_rag_memory_service', } diff --git a/src/google/adk/memory/redis_memory_service.py b/src/google/adk/memory/redis_memory_service.py new file mode 100644 index 0000000000..7bb63d1734 --- /dev/null +++ b/src/google/adk/memory/redis_memory_service.py @@ -0,0 +1,257 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from collections.abc import Mapping +from collections.abc import Sequence +import hashlib +import json +import re +from typing import Any +from typing import TYPE_CHECKING +from urllib.parse import quote + +from google.genai import types +from typing_extensions import override + +from . import _utils +from .base_memory_service import BaseMemoryService +from .base_memory_service import SearchMemoryResponse +from .memory_entry import MemoryEntry + +if TYPE_CHECKING: + from ..events.event import Event + from ..sessions.session import Session + +_UNKNOWN_SESSION_ID = '__unknown_session_id__' + + +def _key_part(value: str) -> str: + return quote(value, safe='') + + +def _decode(value: Any) -> str: + if isinstance(value, bytes): + return value.decode('utf-8') + return str(value) + + +def _extract_words_lower(text: str) -> set[str]: + """Extracts words from a string and converts them to lowercase.""" + return set([word.lower() for word in re.findall(r'[A-Za-z]+', text)]) + + +def _content_to_text(content: types.Content) -> str: + return ' '.join([part.text for part in content.parts or [] if part.text]) + + +def _content_from_payload(payload: dict[str, Any]) -> types.Content: + return types.Content.model_validate(payload) + + +def _event_id(event: Event) -> str: + if event.id: + return event.id + content_text = _content_to_text(event.content) if event.content else '' + digest = hashlib.sha256( + f'{event.author}:{event.timestamp}:{content_text}'.encode('utf-8') + ).hexdigest() + return f'generated-{digest}' + + +def _event_to_payload( + event: Event, + *, + session_id: str, + custom_metadata: Mapping[str, object] | None = None, +) -> dict[str, Any]: + metadata = dict(custom_metadata or {}) + metadata.setdefault('session_id', session_id) + return { + 'id': _event_id(event), + 'author': event.author, + 'timestamp': _utils.format_timestamp(event.timestamp), + 'content': event.content.model_dump( + mode='json', by_alias=True, exclude_none=True + ), + 'custom_metadata': metadata, + } + + +class RedisMemoryService(BaseMemoryService): + """A Redis-backed memory service. + + This service mirrors InMemoryMemoryService's keyword search behavior while + keeping memory entries in Redis so they survive process restarts. + """ + + def __init__( + self, + redis_url: str | None = None, + *, + key_prefix: str = 'adk:memory:', + client: Any | None = None, + **redis_kwargs: Any, + ): + """Initializes the Redis memory service. + + Args: + redis_url: URL passed to redis.asyncio.from_url. Required when client is + not supplied. + key_prefix: Prefix for all Redis keys written by this service. + client: Optional async Redis-compatible client, mainly for tests. + **redis_kwargs: Extra keyword arguments forwarded to from_url. + """ + if client is None: + if redis_url is None: + raise ValueError('redis_url is required when client is not supplied.') + try: + from redis import asyncio as redis_asyncio + except ImportError as e: + from ..utils._dependency import missing_extra + + raise missing_extra('redis', 'redis') from e + client = redis_asyncio.from_url(redis_url, **redis_kwargs) + + self._client = client + self._key_prefix = key_prefix + + def _scope_prefix(self, app_name: str, user_id: str) -> str: + return ( + f'{self._key_prefix}{_key_part(app_name)}:{_key_part(user_id)}' + ) + + def _sessions_key(self, app_name: str, user_id: str) -> str: + return f'{self._scope_prefix(app_name, user_id)}:sessions' + + def _session_keys( + self, app_name: str, user_id: str, session_id: str + ) -> tuple[str, str]: + session_prefix = ( + f'{self._scope_prefix(app_name, user_id)}:{_key_part(session_id)}' + ) + return f'{session_prefix}:order', f'{session_prefix}:entries' + + async def _append_events( + self, + *, + app_name: str, + user_id: str, + session_id: str, + events: Sequence[Event], + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + events_to_add = [ + event for event in events if event.content and event.content.parts + ] + await self._client.sadd(self._sessions_key(app_name, user_id), session_id) + order_key, entries_key = self._session_keys(app_name, user_id, session_id) + + for event in events_to_add: + event_id = _event_id(event) + payload = _event_to_payload( + event, session_id=session_id, custom_metadata=custom_metadata + ) + was_added = await self._client.hsetnx( + entries_key, event_id, json.dumps(payload) + ) + if was_added: + await self._client.rpush(order_key, event_id) + + @override + async def add_session_to_memory(self, session: Session) -> None: + session_id = session.id or _UNKNOWN_SESSION_ID + order_key, entries_key = self._session_keys( + session.app_name, session.user_id, session_id + ) + await self._client.delete(order_key, entries_key) + await self._append_events( + app_name=session.app_name, + user_id=session.user_id, + session_id=session_id, + events=session.events, + ) + + @override + async def add_events_to_memory( + self, + *, + app_name: str, + user_id: str, + events: Sequence[Event], + session_id: str | None = None, + custom_metadata: Mapping[str, object] | None = None, + ) -> None: + await self._append_events( + app_name=app_name, + user_id=user_id, + session_id=session_id or _UNKNOWN_SESSION_ID, + events=events, + custom_metadata=custom_metadata, + ) + + @override + async def search_memory( + self, *, app_name: str, user_id: str, query: str + ) -> SearchMemoryResponse: + sessions_key = self._sessions_key(app_name, user_id) + session_ids = sorted([ + _decode(value) for value in await self._client.smembers(sessions_key) + ]) + words_in_query = _extract_words_lower(query) + response = SearchMemoryResponse() + + for session_id in session_ids: + order_key, entries_key = self._session_keys(app_name, user_id, session_id) + event_ids = [ + _decode(value) + for value in await self._client.lrange(order_key, 0, -1) + ] + for event_id in event_ids: + raw_payload = await self._client.hget(entries_key, event_id) + if raw_payload is None: + continue + payload = json.loads(_decode(raw_payload)) + content = _content_from_payload(payload['content']) + words_in_memory = _extract_words_lower(_content_to_text(content)) + if not words_in_memory: + continue + if any(query_word in words_in_memory for query_word in words_in_query): + response.memories.append( + MemoryEntry( + id=payload['id'], + content=content, + author=payload.get('author'), + timestamp=payload.get('timestamp'), + custom_metadata=payload.get('custom_metadata') or {}, + ) + ) + + return response + + async def close(self) -> None: + """Closes the Redis client if it exposes a close method.""" + close = getattr(self._client, 'aclose', None) + if close is None: + close = getattr(self._client, 'close', None) + if close is not None: + result = close() + if hasattr(result, '__await__'): + await result + + async def __aenter__(self) -> RedisMemoryService: + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: + await self.close() diff --git a/tests/unittests/cli/test_service_registry.py b/tests/unittests/cli/test_service_registry.py index 094c4ea428..47305a5e39 100644 --- a/tests/unittests/cli/test_service_registry.py +++ b/tests/unittests/cli/test_service_registry.py @@ -39,6 +39,9 @@ def mock_services(): patch( "google.adk.memory.vertex_ai_memory_bank_service.VertexAiMemoryBankService" ) as mock_agentengine_memory, + patch( + "google.adk.memory.redis_memory_service.RedisMemoryService" + ) as mock_redis_memory, ): yield { "vertex_session": mock_vertex_session, @@ -47,6 +50,7 @@ def mock_services(): "gcs_artifact": mock_gcs_artifact, "rag_memory": mock_rag_memory, "agentengine_memory": mock_agentengine_memory, + "redis_memory": mock_redis_memory, } @@ -172,6 +176,22 @@ def test_create_memory_service_memory(registry): assert isinstance(memory_service, InMemoryMemoryService) +def test_create_memory_service_redis(registry, mock_services): + registry.create_memory_service( + "redis://localhost:6379/0", agents_dir="/path/to/agents" + ) + mock_services["redis_memory"].assert_called_once_with( + redis_url="redis://localhost:6379/0" + ) + + +def test_create_memory_service_rediss(registry, mock_services): + registry.create_memory_service("rediss://localhost:6379/0") + mock_services["redis_memory"].assert_called_once_with( + redis_url="rediss://localhost:6379/0" + ) + + # Task Store Tests def test_create_task_store_memory(registry): from a2a.server.tasks import InMemoryTaskStore @@ -209,5 +229,6 @@ def test_unsupported_scheme(registry, mock_services): "gcs_artifact", "rag_memory", "agentengine_memory", + "redis_memory", ]: mock_services[service].assert_not_called() diff --git a/tests/unittests/memory/test_redis_memory_service.py b/tests/unittests/memory/test_redis_memory_service.py new file mode 100644 index 0000000000..d2c9a949ca --- /dev/null +++ b/tests/unittests/memory/test_redis_memory_service.py @@ -0,0 +1,387 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from google.adk.events.event import Event +from google.adk.memory.redis_memory_service import RedisMemoryService +from google.adk.sessions.session import Session +from google.genai import types +import pytest + +MOCK_APP_NAME = 'test-app' +MOCK_USER_ID = 'test-user' +MOCK_OTHER_USER_ID = 'another-user' + +MOCK_SESSION_1 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='session-1', + last_update_time=1000, + events=[ + Event( + id='event-1a', + invocation_id='inv-1', + author='user', + timestamp=12345, + content=types.Content( + parts=[types.Part(text='The ADK is a great toolkit.')] + ), + ), + Event( + id='event-1b', + invocation_id='inv-2', + author='user', + timestamp=12346, + ), + Event( + id='event-1c', + invocation_id='inv-3', + author='model', + timestamp=12347, + content=types.Content( + parts=[ + types.Part( + text='I agree. The Agent Development Kit (ADK) rocks!' + ) + ] + ), + ), + ], +) + +MOCK_SESSION_2 = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='session-2', + last_update_time=2000, + events=[ + Event( + id='event-2a', + invocation_id='inv-4', + author='user', + timestamp=54321, + content=types.Content( + parts=[types.Part(text='I like to code in Python.')] + ), + ), + ], +) + +MOCK_SESSION_DIFFERENT_USER = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_OTHER_USER_ID, + id='session-3', + last_update_time=3000, + events=[ + Event( + id='event-3a', + invocation_id='inv-5', + author='user', + timestamp=60000, + content=types.Content(parts=[types.Part(text='This is a secret.')]), + ), + ], +) + +MOCK_SESSION_WITH_NO_EVENTS = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id='session-4', + last_update_time=4000, +) + + +class FakeAsyncRedis: + def __init__(self): + self.sets: dict[str, set[str]] = {} + self.lists: dict[str, list[str]] = {} + self.hashes: dict[str, dict[str, str]] = {} + self.closed = False + + async def sadd(self, key: str, *values: str) -> int: + values_set = self.sets.setdefault(key, set()) + old_len = len(values_set) + values_set.update(values) + return len(values_set) - old_len + + async def smembers(self, key: str) -> set[str]: + return set(self.sets.get(key, set())) + + async def delete(self, *keys: str) -> int: + deleted = 0 + for key in keys: + for store in (self.sets, self.lists, self.hashes): + if key in store: + del store[key] + deleted += 1 + return deleted + + async def hsetnx(self, key: str, field: str, value: str) -> int: + values = self.hashes.setdefault(key, {}) + if field in values: + return 0 + values[field] = value + return 1 + + async def rpush(self, key: str, value: str) -> int: + values = self.lists.setdefault(key, []) + values.append(value) + return len(values) + + async def lrange(self, key: str, start: int, end: int) -> list[str]: + values = self.lists.get(key, []) + if end == -1: + return values[start:] + return values[start : end + 1] + + async def hget(self, key: str, field: str) -> str | None: + return self.hashes.get(key, {}).get(field) + + async def aclose(self) -> None: + self.closed = True + + +def redis_memory_service() -> RedisMemoryService: + return RedisMemoryService(client=FakeAsyncRedis()) + + +@pytest.mark.asyncio +async def test_add_session_to_memory(): + memory_service = redis_memory_service() + + await memory_service.add_session_to_memory(MOCK_SESSION_1) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='ADK' + ) + + assert len(result.memories) == 2 + assert {memory.id for memory in result.memories} == {'event-1a', 'event-1c'} + + +@pytest.mark.asyncio +async def test_add_events_to_memory_with_explicit_events(): + memory_service = redis_memory_service() + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[MOCK_SESSION_1.events[0]], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='toolkit' + ) + + assert len(result.memories) == 1 + assert result.memories[0].id == 'event-1a' + + +@pytest.mark.asyncio +async def test_add_events_to_memory_without_session_id_uses_default_bucket(): + memory_service = redis_memory_service() + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + events=[MOCK_SESSION_1.events[0]], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='toolkit' + ) + + assert len(result.memories) == 1 + assert result.memories[0].custom_metadata['session_id'] + + +@pytest.mark.asyncio +async def test_add_events_to_memory_appends_without_replacing(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + new_event = Event( + id='event-1d', + invocation_id='inv-6', + author='user', + timestamp=12348, + content=types.Content(parts=[types.Part(text='A new fact.')]), + ) + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[new_event], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='fact' + ) + + assert len(result.memories) == 1 + assert result.memories[0].id == 'event-1d' + + +@pytest.mark.asyncio +async def test_add_events_to_memory_deduplicates_event_ids(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + duplicate_event = Event( + id='event-1a', + invocation_id='inv-7', + author='user', + timestamp=12349, + content=types.Content(parts=[types.Part(text='Updated duplicate text.')]), + ) + + await memory_service.add_events_to_memory( + app_name=MOCK_SESSION_1.app_name, + user_id=MOCK_SESSION_1.user_id, + session_id=MOCK_SESSION_1.id, + events=[duplicate_event], + ) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='duplicate' + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_add_session_replaces_existing_session_memory(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + replacement_session = Session( + app_name=MOCK_APP_NAME, + user_id=MOCK_USER_ID, + id=MOCK_SESSION_1.id, + last_update_time=5000, + events=[ + Event( + id='replacement', + invocation_id='inv-8', + author='user', + timestamp=12350, + content=types.Content(parts=[types.Part(text='Replacement')]), + ) + ], + ) + + await memory_service.add_session_to_memory(replacement_session) + old_result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='ADK' + ) + new_result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Replacement' + ) + + assert not old_result.memories + assert len(new_result.memories) == 1 + assert new_result.memories[0].id == 'replacement' + + +@pytest.mark.asyncio +async def test_add_session_with_no_events_to_memory(): + memory_service = redis_memory_service() + + await memory_service.add_session_to_memory(MOCK_SESSION_WITH_NO_EVENTS) + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='anything' + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_simple_match(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_2) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='Python' + ) + + assert len(result.memories) == 1 + assert result.memories[0].content.parts[0].text == 'I like to code in Python.' + assert result.memories[0].author == 'user' + + +@pytest.mark.asyncio +async def test_search_memory_case_insensitive_match(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='development' + ) + + assert len(result.memories) == 1 + assert ( + result.memories[0].content.parts[0].text + == 'I agree. The Agent Development Kit (ADK) rocks!' + ) + + +@pytest.mark.asyncio +async def test_search_memory_multiple_matches(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='How about ADK?' + ) + + assert len(result.memories) == 2 + texts = {memory.content.parts[0].text for memory in result.memories} + assert 'The ADK is a great toolkit.' in texts + assert 'I agree. The Agent Development Kit (ADK) rocks!' in texts + + +@pytest.mark.asyncio +async def test_search_memory_no_match(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='nonexistent' + ) + + assert not result.memories + + +@pytest.mark.asyncio +async def test_search_memory_is_scoped_by_user(): + memory_service = redis_memory_service() + await memory_service.add_session_to_memory(MOCK_SESSION_1) + await memory_service.add_session_to_memory(MOCK_SESSION_DIFFERENT_USER) + + result = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_USER_ID, query='secret' + ) + result_other_user = await memory_service.search_memory( + app_name=MOCK_APP_NAME, user_id=MOCK_OTHER_USER_ID, query='secret' + ) + + assert not result.memories + assert len(result_other_user.memories) == 1 + assert ( + result_other_user.memories[0].content.parts[0].text == 'This is a secret.' + ) + + +@pytest.mark.asyncio +async def test_close_closes_client(): + client = FakeAsyncRedis() + memory_service = RedisMemoryService(client=client) + + await memory_service.close() + + assert client.closed