Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions src/google/adk/models/lite_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1883,6 +1883,7 @@ async def _get_completion_inputs(
Optional[List[Dict]],
Optional[Dict[str, Any]],
Optional[Dict],
Optional[str],
]:
"""Converts an LlmRequest to litellm inputs and extracts generation params.

Expand All @@ -1891,8 +1892,8 @@ async def _get_completion_inputs(
model: The model string to use for determining provider-specific behavior.

Returns:
The litellm inputs (message list, tool dictionary, response format and
generation params).
The litellm inputs (message list, tool dictionary, response format,
generation params, and tool_choice).
"""
_ensure_litellm_imported()

Expand Down Expand Up @@ -1967,7 +1968,21 @@ async def _get_completion_inputs(
if not generation_params:
generation_params = None

return messages, tools, response_format, generation_params
# 5. Extract tool_choice from tool_config
tool_choice: Optional[str] = None
if (
llm_request.config
and llm_request.config.tool_config
and llm_request.config.tool_config.function_calling_config
):
mode = llm_request.config.tool_config.function_calling_config.mode
if mode == types.FunctionCallingConfigMode.ANY:
tool_choice = "required"
elif mode == types.FunctionCallingConfigMode.NONE:
tool_choice = "none"
# AUTO → None (provider default)

return messages, tools, response_format, generation_params, tool_choice


def _build_function_declaration_log(
Expand Down Expand Up @@ -2228,7 +2243,7 @@ async def generate_content_async(
logger.debug(_build_request_log(llm_request))

effective_model = llm_request.model or self.model
messages, tools, response_format, generation_params = (
messages, tools, response_format, generation_params, tool_choice = (
await _get_completion_inputs(llm_request, effective_model)
)
normalized_messages = _normalize_ollama_chat_messages(
Expand Down Expand Up @@ -2260,6 +2275,9 @@ async def generate_content_async(
if generation_params:
completion_args.update(generation_params)

if tool_choice is not None:
completion_args["tool_choice"] = tool_choice

if stream:
text = ""
reasoning_parts: List[types.Part] = []
Expand Down
32 changes: 24 additions & 8 deletions src/google/adk/tools/agent_tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,14 +217,30 @@ async def run_async(
input_schema = _get_input_schema(self.agent)
if input_schema:
input_value = input_schema.model_validate(args)
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=input_value.model_dump_json(exclude_none=True)
)
],
)
json_payload = input_value.model_dump_json(exclude_none=True)
output_schema = _get_output_schema(self.agent)
if output_schema:
# Single-shot structured output mode: pass raw JSON, no ReAct wrapper.
content = types.Content(
role='user',
parts=[types.Part.from_text(text=json_payload)],
)
else:
# Tool-calling mode: wrap with ReAct-style prompt.
content = types.Content(
role='user',
parts=[
types.Part.from_text(
text=(
'Process the following structured request. Use your'
' available tools as needed to gather information or'
' perform actions before producing the final'
' response.\n\nRequest:\n'
+ json_payload
)
)
],
)
else:
content = types.Content(
role='user',
Expand Down
Loading