From 76f51172cd2185293e923ed0ee9467bc6dbe916b Mon Sep 17 00:00:00 2001 From: Jazzcort Date: Thu, 11 Jun 2026 15:54:19 -0400 Subject: [PATCH] LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI Implement an LLM-based guardrail that classifies user questions as on-topic (Kubernetes/OpenShift or customized topic) before the main agent processes them. Off-topic questions are short-circuited with a rejection message, bypassing the primary agent entirely. Includes unit tests. --- src/constants.py | 40 ++ src/models/config.py | 18 + .../capabilities/__init__.py | 10 + .../question_validity/__init__.py | 7 + .../question_validity/_capability.py | 132 +++++ .../capabilities/__init__.py | 1 + .../question_validity/__init__.py | 1 + .../question_validity/test_capability.py | 505 ++++++++++++++++++ 8 files changed, 714 insertions(+) create mode 100644 src/pydantic_ai_lightspeed/capabilities/__init__.py create mode 100644 src/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py create mode 100644 src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py create mode 100644 tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py create mode 100644 tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py create mode 100644 tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py diff --git a/src/constants.py b/src/constants.py index cae458f51..8fa524602 100644 --- a/src/constants.py +++ b/src/constants.py @@ -248,6 +248,46 @@ "I cannot process this request due to policy restrictions." ) +# The Default model prompt and the default invalid question response for QuestionValidityConfig +DEFAULT_MODEL_PROMPT: Final[str] = """ +Instructions: +- You are a question classifying tool +- You are an expert in kubernetes and openshift +- Your job is to determine where or a user's question is related to kubernetes and/or openshift technologies and to provide a one-word response. +- If a question appears to be related to kubernetes or openshift technologies, answer with the word ${allowed}, otherwise answer with the word ${rejected}. +- Do not explain your answer, just provide the one-word response. Do not give any other response. +- If the given question is an empty string, answer with the word ${rejected} + + +Example Question: +Why is the sky blue? +Example Response: +${rejected} + +Example Question: +Why is the grass green? +Example Response: +${rejected} + +Example Question: +Why is sand yellow? +Example Response: +${rejected} + +Example Question: +Can you help configure my cluster to automatically scale? +Example Response: +${allowed} + +Question: +${message} +Response: +""" +DEFAULT_INVALID_QUESTION_RESPONSE: Final[str] = """ +Hi, I'm the OpenShift Lightspeed assistant, I can help you with questions about OpenShift, +please ask me a question related to OpenShift. +""" + # Placeholder slug used in responses when the server substituted its own # system prompt for the client's instructions. Avoids leaking the actual # server prompt back to the client. diff --git a/src/models/config.py b/src/models/config.py index 923d720f0..84ccf8f4f 100644 --- a/src/models/config.py +++ b/src/models/config.py @@ -2060,6 +2060,24 @@ class SkillsConfiguration(ConfigurationBase): ) +class QuestionValidityConfig(ConfigurationBase): + """Configuration for the question validity guardrail.""" + + model_id: str = Field( + ..., title="Model id", description="The model_id to use for the guard" + ) + model_prompt: str = Field( + default=constants.DEFAULT_MODEL_PROMPT, + title="Model prompt", + description="The default prompt sent to the LLM used to validate the Users' question.", + ) + invalid_question_response: str = Field( + default=constants.DEFAULT_INVALID_QUESTION_RESPONSE, + title="Invalid question response", + description="The default response when the Users' question is determined to be invalid.", + ) + + class Configuration(ConfigurationBase): """Global service configuration.""" diff --git a/src/pydantic_ai_lightspeed/capabilities/__init__.py b/src/pydantic_ai_lightspeed/capabilities/__init__.py new file mode 100644 index 000000000..eb73644d5 --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/__init__.py @@ -0,0 +1,10 @@ +"""Pluggable capabilities for pydantic-ai agents in Lightspeed. + +Provides safety, guardrail, and policy capabilities that hook into +pydantic-ai's AbstractCapability lifecycle to enforce constraints +before, during, or after agent runs. +""" + +from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity + +__all__ = ["QuestionValidity"] diff --git a/src/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py b/src/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py new file mode 100644 index 000000000..7d722867f --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py @@ -0,0 +1,7 @@ +"""Question validity capability for agent input validation.""" + +from pydantic_ai_lightspeed.capabilities.question_validity._capability import ( + QuestionValidity, +) + +__all__ = ["QuestionValidity"] diff --git a/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py b/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py new file mode 100644 index 000000000..334ce7e38 --- /dev/null +++ b/src/pydantic_ai_lightspeed/capabilities/question_validity/_capability.py @@ -0,0 +1,132 @@ +"""Question validity capability for filtering off-topic user queries. + +This module implements a guardrail that classifies user questions as +Kubernetes/OpenShift-related or not (It can be customized to any +topic as well), using an LLM-based check before the main agent +processes the request. Invalid questions are rejected with a +predefined response, bypassing the primary agent entirely. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from dataclasses import dataclass, field +from string import Template + +from pydantic_ai import AgentRunResult, RunContext +from pydantic_ai._agent_graph import GraphAgentState +from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler +from pydantic_ai.direct import model_request +from pydantic_ai.messages import ModelRequest, TextContent, UserContent +from pydantic_ai.models import Model, infer_model + +from log import get_logger +from models.config import ( + QuestionValidityConfig, +) + +logger = get_logger(__name__) + +SUBJECT_REJECTED = "REJECTED" +SUBJECT_ALLOWED = "ALLOWED" + + +def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str: + """Extract and combine all text content into a string from a UserContent sequence. + + Parameters: + user_content: A sequence of user content items to extract text from. + + Returns: + A single string with all text content joined by newlines. + """ + str_arr: list[str] = [] + for c in user_content: + match c: + case str() as s: + str_arr.append(s) + case TextContent(content=c): + str_arr.append(c) + + return "\n".join(str_arr) + + +@dataclass +class QuestionValidity(AbstractCapability[None]): + """Block or modify user input based on a guardrail check. + + The guard function receives the user prompt and returns True if safe. + + Example: + ```python + from pydantic_ai import Agent + from pydantic_ai.models.openai import OpenAIResponsesModel + + model = OpenAIResponsesModel("gpt-4o-mini") + agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)]) + ``` + """ + + config: QuestionValidityConfig + _model: Model = field(init=False) + + def __post_init__(self) -> None: + """Initialize the model instance from the configured model ID.""" + self._model = infer_model(self.config.model_id) + + def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str: + """Build the classification prompt from the user message. + + Parameters: + message: The user input as a string, sequence of user content, or None. + + Returns: + The rendered prompt string ready to send to the validity model. + """ + match message: + case str() as s: + _message = s + case Sequence() as seq: + _message = _extract_message_str_from_user_content(seq) + case None: + _message = "" + + return Template(self.config.model_prompt).substitute( + message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED + ) + + async def wrap_run( + self, ctx: RunContext, *, handler: WrapRunHandler + ) -> AgentRunResult: + """Run the question validity check before delegating to the main agent. + + Sends the user prompt to the validity model for classification. + If the question is allowed, the handler proceeds normally. + Otherwise, a rejection response is returned and the main agent + is bypassed. + + Parameters: + ctx: The run context containing the user prompt and usage tracker. + handler: The handler that invokes the main agent run. + + Returns: + The agent run result, either from the main agent or a rejection. + """ + prompt = self._build_prompt(ctx.prompt) + + result = await model_request( + model=self._model, + messages=[ModelRequest.user_text_prompt(prompt)], + ) + + # Include token usage from the question validity request + ctx.usage.incr(result.usage) + + if result.text is not None and result.text.strip() == SUBJECT_ALLOWED: + return await handler() # proceed with the real run + + # short-circuit: return the rejection message with shield usage tracked + state = GraphAgentState(usage=ctx.usage) + return AgentRunResult( + output=self.config.invalid_question_response, _state=state + ) diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py b/tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py new file mode 100644 index 000000000..4e87fe7fd --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/__init__.py @@ -0,0 +1 @@ +"""Unit tests for pydantic_ai_lightspeed capabilities.""" diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py b/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py new file mode 100644 index 000000000..d2381208e --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/__init__.py @@ -0,0 +1 @@ +"""Unit tests for question validity capability.""" diff --git a/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py b/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py new file mode 100644 index 000000000..b3bb7839d --- /dev/null +++ b/tests/unit/pydantic_ai_lightspeed/capabilities/question_validity/test_capability.py @@ -0,0 +1,505 @@ +"""Unit tests for pydantic_ai_lightspeed.capabilities.question_validity._capacity module.""" + +# pylint: disable=protected-access + +import pytest +from pydantic import ValidationError +from pydantic_ai import AgentRunResult, RunContext, UserError +from pydantic_ai.messages import ImageUrl, ModelResponse, TextContent, TextPart +from pydantic_ai.models.openai import OpenAIResponsesModel +from pydantic_ai.usage import RequestUsage, RunUsage +from pytest_mock import MockerFixture, MockType + +from constants import ( + DEFAULT_INVALID_QUESTION_RESPONSE, + DEFAULT_MODEL_PROMPT, +) +from models.config import ( + QuestionValidityConfig, +) +from pydantic_ai_lightspeed.capabilities.question_validity._capability import ( + SUBJECT_ALLOWED, + SUBJECT_REJECTED, + QuestionValidity, + _extract_message_str_from_user_content, +) + + +class TestExtractMessageStrFromUserContent: + """Tests for _extract_message_str_from_user_content helper.""" + + def test_extracts_plain_strings(self) -> None: + """Test extraction from a sequence of plain strings.""" + content = ["hello", "world"] + result = _extract_message_str_from_user_content(content) + assert result == "hello\nworld" + + def test_extracts_text_content(self) -> None: + """Test extraction from TextContent objects.""" + content = [TextContent(content="first"), TextContent(content="second")] + result = _extract_message_str_from_user_content(content) + assert result == "first\nsecond" + + def test_mixed_str_and_text_content(self) -> None: + """Test extraction from a mix of strings and TextContent.""" + content = ["plain", TextContent(content="rich")] + result = _extract_message_str_from_user_content(content) + assert result == "plain\nrich" + + def test_empty_sequence(self) -> None: + """Test extraction from an empty sequence.""" + result = _extract_message_str_from_user_content([]) + assert result == "" + + def test_single_string(self) -> None: + """Test extraction from a single-element sequence.""" + result = _extract_message_str_from_user_content(["only"]) + assert result == "only" + + def test_sequence_with_non_text_content(self) -> None: + """Test extraction from a single-element sequence.""" + result = _extract_message_str_from_user_content([ImageUrl("fake.png"), "keep"]) + assert result == "keep" + + +class TestQuestionValidityConfigInit: + """Tests for QuestionValidityConfig initialization.""" + + def test_default_model_prompt(self) -> None: + """Test that default model_prompt is used.""" + qv_config = QuestionValidityConfig(model_id="test") + + assert qv_config.model_prompt == DEFAULT_MODEL_PROMPT + + def test_default_invalid_question_response(self) -> None: + """Test that default invalid_question_response is used.""" + qv_config = QuestionValidityConfig(model_id="test") + + assert qv_config.invalid_question_response == DEFAULT_INVALID_QUESTION_RESPONSE + + def test_custom_model_prompt(self) -> None: + """Test that custom model_prompt can be provided.""" + qv_config = QuestionValidityConfig( + model_id="test", model_prompt="custom prompt ${message}" + ) + + assert qv_config.model_prompt == "custom prompt ${message}" + + def test_custom_invalid_response(self) -> None: + """Test that custom invalid_question_response can be provided.""" + qv_config = QuestionValidityConfig( + model_id="test", invalid_question_response="Nope!" + ) + + assert qv_config.invalid_question_response == "Nope!" + + def test_missing_model_id_raises_validation_error(self) -> None: + """Test that model_id is required.""" + with pytest.raises(ValidationError): + QuestionValidityConfig() # type: ignore[call-arg] + + def test_unknown_fields_rejected(self) -> None: + """Test that extra fields are rejected.""" + with pytest.raises(ValidationError): + QuestionValidityConfig(model_id="test", unknown_field="value") # type: ignore[call-arg] + + +class TestQuestionValidityInit: + """Tests for QuestionValidity dataclass initialization.""" + + def test_model_creation_with_model_id(self, monkeypatch) -> None: + """Test that a valid provider:model_id creates the correct model.""" + monkeypatch.setenv("OPENAI_API_KEY", "tasty") + config = QuestionValidityConfig(model_id="openai-responses:gpt-5.4") + qv = QuestionValidity(config=config) + assert isinstance(qv._model, OpenAIResponsesModel) + + def test_model_creation_with_invalid_provider(self) -> None: + """Test that an unknown provider raises ValueError.""" + with pytest.raises(ValueError, match="Unknown provider: invalid-provider"): + config = QuestionValidityConfig(model_id="invalid-provider:model") + QuestionValidity(config=config) + + def test_model_creation_with_invalid_model(self) -> None: + """Test that an unknown model raises UserError.""" + with pytest.raises(UserError, match="Unknown model: invalid-model"): + config = QuestionValidityConfig(model_id="invalid-model") + QuestionValidity(config=config) + + +class TestBuildPrompt: + """Tests for QuestionValidity._build_prompt method.""" + + @pytest.fixture(name="question_validity") + def question_validity_fixture(self) -> QuestionValidity: + """Create a QuestionValidity instance with a mock model.""" + config = QuestionValidityConfig(model_id="test") + return QuestionValidity(config=config) + + def test_string_input(self, question_validity: QuestionValidity) -> None: + """Test prompt building with a plain string input.""" + prompt = question_validity._build_prompt("How do I scale pods?") + + assert "How do I scale pods?" in prompt + assert SUBJECT_ALLOWED in prompt + assert SUBJECT_REJECTED in prompt + + def test_none_input(self, question_validity: QuestionValidity) -> None: + """Test prompt building with None input uses empty string.""" + prompt = question_validity._build_prompt(None) + + assert "Question:\n\nResponse:" in prompt + + def test_sequence_input(self, question_validity: QuestionValidity) -> None: + """Test prompt building with a sequence of UserContent.""" + content = ["What is a", TextContent(content="deployment?")] + + prompt = question_validity._build_prompt(content) + + assert "What is a\ndeployment?" in prompt + + def test_substitutes_allowed_and_rejected( + self, question_validity: QuestionValidity + ) -> None: + """Test that ALLOWED and REJECTED tokens are substituted.""" + prompt = question_validity._build_prompt("test") + + assert SUBJECT_ALLOWED in prompt + assert SUBJECT_REJECTED in prompt + assert "${allowed}" not in prompt + assert "${rejected}" not in prompt + assert "${message}" not in prompt + + def test_custom_prompt_template(self) -> None: + """Test with a custom prompt template.""" + config = QuestionValidityConfig( + model_id="test", + model_prompt="Is '${message}' valid? ${allowed}/${rejected}", + ) + qv = QuestionValidity(config=config) + + prompt = qv._build_prompt("my question") + + assert prompt == f"Is 'my question' valid? {SUBJECT_ALLOWED}/{SUBJECT_REJECTED}" + + +class TestWrapRun: + """Tests for QuestionValidity.wrap_run method.""" + + @pytest.fixture(name="mock_ctx") + def mock_ctx_fixture(self, mocker: MockerFixture) -> RunContext: + """Create a mock RunContext.""" + ctx = mocker.Mock(spec=RunContext) + ctx.prompt = "How do I create a pod?" + ctx.usage = RunUsage() + return ctx + + @pytest.fixture(name="mock_handler") + def mock_handler_fixture(self, mocker: MockerFixture) -> MockType: + """Create a mock WrapRunHandler.""" + handler = mocker.AsyncMock() + handler.return_value = mocker.Mock(spec=AgentRunResult) + return handler + + @pytest.mark.asyncio + async def test_allowed_question_calls_handler( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that an allowed question proceeds to the handler.""" + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_ALLOWED)], + usage=RequestUsage(input_tokens=10, output_tokens=1), + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + result = await qv.wrap_run(mock_ctx, handler=mock_handler) + + mock_handler.assert_awaited_once() + assert result == mock_handler.return_value + + @pytest.mark.asyncio + async def test_rejected_question_returns_rejection( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that a rejected question short-circuits with rejection message.""" + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_REJECTED)], + usage=RequestUsage(input_tokens=10, output_tokens=1), + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + result = await qv.wrap_run(mock_ctx, handler=mock_handler) + + mock_handler.assert_not_awaited() + assert isinstance(result, AgentRunResult) + assert result.output == DEFAULT_INVALID_QUESTION_RESPONSE + + @pytest.mark.asyncio + async def test_unexpected_response_treated_as_rejected( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that an unexpected model response is treated as rejection.""" + mock_response = ModelResponse( + parts=[TextPart(content="I don't understand")], + usage=RequestUsage(input_tokens=10, output_tokens=5), + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + result = await qv.wrap_run(mock_ctx, handler=mock_handler) + + mock_handler.assert_not_awaited() + assert result.output == DEFAULT_INVALID_QUESTION_RESPONSE + + @pytest.mark.asyncio + @pytest.mark.parametrize( + "response_text", + [" ALLOWED", "ALLOWED ", " ALLOWED ", "ALLOWED\n"], + ids=["leading-space", "trailing-space", "both-spaces", "trailing-newline"], + ) + async def test_allowed_with_whitespace_still_accepted( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + response_text: str, + ) -> None: + """Test that ALLOWED with surrounding whitespace is still accepted.""" + mock_response = ModelResponse( + parts=[TextPart(content=response_text)], + usage=RequestUsage(input_tokens=10, output_tokens=1), + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + result = await qv.wrap_run(mock_ctx, handler=mock_handler) + + mock_handler.assert_awaited_once() + assert result == mock_handler.return_value + + @pytest.mark.asyncio + async def test_usage_is_incremented( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that token usage from the validity check is tracked.""" + request_usage = RequestUsage(input_tokens=50, output_tokens=5) + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_ALLOWED)], + usage=request_usage, + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + await qv.wrap_run(mock_ctx, handler=mock_handler) + + assert mock_ctx.usage.input_tokens == 50 + assert mock_ctx.usage.output_tokens == 5 + + @pytest.mark.asyncio + async def test_usage_is_incremented_on_rejection( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that token usage is tracked even when question is rejected.""" + request_usage = RequestUsage(input_tokens=30, output_tokens=2) + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_REJECTED)], + usage=request_usage, + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + await qv.wrap_run(mock_ctx, handler=mock_handler) + + assert mock_ctx.usage.input_tokens == 30 + assert mock_ctx.usage.output_tokens == 2 + + @pytest.mark.asyncio + async def test_rejection_result_contains_usage_in_state( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that the rejection AgentRunResult state carries the usage.""" + request_usage = RequestUsage(input_tokens=20, output_tokens=3) + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_REJECTED)], + usage=request_usage, + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + result = await qv.wrap_run(mock_ctx, handler=mock_handler) + + assert result._state.usage == mock_ctx.usage + + @pytest.mark.asyncio + async def test_custom_invalid_response( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that a custom rejection message is used when set.""" + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_REJECTED)], + usage=RequestUsage(), + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig( + model_id="test", invalid_question_response="Custom rejection." + ) + qv = QuestionValidity(config=config) + result = await qv.wrap_run(mock_ctx, handler=mock_handler) + + assert result.output == "Custom rejection." + + @pytest.mark.asyncio + async def test_model_request_receives_correct_prompt( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that model_request is called with the built prompt.""" + mock_model_request = mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=ModelResponse( + parts=[TextPart(content=SUBJECT_ALLOWED)], + usage=RequestUsage(), + ), + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + await qv.wrap_run(mock_ctx, handler=mock_handler) + + call_kwargs = mock_model_request.call_args + assert call_kwargs.kwargs["model"] is qv._model + messages = call_kwargs.kwargs["messages"] + assert len(messages) == 1 + assert "How do I create a pod?" in str(messages[0]) + + @pytest.mark.asyncio + async def test_wrap_run_with_none_prompt( + self, + mocker: MockerFixture, + mock_handler: MockType, + ) -> None: + """Test wrap_run when ctx.prompt is None.""" + ctx = mocker.Mock(spec=RunContext) + ctx.prompt = None + ctx.usage = RunUsage() + + mock_response = ModelResponse( + parts=[TextPart(content=SUBJECT_REJECTED)], + usage=RequestUsage(), + ) + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=mock_response, + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + result = await qv.wrap_run(ctx, handler=mock_handler) + + assert result.output == DEFAULT_INVALID_QUESTION_RESPONSE + + @pytest.mark.asyncio + async def test_wrap_run_propagates_model_request_error( + self, + mocker: MockerFixture, + mock_ctx: RunContext, + mock_handler: MockType, + ) -> None: + """Test that model_request exceptions propagate to the caller.""" + mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + side_effect=RuntimeError("connection failed"), + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + + with pytest.raises(RuntimeError, match="connection failed"): + await qv.wrap_run(mock_ctx, handler=mock_handler) + + mock_handler.assert_not_awaited() + + @pytest.mark.asyncio + async def test_wrap_run_with_sequence_prompt( + self, + mocker: MockerFixture, + mock_handler: MockType, + ) -> None: + """Test wrap_run when ctx.prompt is a Sequence[UserContent].""" + ctx = mocker.Mock(spec=RunContext) + ctx.prompt = ["How to", TextContent(content="scale a deployment?")] + ctx.usage = RunUsage() + + mock_model_request = mocker.patch( + "pydantic_ai_lightspeed.capabilities.question_validity._capability.model_request", + return_value=ModelResponse( + parts=[TextPart(content=SUBJECT_ALLOWED)], + usage=RequestUsage(), + ), + ) + + config = QuestionValidityConfig(model_id="test") + qv = QuestionValidity(config=config) + await qv.wrap_run(ctx, handler=mock_handler) + + messages = mock_model_request.call_args.kwargs["messages"] + prompt_str = str(messages[0]) + assert "How to" in prompt_str + assert "scale a deployment?" in prompt_str