-
Notifications
You must be signed in to change notification settings - Fork 94
LCORE-1830: Implement Question Validity Safety Capability in Pydantic AI #1913
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,10 @@ | ||
| """Pluggable capabilities for pydantic-ai agents in Lightspeed. | ||
|
|
||
| Provides safety, guardrail, and policy capabilities that hook into | ||
| pydantic-ai's AbstractCapability lifecycle to enforce constraints | ||
| before, during, or after agent runs. | ||
| """ | ||
|
|
||
| from pydantic_ai_lightspeed.capabilities.question_validity import QuestionValidity | ||
|
|
||
| __all__ = ["QuestionValidity"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,7 @@ | ||
| """Question validity capability for agent input validation.""" | ||
|
|
||
| from pydantic_ai_lightspeed.capabilities.question_validity._capability import ( | ||
| QuestionValidity, | ||
| ) | ||
|
|
||
| __all__ = ["QuestionValidity"] |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,132 @@ | ||
| """Question validity capability for filtering off-topic user queries. | ||
|
|
||
| This module implements a guardrail that classifies user questions as | ||
| Kubernetes/OpenShift-related or not (It can be customized to any | ||
| topic as well), using an LLM-based check before the main agent | ||
| processes the request. Invalid questions are rejected with a | ||
| predefined response, bypassing the primary agent entirely. | ||
| """ | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| from collections.abc import Sequence | ||
| from dataclasses import dataclass, field | ||
| from string import Template | ||
|
|
||
| from pydantic_ai import AgentRunResult, RunContext | ||
| from pydantic_ai._agent_graph import GraphAgentState | ||
| from pydantic_ai.capabilities import AbstractCapability, WrapRunHandler | ||
| from pydantic_ai.direct import model_request | ||
| from pydantic_ai.messages import ModelRequest, TextContent, UserContent | ||
| from pydantic_ai.models import Model, infer_model | ||
|
|
||
| from log import get_logger | ||
| from models.config import ( | ||
| QuestionValidityConfig, | ||
| ) | ||
|
|
||
| logger = get_logger(__name__) | ||
|
|
||
| SUBJECT_REJECTED = "REJECTED" | ||
| SUBJECT_ALLOWED = "ALLOWED" | ||
|
|
||
|
|
||
| def _extract_message_str_from_user_content(user_content: Sequence[UserContent]) -> str: | ||
| """Extract and combine all text content into a string from a UserContent sequence. | ||
|
|
||
| Parameters: | ||
| user_content: A sequence of user content items to extract text from. | ||
|
|
||
| Returns: | ||
| A single string with all text content joined by newlines. | ||
| """ | ||
| str_arr: list[str] = [] | ||
| for c in user_content: | ||
| match c: | ||
| case str() as s: | ||
| str_arr.append(s) | ||
| case TextContent(content=c): | ||
| str_arr.append(c) | ||
|
|
||
| return "\n".join(str_arr) | ||
|
|
||
|
|
||
| @dataclass | ||
| class QuestionValidity(AbstractCapability[None]): | ||
| """Block or modify user input based on a guardrail check. | ||
|
|
||
| The guard function receives the user prompt and returns True if safe. | ||
|
|
||
|
Comment on lines
+56
to
+59
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Update class docstring to match actual behavior. Line 106 describes a boolean guard-function contract, but this class performs LLM classification and short-circuiting with 🤖 Prompt for AI Agents |
||
| Example: | ||
| ```python | ||
| from pydantic_ai import Agent | ||
| from pydantic_ai.models.openai import OpenAIResponsesModel | ||
|
|
||
| model = OpenAIResponsesModel("gpt-4o-mini") | ||
| agent = Agent("openai:gpt-4.1", capabilities=[QuestionValidity(model)]) | ||
| ``` | ||
| """ | ||
|
|
||
| config: QuestionValidityConfig | ||
| _model: Model = field(init=False) | ||
|
|
||
| def __post_init__(self) -> None: | ||
| """Initialize the model instance from the configured model ID.""" | ||
| self._model = infer_model(self.config.model_id) | ||
|
|
||
| def _build_prompt(self, message: str | Sequence[UserContent] | None) -> str: | ||
| """Build the classification prompt from the user message. | ||
|
|
||
| Parameters: | ||
| message: The user input as a string, sequence of user content, or None. | ||
|
|
||
| Returns: | ||
| The rendered prompt string ready to send to the validity model. | ||
| """ | ||
| match message: | ||
| case str() as s: | ||
| _message = s | ||
| case Sequence() as seq: | ||
| _message = _extract_message_str_from_user_content(seq) | ||
| case None: | ||
| _message = "" | ||
|
|
||
| return Template(self.config.model_prompt).substitute( | ||
| message=_message, allowed=SUBJECT_ALLOWED, rejected=SUBJECT_REJECTED | ||
| ) | ||
|
|
||
| async def wrap_run( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. For discussion: overriding
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So just one thing that I would like to bring up for consideration. The output from Follow the agent run lifecycle It's just my thought. I'll follow whatever you guys think is the most suitable solution here. 😁
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok this sounds reasonable. |
||
| self, ctx: RunContext, *, handler: WrapRunHandler | ||
| ) -> AgentRunResult: | ||
| """Run the question validity check before delegating to the main agent. | ||
|
|
||
| Sends the user prompt to the validity model for classification. | ||
| If the question is allowed, the handler proceeds normally. | ||
| Otherwise, a rejection response is returned and the main agent | ||
| is bypassed. | ||
|
|
||
| Parameters: | ||
| ctx: The run context containing the user prompt and usage tracker. | ||
| handler: The handler that invokes the main agent run. | ||
|
|
||
| Returns: | ||
| The agent run result, either from the main agent or a rejection. | ||
| """ | ||
| prompt = self._build_prompt(ctx.prompt) | ||
|
|
||
| result = await model_request( | ||
| model=self._model, | ||
| messages=[ModelRequest.user_text_prompt(prompt)], | ||
| ) | ||
|
asimurka marked this conversation as resolved.
|
||
|
|
||
| # Include token usage from the question validity request | ||
| ctx.usage.incr(result.usage) | ||
|
|
||
| if result.text is not None and result.text.strip() == SUBJECT_ALLOWED: | ||
| return await handler() # proceed with the real run | ||
|
|
||
| # short-circuit: return the rejection message with shield usage tracked | ||
| state = GraphAgentState(usage=ctx.usage) | ||
| return AgentRunResult( | ||
| output=self.config.invalid_question_response, _state=state | ||
| ) | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for pydantic_ai_lightspeed capabilities.""" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1 @@ | ||
| """Unit tests for question validity capability.""" |
Uh oh!
There was an error while loading. Please reload this page.