chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,22 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model module."""
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._dashscope_model import DashScopeChatModel
|
||||
from ._openai_model import OpenAIChatModel
|
||||
from ._anthropic_model import AnthropicChatModel
|
||||
from ._ollama_model import OllamaChatModel
|
||||
from ._gemini_model import GeminiChatModel
|
||||
from ._trinity_model import TrinityChatModel
|
||||
|
||||
__all__ = [
|
||||
"ChatModelBase",
|
||||
"ChatResponse",
|
||||
"DashScopeChatModel",
|
||||
"OpenAIChatModel",
|
||||
"AnthropicChatModel",
|
||||
"OllamaChatModel",
|
||||
"GeminiChatModel",
|
||||
"TrinityChatModel",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,507 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches, too-many-statements
|
||||
"""The Anthropic API model classes."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import (
|
||||
_json_loads_with_repair,
|
||||
_create_tool_from_base_model,
|
||||
)
|
||||
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types._json import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from anthropic.types.message import Message
|
||||
from anthropic import AsyncStream
|
||||
else:
|
||||
Message = "anthropic.types.message.Message"
|
||||
AsyncStream = "anthropic.AsyncStream"
|
||||
|
||||
|
||||
class AnthropicChatModel(ChatModelBase):
|
||||
"""The Anthropic model wrapper for AgentScope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str | None = None,
|
||||
max_tokens: int = 2048,
|
||||
stream: bool = True,
|
||||
thinking: dict | None = None,
|
||||
client_args: dict | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Anthropic chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The model names.
|
||||
api_key (`str`):
|
||||
The anthropic API key.
|
||||
stream (`bool`):
|
||||
The streaming output or not
|
||||
max_tokens (`int`):
|
||||
Limit the maximum token count the model can generate.
|
||||
thinking (`dict | None`, default `None`):
|
||||
Configuration for Claude's internal reasoning process.
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of thinking
|
||||
|
||||
{
|
||||
"type": "enabled" | "disabled",
|
||||
"budget_tokens": 1024
|
||||
}
|
||||
|
||||
client_args (`dict | None`, optional):
|
||||
The extra keyword arguments to initialize the Anthropic client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in Gemini API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
"""
|
||||
|
||||
try:
|
||||
import anthropic
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install the `anthropic` package by running "
|
||||
"`pip install anthropic`.",
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = anthropic.AsyncAnthropic(
|
||||
api_key=api_key,
|
||||
**(client_args or {}),
|
||||
)
|
||||
self.max_tokens = max_tokens
|
||||
self.thinking = thinking
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "any", "required"]
|
||||
| str
|
||||
| None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**generate_kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from Anthropic chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that in format of:
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of tools JSON schemas
|
||||
|
||||
[
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "xxx",
|
||||
"description": "xxx",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"param1": {
|
||||
"type": "string",
|
||||
"description": "..."
|
||||
},
|
||||
# Add more parameters as needed
|
||||
},
|
||||
"required": ["param1"]
|
||||
}
|
||||
},
|
||||
# More schemas here
|
||||
]
|
||||
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
**generate_kwargs (`Any`):
|
||||
The keyword arguments for Anthropic chat completions API,
|
||||
e.g. `temperature`, `top_p`, etc. Please
|
||||
refer to the Anthropic API documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the Anthropic chat completions API."""
|
||||
|
||||
kwargs: dict[str, Any] = {
|
||||
"model": self.model_name,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**generate_kwargs,
|
||||
}
|
||||
if self.thinking and "thinking" not in kwargs:
|
||||
kwargs["thinking"] = self.thinking
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
format_tool = _create_tool_from_base_model(structured_model)
|
||||
kwargs["tools"] = self._format_tools_json_schemas(
|
||||
[format_tool],
|
||||
)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(
|
||||
format_tool["function"]["name"],
|
||||
)
|
||||
|
||||
# Extract the system message
|
||||
if messages[0]["role"] == "system":
|
||||
kwargs["system"] = messages[0]["content"]
|
||||
messages = messages[1:]
|
||||
|
||||
kwargs["messages"] = messages
|
||||
|
||||
start_datetime = datetime.now()
|
||||
|
||||
response = await self.client.messages.create(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_anthropic_stream_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
parsed_response = await self._parse_anthropic_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_anthropic_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Message,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an Anthropic Message object, extract the content blocks and
|
||||
usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`Message`):
|
||||
Anthropic Message object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = []
|
||||
metadata = None
|
||||
|
||||
if hasattr(response, "content") and response.content:
|
||||
for content_block in response.content:
|
||||
if (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "thinking"
|
||||
):
|
||||
thinking_block = ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=content_block.thinking,
|
||||
)
|
||||
thinking_block["signature"] = content_block.signature
|
||||
content_blocks.append(thinking_block)
|
||||
|
||||
elif hasattr(content_block, "text"):
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content_block.text,
|
||||
),
|
||||
)
|
||||
|
||||
elif (
|
||||
hasattr(content_block, "type")
|
||||
and content_block.type == "tool_use"
|
||||
):
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=content_block.id,
|
||||
name=content_block.name,
|
||||
input=content_block.input,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = content_block.input
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_anthropic_stream_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncStream,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an Anthropic streaming response, extract the content blocks
|
||||
and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncStream`):
|
||||
Anthropic AsyncStream object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in
|
||||
the streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
|
||||
usage = None
|
||||
text_buffer = ""
|
||||
thinking_buffer = ""
|
||||
thinking_signature = ""
|
||||
tool_calls = OrderedDict()
|
||||
tool_call_buffers = {}
|
||||
res = None
|
||||
metadata = None
|
||||
|
||||
async for event in response:
|
||||
content_changed = False
|
||||
thinking_changed = False
|
||||
|
||||
if event.type == "message_start":
|
||||
message = event.message
|
||||
if message.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=message.usage.input_tokens,
|
||||
output_tokens=getattr(
|
||||
message.usage,
|
||||
"output_tokens",
|
||||
0,
|
||||
),
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
elif event.type == "content_block_start":
|
||||
if event.content_block.type == "tool_use":
|
||||
block_index = event.index
|
||||
tool_block = event.content_block
|
||||
tool_calls[block_index] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_block.id,
|
||||
"name": tool_block.name,
|
||||
"input": "",
|
||||
}
|
||||
tool_call_buffers[block_index] = ""
|
||||
content_changed = True
|
||||
|
||||
elif event.type == "content_block_delta":
|
||||
block_index = event.index
|
||||
delta = event.delta
|
||||
if delta.type == "text_delta":
|
||||
text_buffer += delta.text
|
||||
content_changed = True
|
||||
elif delta.type == "thinking_delta":
|
||||
thinking_buffer += delta.thinking
|
||||
thinking_changed = True
|
||||
elif delta.type == "signature_delta":
|
||||
thinking_signature = delta.signature
|
||||
elif (
|
||||
delta.type == "input_json_delta"
|
||||
and block_index in tool_calls
|
||||
):
|
||||
tool_call_buffers[block_index] += delta.partial_json or ""
|
||||
tool_calls[block_index]["input"] = tool_call_buffers[
|
||||
block_index
|
||||
]
|
||||
content_changed = True
|
||||
|
||||
elif event.type == "message_delta":
|
||||
if event.usage and usage:
|
||||
usage.output_tokens = event.usage.output_tokens
|
||||
|
||||
if (thinking_changed or content_changed) and usage:
|
||||
contents: list = []
|
||||
if thinking_buffer:
|
||||
thinking_block = ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking_buffer,
|
||||
)
|
||||
thinking_block["signature"] = thinking_signature
|
||||
contents.append(thinking_block)
|
||||
if text_buffer:
|
||||
contents.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text_buffer,
|
||||
),
|
||||
)
|
||||
for block_index, tool_call in tool_calls.items():
|
||||
input_str = tool_call["input"]
|
||||
try:
|
||||
input_obj = _json_loads_with_repair(input_str or "{}")
|
||||
if not isinstance(input_obj, dict):
|
||||
input_obj = {}
|
||||
|
||||
except Exception:
|
||||
input_obj = {}
|
||||
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
input=input_obj,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = input_obj
|
||||
if contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the JSON schemas of the tool functions to the format that
|
||||
Anthropic API expects."""
|
||||
formatted_schemas = []
|
||||
for schema in schemas:
|
||||
assert (
|
||||
"function" in schema
|
||||
), f"Invalid schema: {schema}, expect key 'function'."
|
||||
|
||||
assert "name" in schema["function"], (
|
||||
f"Invalid schema: {schema}, "
|
||||
"expect key 'name' in 'function' field."
|
||||
)
|
||||
|
||||
formatted_schemas.append(
|
||||
{
|
||||
"name": schema["function"]["name"],
|
||||
"description": schema["function"].get("description", ""),
|
||||
"input_schema": schema["function"].get("parameters", {}),
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
|
||||
) -> dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
type_mapping = {
|
||||
"auto": {"type": "auto"},
|
||||
"none": {"type": "none"},
|
||||
"any": {"type": "any"},
|
||||
"required": {"type": "any"},
|
||||
}
|
||||
if tool_choice in type_mapping:
|
||||
return type_mapping[tool_choice]
|
||||
|
||||
return {"type": "tool", "name": tool_choice}
|
||||
@@ -0,0 +1,524 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The dashscope API model classes."""
|
||||
import collections
|
||||
from datetime import datetime
|
||||
from http import HTTPStatus
|
||||
from typing import (
|
||||
Any,
|
||||
AsyncGenerator,
|
||||
Generator,
|
||||
Union,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from pydantic import BaseModel
|
||||
from aioitertools import iter as giter
|
||||
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ._model_usage import ChatUsage
|
||||
from .._utils._common import (
|
||||
_json_loads_with_repair,
|
||||
_create_tool_from_base_model,
|
||||
)
|
||||
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
from .._logging import logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dashscope.api_entities.dashscope_response import GenerationResponse
|
||||
from dashscope.api_entities.dashscope_response import (
|
||||
MultiModalConversationResponse,
|
||||
)
|
||||
else:
|
||||
GenerationResponse = (
|
||||
"dashscope.api_entities.dashscope_response.GenerationResponse"
|
||||
)
|
||||
MultiModalConversationResponse = (
|
||||
"dashscope.api_entities.dashscope_response."
|
||||
"MultiModalConversationResponse"
|
||||
)
|
||||
|
||||
|
||||
class DashScopeChatModel(ChatModelBase):
|
||||
"""The DashScope chat model class, which unifies the Generation and
|
||||
MultimodalConversation APIs into one method."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
stream: bool = True,
|
||||
enable_thinking: bool | None = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
base_http_api_url: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The model names.
|
||||
api_key (`str`):
|
||||
The dashscope API key.
|
||||
stream (`bool`):
|
||||
The streaming output or not
|
||||
enable_thinking (`bool | None`, optional):
|
||||
Enable thinking or not, only support Qwen3, QwQ, DeepSeek-R1.
|
||||
Refer to `DashScope documentation
|
||||
<https://help.aliyun.com/zh/model-studio/deep-thinking>`_
|
||||
for more details.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in DashScope API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
base_http_api_url (`str | None`, optional):
|
||||
The base URL for DashScope API requests. If not provided,
|
||||
the default base URL from the DashScope SDK will be used.
|
||||
"""
|
||||
if enable_thinking and not stream:
|
||||
logger.info(
|
||||
"In DashScope API, `stream` must be True when "
|
||||
"`enable_thinking` is True. ",
|
||||
)
|
||||
stream = True
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.api_key = api_key
|
||||
self.enable_thinking = enable_thinking
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
if base_http_api_url is not None:
|
||||
import dashscope
|
||||
|
||||
dashscope.base_http_api_url = base_http_api_url
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "any", "required"]
|
||||
| str
|
||||
| None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from the dashscope
|
||||
Generation/MultimodalConversation API by the given arguments.
|
||||
|
||||
.. note:: We unify the dashscope generation and multimodal conversation
|
||||
APIs into one method, since they support similar arguments and share
|
||||
the same functionality.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required.
|
||||
tools (`list[dict] | None`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", or specific tool name.
|
||||
For more details, please refer to
|
||||
https://help.aliyun.com/zh/model-studio/qwen-function-calling
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for DashScope chat completions API,
|
||||
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
|
||||
refer to `DashScope documentation
|
||||
<https://help.aliyun.com/zh/dashscope/developer-reference/api-details>`_
|
||||
for more detailed arguments.
|
||||
"""
|
||||
import dashscope
|
||||
|
||||
# For qvq and qwen-vl models, the content field cannot be `None` or
|
||||
# `[{"text": None}]`, so we need to convert it to an empty list.
|
||||
if self.model_name.startswith("qvq") or "-vl" in self.model_name:
|
||||
for msg in messages:
|
||||
if msg["content"] is None or msg["content"] == [
|
||||
{"text": None},
|
||||
]:
|
||||
msg["content"] = []
|
||||
|
||||
kwargs = {
|
||||
"messages": messages,
|
||||
"model": self.model_name,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**kwargs,
|
||||
"result_format": "message",
|
||||
# In agentscope, the `incremental_output` must be `True` when
|
||||
# `self.stream` is True
|
||||
"incremental_output": self.stream,
|
||||
}
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if (
|
||||
self.enable_thinking is not None
|
||||
and "enable_thinking" not in kwargs
|
||||
):
|
||||
kwargs["enable_thinking"] = self.enable_thinking
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
format_tool = _create_tool_from_base_model(structured_model)
|
||||
kwargs["tools"] = self._format_tools_json_schemas(
|
||||
[format_tool],
|
||||
)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(
|
||||
format_tool["function"]["name"],
|
||||
)
|
||||
|
||||
start_datetime = datetime.now()
|
||||
if self.model_name.startswith("qvq") or "-vl" in self.model_name:
|
||||
response = dashscope.MultiModalConversation.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
else:
|
||||
response = await dashscope.aigc.generation.AioGeneration.call(
|
||||
api_key=self.api_key,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_dashscope_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
parsed_response = await self._parse_dashscope_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
# pylint: disable=too-many-branches
|
||||
async def _parse_dashscope_stream_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Union[
|
||||
AsyncGenerator[GenerationResponse, None],
|
||||
Generator[MultiModalConversationResponse, None, None],
|
||||
],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, Any]:
|
||||
"""Given a DashScope streaming response generator, extract the content
|
||||
blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (
|
||||
`Union[AsyncGenerator[GenerationResponse, None], Generator[ \
|
||||
MultiModalConversationResponse, None, None]]`
|
||||
):
|
||||
DashScope streaming response generator (GenerationResponse or
|
||||
MultiModalConversationResponse) to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[ChatResponse, Any]:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
acc_content, acc_thinking_content = "", ""
|
||||
acc_tool_calls = collections.defaultdict(dict)
|
||||
metadata = None
|
||||
|
||||
async for chunk in giter(response):
|
||||
if chunk.status_code != HTTPStatus.OK:
|
||||
raise RuntimeError(
|
||||
f"Failed to get response from _ API: {chunk}",
|
||||
)
|
||||
|
||||
message = chunk.output.choices[0].message
|
||||
|
||||
# Update reasoning content
|
||||
if isinstance(message.get("reasoning_content"), str):
|
||||
acc_thinking_content += message["reasoning_content"]
|
||||
|
||||
# Update text content
|
||||
if isinstance(message.content, str):
|
||||
acc_content += message.content
|
||||
elif isinstance(message.content, list):
|
||||
for item in message.content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
acc_content += item["text"]
|
||||
|
||||
# Update tool calls
|
||||
for tool_call in message.get("tool_calls", []):
|
||||
index = tool_call.get("index", 0)
|
||||
|
||||
if "id" in tool_call and tool_call["id"] != acc_tool_calls[
|
||||
index
|
||||
].get("id"):
|
||||
acc_tool_calls[index]["id"] = (
|
||||
acc_tool_calls[index].get("id", "") + tool_call["id"]
|
||||
)
|
||||
|
||||
if "function" in tool_call:
|
||||
func = tool_call["function"]
|
||||
if "name" in func:
|
||||
acc_tool_calls[index]["name"] = (
|
||||
acc_tool_calls[index].get("name", "")
|
||||
+ func["name"]
|
||||
)
|
||||
|
||||
if "arguments" in func:
|
||||
acc_tool_calls[index]["arguments"] = (
|
||||
acc_tool_calls[index].get("arguments", "")
|
||||
+ func["arguments"]
|
||||
)
|
||||
|
||||
# to content blocks
|
||||
content_blocks: list[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
if acc_thinking_content:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=acc_thinking_content,
|
||||
),
|
||||
)
|
||||
|
||||
if acc_content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=acc_content,
|
||||
),
|
||||
)
|
||||
|
||||
for tool_call in acc_tool_calls.values():
|
||||
repaired_input = _json_loads_with_repair(
|
||||
tool_call.get("arguments", "{}") or "{}",
|
||||
)
|
||||
|
||||
if not isinstance(repaired_input, dict):
|
||||
repaired_input = {}
|
||||
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.get("id", ""),
|
||||
name=tool_call.get("name", ""),
|
||||
input=repaired_input,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = repaired_input
|
||||
|
||||
usage = None
|
||||
if chunk.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage.input_tokens,
|
||||
output_tokens=chunk.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_chunk = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield parsed_chunk
|
||||
|
||||
async def _parse_dashscope_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: Union[
|
||||
GenerationResponse,
|
||||
MultiModalConversationResponse,
|
||||
],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given a DashScope GenerationResponse object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (
|
||||
`Union[GenerationResponse, MultiModalConversationResponse]`
|
||||
):
|
||||
Dashscope GenerationResponse | MultiModalConversationResponse
|
||||
object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
# Collect the content blocks from the response.
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(response)
|
||||
|
||||
content_blocks: List[TextBlock | ToolUseBlock] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
message = response.output.choices[0].message
|
||||
content = message.get("content")
|
||||
|
||||
if response.output.choices[0].message.get("content") not in [
|
||||
None,
|
||||
"",
|
||||
[],
|
||||
]:
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
if isinstance(item, dict) and "text" in item:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=item["text"],
|
||||
),
|
||||
)
|
||||
else:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=content,
|
||||
),
|
||||
)
|
||||
|
||||
if message.get("tool_calls"):
|
||||
for tool_call in message["tool_calls"]:
|
||||
input_ = _json_loads_with_repair(
|
||||
tool_call["function"].get(
|
||||
"arguments",
|
||||
"{}",
|
||||
)
|
||||
or "{}",
|
||||
)
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
name=tool_call["function"]["name"],
|
||||
input=input_,
|
||||
id=tool_call["id"],
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = input_
|
||||
|
||||
# Usage information
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.input_tokens,
|
||||
output_tokens=response.usage.output_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schema into required format for DashScope API.
|
||||
|
||||
Args:
|
||||
schemas (`dict[str, dict[str, Any]]`):
|
||||
The tools JSON schemas.
|
||||
"""
|
||||
# Check schemas format
|
||||
for value in schemas:
|
||||
if (
|
||||
not isinstance(value, dict)
|
||||
or "type" not in value
|
||||
or value["type"] != "function"
|
||||
or "function" not in value
|
||||
):
|
||||
raise ValueError(
|
||||
f"Each schema must be a dict with 'type' as 'function' "
|
||||
f"and 'function' key, got {value}",
|
||||
)
|
||||
|
||||
return schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
|
||||
) -> str | dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", or specific tool name.
|
||||
For more details, please refer to
|
||||
https://help.aliyun.com/zh/model-studio/qwen-function-calling
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
if tool_choice in ["auto", "none"]:
|
||||
return tool_choice
|
||||
if tool_choice in ["any", "required"]:
|
||||
logger.warning(
|
||||
"tool_choice '%s' is not supported by DashScope API. "
|
||||
"Supported options are 'auto', 'none', or specific function "
|
||||
"name. Automatically using 'auto' instead.",
|
||||
tool_choice,
|
||||
)
|
||||
return "auto"
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
@@ -0,0 +1,487 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# mypy: disable-error-code="dict-item"
|
||||
"""The Google Gemini model in agentscope."""
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
AsyncGenerator,
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
AsyncIterator,
|
||||
Literal,
|
||||
Type,
|
||||
List,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._logging import logger
|
||||
from .._utils._common import _json_loads_with_repair
|
||||
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
|
||||
from ._model_usage import ChatUsage
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_response import ChatResponse
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google.genai.types import GenerateContentResponse
|
||||
else:
|
||||
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
|
||||
|
||||
|
||||
class GeminiChatModel(ChatModelBase):
|
||||
"""The Google Gemini chat model class in agentscope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str,
|
||||
stream: bool = True,
|
||||
thinking_config: dict | None = None,
|
||||
client_args: dict = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the Gemini model to use, e.g. "gemini-2.5-flash".
|
||||
api_key (`str`):
|
||||
The API key for Google Gemini.
|
||||
stream (`bool`, default `True`):
|
||||
Whether to use streaming output or not.
|
||||
thinking_config (`dict | None`, optional):
|
||||
Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc.
|
||||
Refer to https://ai.google.dev/gemini-api/docs/thinking for
|
||||
more details.
|
||||
|
||||
.. code-block:: python
|
||||
:caption: Example of thinking_config
|
||||
|
||||
{
|
||||
"include_thoughts": True, # enable thoughts or not
|
||||
"thinking_budget": 1024 # Max tokens for reasoning
|
||||
}
|
||||
|
||||
client_args (`dict`, default `None`):
|
||||
The extra keyword arguments to initialize the OpenAI client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in Gemini API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
"""
|
||||
try:
|
||||
from google import genai
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Please install gemini Python sdk with "
|
||||
"`pip install -q -U google-genai`",
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = genai.Client(
|
||||
api_key=api_key,
|
||||
**(client_args or {}),
|
||||
)
|
||||
self.thinking_config = thinking_config
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "any", "required"]
|
||||
| str
|
||||
| None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**config_kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Call the Gemini model with the provided arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict[str, Any]]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required.
|
||||
tools (`list[dict] | None`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
For more details, please refer to
|
||||
https://ai.google.dev/gemini-api/docs/structured-output
|
||||
|
||||
**config_kwargs (`Any`):
|
||||
The keyword arguments for Gemini chat completions API.
|
||||
"""
|
||||
|
||||
config: dict = {
|
||||
"thinking_config": self.thinking_config,
|
||||
**self.generate_kwargs,
|
||||
**config_kwargs,
|
||||
}
|
||||
|
||||
if tools:
|
||||
config["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
config["tool_config"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
config.pop("tools", None)
|
||||
config.pop("tool_config", None)
|
||||
config["response_mime_type"] = "application/json"
|
||||
config["response_schema"] = structured_model
|
||||
|
||||
# Prepare the arguments for the Gemini API call
|
||||
kwargs: dict[str, JSONSerializableObject] = {
|
||||
"model": self.model_name,
|
||||
"contents": messages,
|
||||
"config": config,
|
||||
}
|
||||
|
||||
start_datetime = datetime.now()
|
||||
if self.stream:
|
||||
response = await self.client.aio.models.generate_content_stream(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
return self._parse_gemini_stream_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# non-streaming
|
||||
response = await self.client.aio.models.generate_content(
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
parsed_response = self._parse_gemini_generation_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_gemini_stream_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncIterator[GenerateContentResponse],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given a Gemini streaming generation response, extract the
|
||||
content blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncIterator[GenerateContentResponse]`):
|
||||
Gemini GenerateContentResponse async iterator to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
|
||||
text = ""
|
||||
thinking = ""
|
||||
metadata: dict | None = None
|
||||
async for chunk in response:
|
||||
content_block: list = []
|
||||
|
||||
# Thinking parts
|
||||
if (
|
||||
chunk.candidates
|
||||
and chunk.candidates[0].content
|
||||
and chunk.candidates[0].content.parts
|
||||
):
|
||||
for part in chunk.candidates[0].content.parts:
|
||||
if part.thought and part.text:
|
||||
thinking += part.text
|
||||
|
||||
# Text parts
|
||||
if chunk.text:
|
||||
text += chunk.text
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(text)
|
||||
|
||||
# Function calls
|
||||
tool_calls = []
|
||||
if chunk.function_calls:
|
||||
for function_call in chunk.function_calls:
|
||||
tool_calls.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=function_call.id,
|
||||
name=function_call.name,
|
||||
input=function_call.args or {},
|
||||
),
|
||||
)
|
||||
|
||||
usage = None
|
||||
if chunk.usage_metadata:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage_metadata.prompt_token_count,
|
||||
output_tokens=chunk.usage_metadata.total_token_count
|
||||
- chunk.usage_metadata.prompt_token_count,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
if thinking:
|
||||
content_block.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking,
|
||||
),
|
||||
)
|
||||
|
||||
if text:
|
||||
content_block.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text,
|
||||
),
|
||||
)
|
||||
|
||||
content_block.extend(
|
||||
[
|
||||
*tool_calls,
|
||||
],
|
||||
)
|
||||
|
||||
parsed_chunk = ChatResponse(
|
||||
content=content_block,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield parsed_chunk
|
||||
|
||||
def _parse_gemini_generation_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: GenerateContentResponse,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given a Gemini chat completion response object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`ChatCompletion`):
|
||||
The OpenAI chat completion response object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
if (
|
||||
response.candidates
|
||||
and response.candidates[0].content
|
||||
and response.candidates[0].content.parts
|
||||
):
|
||||
for part in response.candidates[0].content.parts:
|
||||
if part.thought and part.text:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=part.text,
|
||||
),
|
||||
)
|
||||
|
||||
if response.text:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=response.text,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(response.text)
|
||||
|
||||
if response.function_calls:
|
||||
for tool_call in response.function_calls:
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.name,
|
||||
input=tool_call.args or {},
|
||||
),
|
||||
)
|
||||
|
||||
if response.usage_metadata:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage_metadata.prompt_token_count,
|
||||
output_tokens=response.usage_metadata.total_token_count
|
||||
- response.usage_metadata.prompt_token_count,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
else:
|
||||
usage = None
|
||||
|
||||
return ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schema into required format for Gemini API.
|
||||
|
||||
Args:
|
||||
schemas (`dict[str, Any]`):
|
||||
The tools JSON schemas.
|
||||
|
||||
Returns:
|
||||
List[Dict[str, Any]]:
|
||||
A list containing a dictionary with the
|
||||
"function_declarations" key, which maps to a list of
|
||||
function definitions.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
:caption: Example tool schemas of Gemini API
|
||||
|
||||
# Input JSON schema
|
||||
schemas = [
|
||||
{
|
||||
'type': 'function',
|
||||
'function': {
|
||||
'name': 'execute_shell_command',
|
||||
'description': 'xxx',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'xxx.'
|
||||
},
|
||||
'timeout': {
|
||||
'type': 'integer',
|
||||
'default': 300
|
||||
}
|
||||
},
|
||||
'required': ['command']
|
||||
}
|
||||
}
|
||||
}
|
||||
]
|
||||
|
||||
# Output format (Gemini API expected):
|
||||
[
|
||||
{
|
||||
'function_declarations': [
|
||||
{
|
||||
'name': 'execute_shell_command',
|
||||
'description': 'xxx.',
|
||||
'parameters': {
|
||||
'type': 'object',
|
||||
'properties': {
|
||||
'command': {
|
||||
'type': 'string',
|
||||
'description': 'xxx.'
|
||||
},
|
||||
'timeout': {
|
||||
'type': 'integer',
|
||||
'default': 300
|
||||
}
|
||||
},
|
||||
'required': ['command']
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
"""
|
||||
return [
|
||||
{
|
||||
"function_declarations": [
|
||||
_["function"] for _ in schemas if "function" in _
|
||||
],
|
||||
},
|
||||
]
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
|
||||
) -> dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none"] | str | None`, default \
|
||||
`None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name.
|
||||
For more details, please refer to
|
||||
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
|
||||
mode_mapping = {
|
||||
"auto": "AUTO",
|
||||
"none": "NONE",
|
||||
"any": "ANY",
|
||||
"required": "ANY",
|
||||
}
|
||||
mode = mode_mapping.get(tool_choice)
|
||||
if mode:
|
||||
return {"function_calling_config": {"mode": mode}}
|
||||
return {
|
||||
"function_calling_config": {
|
||||
"mode": "ANY",
|
||||
"allowed_function_names": [tool_choice],
|
||||
},
|
||||
}
|
||||
@@ -0,0 +1,76 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The chat model base class."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import AsyncGenerator, Any
|
||||
|
||||
from ._model_response import ChatResponse
|
||||
|
||||
TOOL_CHOICE_MODES = ["auto", "none", "any", "required"]
|
||||
|
||||
|
||||
class ChatModelBase:
|
||||
"""Base class for chat models."""
|
||||
|
||||
model_name: str
|
||||
"""The model name"""
|
||||
|
||||
stream: bool
|
||||
"""Is the model output streaming or not"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
stream: bool,
|
||||
) -> None:
|
||||
"""Initialize the chat model base class.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the model
|
||||
stream (`bool`):
|
||||
Whether the model output is streaming or not
|
||||
"""
|
||||
self.model_name = model_name
|
||||
self.stream = stream
|
||||
|
||||
@abstractmethod
|
||||
async def __call__(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
pass
|
||||
|
||||
def _validate_tool_choice(
|
||||
self,
|
||||
tool_choice: str,
|
||||
tools: list[dict] | None,
|
||||
) -> None:
|
||||
"""
|
||||
Validate tool_choice parameter.
|
||||
|
||||
Args:
|
||||
tool_choice (`str`):
|
||||
Tool choice mode or function name
|
||||
tools (`list[dict] | None`):
|
||||
Available tools list
|
||||
Raises:
|
||||
TypeError: If tool_choice is not string
|
||||
ValueError: If tool_choice is invalid
|
||||
"""
|
||||
if not isinstance(tool_choice, str):
|
||||
raise TypeError(
|
||||
f"tool_choice must be str, got {type(tool_choice)}",
|
||||
)
|
||||
if tool_choice in TOOL_CHOICE_MODES:
|
||||
return
|
||||
|
||||
available_functions = [tool["function"]["name"] for tool in tools]
|
||||
|
||||
if tool_choice not in available_functions:
|
||||
all_options = TOOL_CHOICE_MODES + available_functions
|
||||
raise ValueError(
|
||||
f"Invalid tool_choice '{tool_choice}'. "
|
||||
f"Available options: {', '.join(sorted(all_options))}",
|
||||
)
|
||||
@@ -0,0 +1,42 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model response module."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Sequence
|
||||
|
||||
from ._model_usage import ChatUsage
|
||||
from .._utils._common import _get_timestamp
|
||||
from .._utils._mixin import DictMixin
|
||||
from ..message import (
|
||||
TextBlock,
|
||||
ToolUseBlock,
|
||||
ThinkingBlock,
|
||||
AudioBlock,
|
||||
)
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatResponse(DictMixin):
|
||||
"""The response of chat models."""
|
||||
|
||||
content: Sequence[TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock]
|
||||
"""The content of the chat response, which can include text blocks,
|
||||
tool use blocks, or thinking blocks."""
|
||||
|
||||
id: str = field(default_factory=lambda: _get_timestamp(True))
|
||||
"""The unique identifier formatter """
|
||||
|
||||
created_at: str = field(default_factory=_get_timestamp)
|
||||
"""When the response was created"""
|
||||
|
||||
type: Literal["chat"] = field(default_factory=lambda: "chat")
|
||||
"""The type of the response, which is always 'chat'."""
|
||||
|
||||
usage: ChatUsage | None = field(default_factory=lambda: None)
|
||||
"""The usage information of the chat response, if available."""
|
||||
|
||||
metadata: dict[str, JSONSerializableObject] | None = field(
|
||||
default_factory=lambda: None,
|
||||
)
|
||||
"""The metadata of the chat response"""
|
||||
@@ -0,0 +1,23 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The model usage class in agentscope."""
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal
|
||||
|
||||
from .._utils._mixin import DictMixin
|
||||
|
||||
|
||||
@dataclass
|
||||
class ChatUsage(DictMixin):
|
||||
"""The usage of a chat model API invocation."""
|
||||
|
||||
input_tokens: int
|
||||
"""The number of input tokens."""
|
||||
|
||||
output_tokens: int
|
||||
"""The number of output tokens."""
|
||||
|
||||
time: float
|
||||
"""The time used in seconds."""
|
||||
|
||||
type: Literal["chat"] = field(default_factory=lambda: "chat")
|
||||
"""The type of the usage, must be `chat`."""
|
||||
@@ -0,0 +1,345 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Model wrapper for Ollama models."""
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
AsyncGenerator,
|
||||
AsyncIterator,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import ChatResponse
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import _json_loads_with_repair
|
||||
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
|
||||
from ..tracing import trace_llm
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ollama._types import ChatResponse as OllamaChatResponse
|
||||
else:
|
||||
OllamaChatResponse = "ollama._types.ChatResponse"
|
||||
|
||||
|
||||
class OllamaChatModel(ChatModelBase):
|
||||
"""The Ollama chat model class in agentscope."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
stream: bool = False,
|
||||
options: dict = None,
|
||||
keep_alive: str = "5m",
|
||||
enable_thinking: bool | None = None,
|
||||
host: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the Ollama chat model.
|
||||
|
||||
Args:
|
||||
model_name (`str`):
|
||||
The name of the model.
|
||||
stream (`bool`, default `True`):
|
||||
Streaming mode or not.
|
||||
options (`dict`, default `None`):
|
||||
Additional parameters to pass to the Ollama API. These can
|
||||
include temperature etc.
|
||||
keep_alive (`str`, default `"5m"`):
|
||||
Duration to keep the model loaded in memory. The format is a
|
||||
number followed by a unit suffix (s for seconds, m for minutes
|
||||
, h for hours).
|
||||
enable_thinking (`bool | None`, default `None`)
|
||||
Whether enable thinking or not, only for models such as qwen3,
|
||||
deepseek-r1, etc. For more details, please refer to
|
||||
https://ollama.com/search?c=thinking
|
||||
host (`str | None`, default `None`):
|
||||
The host address of the Ollama server. If None, uses the
|
||||
default address (typically http://localhost:11434).
|
||||
**kwargs (`Any`):
|
||||
Additional keyword arguments to pass to the base chat model
|
||||
class.
|
||||
"""
|
||||
|
||||
try:
|
||||
import ollama
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"The package ollama is not found. Please install it by "
|
||||
'running command `pip install "ollama>=0.1.7"`',
|
||||
) from e
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
self.client = ollama.AsyncClient(
|
||||
host=host,
|
||||
**kwargs,
|
||||
)
|
||||
self.options = options
|
||||
self.keep_alive = keep_alive
|
||||
self.think = enable_thinking
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "any", "required"]
|
||||
| str
|
||||
| None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from Ollama chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for Ollama chat completions API,
|
||||
e.g. `think`etc. Please refer to the Ollama API
|
||||
documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the Ollama chat completions API.
|
||||
"""
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
"options": self.options,
|
||||
"keep_alive": self.keep_alive,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
if self.think is not None and "think" not in kwargs:
|
||||
kwargs["think"] = self.think
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
logger.warning("Ollama does not support tool_choice yet, ignored.")
|
||||
|
||||
if structured_model:
|
||||
kwargs["format"] = structured_model.model_json_schema()
|
||||
|
||||
start_datetime = datetime.now()
|
||||
response = await self.client.chat(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_ollama_stream_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
parsed_response = await self._parse_ollama_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_ollama_stream_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncIterator[OllamaChatResponse],
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an Ollama streaming completion response, extract the
|
||||
content blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncIterator[OllamaChatResponse]`):
|
||||
Ollama streaming response async iterator to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
AsyncGenerator[ChatResponse, None]:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in the
|
||||
streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
|
||||
"""
|
||||
accumulated_text = ""
|
||||
acc_thinking_content = ""
|
||||
tool_calls = OrderedDict() # Store tool calls
|
||||
metadata: dict | None = None
|
||||
|
||||
async for chunk in response:
|
||||
# Handle text content
|
||||
msg = chunk.message
|
||||
acc_thinking_content += msg.thinking or ""
|
||||
accumulated_text += msg.content or ""
|
||||
|
||||
# Handle tool calls
|
||||
for idx, tool_call in enumerate(msg.tool_calls or []):
|
||||
function = tool_call.function
|
||||
tool_id = f"{idx}_{function.name}"
|
||||
tool_calls[tool_id] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_id,
|
||||
"name": function.name,
|
||||
"input": function.arguments,
|
||||
}
|
||||
# Calculate usage statistics
|
||||
current_time = (datetime.now() - start_datetime).total_seconds()
|
||||
usage = ChatUsage(
|
||||
input_tokens=getattr(chunk, "prompt_eval_count", 0) or 0,
|
||||
output_tokens=getattr(chunk, "eval_count", 0) or 0,
|
||||
time=current_time,
|
||||
)
|
||||
# Create content blocks
|
||||
contents: list = []
|
||||
|
||||
if acc_thinking_content:
|
||||
contents.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=acc_thinking_content,
|
||||
),
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
contents.append(TextBlock(type="text", text=accumulated_text))
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(accumulated_text)
|
||||
|
||||
# Add tool call blocks
|
||||
for tool_call in tool_calls.values():
|
||||
try:
|
||||
input_data = tool_call["input"]
|
||||
if isinstance(input_data, str):
|
||||
input_data = _json_loads_with_repair(input_data)
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
input=input_data,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Error parsing tool call input: {e}")
|
||||
|
||||
# Generate response when there's new content or at final chunk
|
||||
if chunk.done and contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
|
||||
async def _parse_ollama_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: OllamaChatResponse,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an Ollama chat completion response object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`OllamaChatResponse`):
|
||||
Ollama OllamaChatResponse object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`ChatResponse`:
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
if response.message.thinking:
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=response.message.thinking,
|
||||
),
|
||||
)
|
||||
|
||||
if response.message.content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=response.message.content,
|
||||
),
|
||||
)
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(
|
||||
response.message.content,
|
||||
)
|
||||
|
||||
for idx, tool_call in enumerate(response.message.tool_calls or []):
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=f"{idx}_{tool_call.function.name}",
|
||||
name=tool_call.function.name,
|
||||
input=tool_call.function.arguments,
|
||||
),
|
||||
)
|
||||
|
||||
usage = None
|
||||
if "prompt_eval_count" in response and "eval_count" in response:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.get("prompt_eval_count", 0),
|
||||
output_tokens=response.get("eval_count", 0),
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schemas to the Ollama format."""
|
||||
return schemas
|
||||
@@ -0,0 +1,545 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""OpenAI Chat model class."""
|
||||
from datetime import datetime
|
||||
from typing import (
|
||||
Any,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
AsyncGenerator,
|
||||
Literal,
|
||||
Type,
|
||||
)
|
||||
from collections import OrderedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import ChatResponse
|
||||
from ._model_base import ChatModelBase
|
||||
from ._model_usage import ChatUsage
|
||||
from .._logging import logger
|
||||
from .._utils._common import _json_loads_with_repair
|
||||
from ..message import (
|
||||
ToolUseBlock,
|
||||
TextBlock,
|
||||
ThinkingBlock,
|
||||
AudioBlock,
|
||||
Base64Source,
|
||||
)
|
||||
from ..tracing import trace_llm
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai.types.chat import ChatCompletion
|
||||
from openai import AsyncStream
|
||||
else:
|
||||
ChatCompletion = "openai.types.chat.ChatCompletion"
|
||||
AsyncStream = "openai.types.chat.AsyncStream"
|
||||
|
||||
|
||||
def _format_audio_data_for_qwen_omni(messages: list[dict]) -> None:
|
||||
"""Qwen-omni uses OpenAI-compatible API but requires different audio
|
||||
data format than OpenAI with "data:;base64," prefix.
|
||||
Refer to `Qwen-omni documentation
|
||||
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`_
|
||||
for more details.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
The list of message dictionaries from OpenAI formatter.
|
||||
"""
|
||||
for msg in messages:
|
||||
if isinstance(msg.get("content"), list):
|
||||
for block in msg["content"]:
|
||||
if (
|
||||
isinstance(block, dict)
|
||||
and "input_audio" in block
|
||||
and isinstance(block["input_audio"].get("data"), str)
|
||||
):
|
||||
if not block["input_audio"]["data"].startswith("http"):
|
||||
block["input_audio"]["data"] = (
|
||||
"data:;base64," + block["input_audio"]["data"]
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatModel(ChatModelBase):
|
||||
"""The OpenAI chat model class."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_name: str,
|
||||
api_key: str | None = None,
|
||||
stream: bool = True,
|
||||
reasoning_effort: Literal["low", "medium", "high"] | None = None,
|
||||
organization: str = None,
|
||||
client_args: dict = None,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the openai client.
|
||||
|
||||
Args:
|
||||
model_name (`str`, default `None`):
|
||||
The name of the model to use in OpenAI API.
|
||||
api_key (`str`, default `None`):
|
||||
The API key for OpenAI API. If not specified, it will
|
||||
be read from the environment variable `OPENAI_API_KEY`.
|
||||
stream (`bool`, default `True`):
|
||||
Whether to use streaming output or not.
|
||||
reasoning_effort (`Literal["low", "medium", "high"] | None`, \
|
||||
optional):
|
||||
Reasoning effort, supported for o3, o4, etc. Please refer to
|
||||
`OpenAI documentation
|
||||
<https://platform.openai.com/docs/guides/reasoning?api-mode=chat>`_
|
||||
for more details.
|
||||
organization (`str`, default `None`):
|
||||
The organization ID for OpenAI API. If not specified, it will
|
||||
be read from the environment variable `OPENAI_ORGANIZATION`.
|
||||
client_args (`dict`, default `None`):
|
||||
The extra keyword arguments to initialize the OpenAI client.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
The extra keyword arguments used in OpenAI API generation,
|
||||
e.g. `temperature`, `seed`.
|
||||
"""
|
||||
|
||||
super().__init__(model_name, stream)
|
||||
|
||||
import openai
|
||||
|
||||
self.client = openai.AsyncClient(
|
||||
api_key=api_key,
|
||||
organization=organization,
|
||||
**(client_args or {}),
|
||||
)
|
||||
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.generate_kwargs = generate_kwargs or {}
|
||||
|
||||
@trace_llm
|
||||
async def __call__(
|
||||
self,
|
||||
messages: list[dict],
|
||||
tools: list[dict] | None = None,
|
||||
tool_choice: Literal["auto", "none", "any", "required"]
|
||||
| str
|
||||
| None = None,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
|
||||
"""Get the response from OpenAI chat completions API by the given
|
||||
arguments.
|
||||
|
||||
Args:
|
||||
messages (`list[dict]`):
|
||||
A list of dictionaries, where `role` and `content` fields are
|
||||
required, and `name` field is optional.
|
||||
tools (`list[dict]`, default `None`):
|
||||
The tools JSON schemas that the model can use.
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output. When provided, the model will be forced
|
||||
to return data that conforms to this schema by automatically
|
||||
converting the BaseModel to a tool function and setting
|
||||
`tool_choice` to enforce its usage. This enables structured
|
||||
output generation.
|
||||
|
||||
.. note:: When `structured_model` is specified,
|
||||
both `tools` and `tool_choice` parameters are ignored,
|
||||
and the model will only perform structured output
|
||||
generation without calling any other tools.
|
||||
|
||||
For more details, please refer to the `official document
|
||||
<https://platform.openai.com/docs/guides/structured-outputs>`_
|
||||
|
||||
**kwargs (`Any`):
|
||||
The keyword arguments for OpenAI chat completions API,
|
||||
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
|
||||
refer to the OpenAI API documentation for more details.
|
||||
|
||||
Returns:
|
||||
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
|
||||
The response from the OpenAI chat completions API.
|
||||
"""
|
||||
|
||||
# checking messages
|
||||
if not isinstance(messages, list):
|
||||
raise ValueError(
|
||||
"OpenAI `messages` field expected type `list`, "
|
||||
f"got `{type(messages)}` instead.",
|
||||
)
|
||||
if not all("role" in msg and "content" in msg for msg in messages):
|
||||
raise ValueError(
|
||||
"Each message in the 'messages' list must contain a 'role' "
|
||||
"and 'content' key for OpenAI API.",
|
||||
)
|
||||
|
||||
# Qwen-omni requires different base64 audio format from openai
|
||||
if "omni" in self.model_name.lower():
|
||||
_format_audio_data_for_qwen_omni(messages)
|
||||
|
||||
kwargs = {
|
||||
"model": self.model_name,
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
**self.generate_kwargs,
|
||||
**kwargs,
|
||||
}
|
||||
if self.reasoning_effort and "reasoning_effort" not in kwargs:
|
||||
kwargs["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
if tools:
|
||||
kwargs["tools"] = self._format_tools_json_schemas(tools)
|
||||
|
||||
if tool_choice:
|
||||
self._validate_tool_choice(tool_choice, tools)
|
||||
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
|
||||
|
||||
if self.stream:
|
||||
kwargs["stream_options"] = {"include_usage": True}
|
||||
|
||||
start_datetime = datetime.now()
|
||||
|
||||
if structured_model:
|
||||
if tools or tool_choice:
|
||||
logger.warning(
|
||||
"structured_model is provided. Both 'tools' and "
|
||||
"'tool_choice' parameters will be overridden and "
|
||||
"ignored. The model will only perform structured output "
|
||||
"generation without calling any other tools.",
|
||||
)
|
||||
kwargs.pop("stream", None)
|
||||
kwargs.pop("tools", None)
|
||||
kwargs.pop("tool_choice", None)
|
||||
kwargs["response_format"] = structured_model
|
||||
if not self.stream:
|
||||
response = await self.client.chat.completions.parse(**kwargs)
|
||||
else:
|
||||
response = self.client.chat.completions.stream(**kwargs)
|
||||
return self._parse_openai_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
else:
|
||||
response = await self.client.chat.completions.create(**kwargs)
|
||||
|
||||
if self.stream:
|
||||
return self._parse_openai_stream_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
# Non-streaming response
|
||||
parsed_response = self._parse_openai_completion_response(
|
||||
start_datetime,
|
||||
response,
|
||||
structured_model,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
async def _parse_openai_stream_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: AsyncStream,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> AsyncGenerator[ChatResponse, None]:
|
||||
"""Given an OpenAI streaming completion response, extract the content
|
||||
blocks and usages from it and yield ChatResponse objects.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`AsyncStream`):
|
||||
OpenAI AsyncStream object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
`AsyncGenerator[ChatResponse, None]`:
|
||||
An async generator that yields ChatResponse objects containing
|
||||
the content blocks and usage information for each chunk in
|
||||
the streaming response.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
usage, res = None, None
|
||||
text = ""
|
||||
thinking = ""
|
||||
audio = ""
|
||||
tool_calls = OrderedDict()
|
||||
metadata: dict | None = None
|
||||
contents: List[
|
||||
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
|
||||
] = []
|
||||
|
||||
async with response as stream:
|
||||
async for item in stream:
|
||||
if structured_model:
|
||||
if item.type != "chunk":
|
||||
continue
|
||||
chunk = item.chunk
|
||||
else:
|
||||
chunk = item
|
||||
|
||||
if chunk.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=chunk.usage.prompt_tokens,
|
||||
output_tokens=chunk.usage.completion_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
if not chunk.choices:
|
||||
if usage and contents:
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
|
||||
thinking += (
|
||||
getattr(choice.delta, "reasoning_content", None) or ""
|
||||
)
|
||||
text += choice.delta.content or ""
|
||||
|
||||
if (
|
||||
hasattr(choice.delta, "audio")
|
||||
and "data" in choice.delta.audio
|
||||
):
|
||||
audio += choice.delta.audio["data"]
|
||||
if (
|
||||
hasattr(choice.delta, "audio")
|
||||
and "transcript" in choice.delta.audio
|
||||
):
|
||||
text += choice.delta.audio["transcript"]
|
||||
|
||||
for tool_call in choice.delta.tool_calls or []:
|
||||
if tool_call.index in tool_calls:
|
||||
if tool_call.function.arguments is not None:
|
||||
tool_calls[tool_call.index][
|
||||
"input"
|
||||
] += tool_call.function.arguments
|
||||
|
||||
else:
|
||||
tool_calls[tool_call.index] = {
|
||||
"type": "tool_use",
|
||||
"id": tool_call.id,
|
||||
"name": tool_call.function.name,
|
||||
"input": tool_call.function.arguments or "",
|
||||
}
|
||||
|
||||
contents = []
|
||||
|
||||
if thinking:
|
||||
contents.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=thinking,
|
||||
),
|
||||
)
|
||||
|
||||
if audio:
|
||||
media_type = self.generate_kwargs.get("audio", {}).get(
|
||||
"format",
|
||||
"wav",
|
||||
)
|
||||
contents.append(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
data=audio,
|
||||
media_type=f"audio/{media_type}",
|
||||
type="base64",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if text:
|
||||
contents.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=text,
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = _json_loads_with_repair(text)
|
||||
|
||||
for tool_call in tool_calls.values():
|
||||
contents.append(
|
||||
ToolUseBlock(
|
||||
type=tool_call["type"],
|
||||
id=tool_call["id"],
|
||||
name=tool_call["name"],
|
||||
input=_json_loads_with_repair(
|
||||
tool_call["input"] or "{}",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if not contents:
|
||||
continue
|
||||
|
||||
res = ChatResponse(
|
||||
content=contents,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
yield res
|
||||
|
||||
def _parse_openai_completion_response(
|
||||
self,
|
||||
start_datetime: datetime,
|
||||
response: ChatCompletion,
|
||||
structured_model: Type[BaseModel] | None = None,
|
||||
) -> ChatResponse:
|
||||
"""Given an OpenAI chat completion response object, extract the content
|
||||
blocks and usages from it.
|
||||
|
||||
Args:
|
||||
start_datetime (`datetime`):
|
||||
The start datetime of the response generation.
|
||||
response (`ChatCompletion`):
|
||||
OpenAI ChatCompletion object to parse.
|
||||
structured_model (`Type[BaseModel] | None`, default `None`):
|
||||
A Pydantic BaseModel class that defines the expected structure
|
||||
for the model's output.
|
||||
|
||||
Returns:
|
||||
ChatResponse (`ChatResponse`):
|
||||
A ChatResponse object containing the content blocks and usage.
|
||||
|
||||
.. note::
|
||||
If `structured_model` is not `None`, the expected structured output
|
||||
will be stored in the metadata of the `ChatResponse`.
|
||||
"""
|
||||
content_blocks: List[
|
||||
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
|
||||
] = []
|
||||
metadata: dict | None = None
|
||||
|
||||
if response.choices:
|
||||
choice = response.choices[0]
|
||||
if (
|
||||
hasattr(choice.message, "reasoning_content")
|
||||
and choice.message.reasoning_content is not None
|
||||
):
|
||||
content_blocks.append(
|
||||
ThinkingBlock(
|
||||
type="thinking",
|
||||
thinking=response.choices[0].message.reasoning_content,
|
||||
),
|
||||
)
|
||||
|
||||
if choice.message.content:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=response.choices[0].message.content,
|
||||
),
|
||||
)
|
||||
if choice.message.audio:
|
||||
media_type = self.generate_kwargs.get("audio", {}).get(
|
||||
"format",
|
||||
"mp3",
|
||||
)
|
||||
content_blocks.append(
|
||||
AudioBlock(
|
||||
type="audio",
|
||||
source=Base64Source(
|
||||
data=choice.message.audio.data,
|
||||
media_type=f"audio/{media_type}",
|
||||
type="base64",
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if choice.message.audio.transcript:
|
||||
content_blocks.append(
|
||||
TextBlock(
|
||||
type="text",
|
||||
text=choice.message.audio.transcript,
|
||||
),
|
||||
)
|
||||
|
||||
for tool_call in choice.message.tool_calls or []:
|
||||
content_blocks.append(
|
||||
ToolUseBlock(
|
||||
type="tool_use",
|
||||
id=tool_call.id,
|
||||
name=tool_call.function.name,
|
||||
input=_json_loads_with_repair(
|
||||
tool_call.function.arguments,
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
if structured_model:
|
||||
metadata = choice.message.parsed.model_dump()
|
||||
|
||||
usage = None
|
||||
if response.usage:
|
||||
usage = ChatUsage(
|
||||
input_tokens=response.usage.prompt_tokens,
|
||||
output_tokens=response.usage.completion_tokens,
|
||||
time=(datetime.now() - start_datetime).total_seconds(),
|
||||
)
|
||||
|
||||
parsed_response = ChatResponse(
|
||||
content=content_blocks,
|
||||
usage=usage,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
return parsed_response
|
||||
|
||||
def _format_tools_json_schemas(
|
||||
self,
|
||||
schemas: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the tools JSON schemas to the OpenAI format."""
|
||||
return schemas
|
||||
|
||||
def _format_tool_choice(
|
||||
self,
|
||||
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
|
||||
) -> str | dict | None:
|
||||
"""Format tool_choice parameter for API compatibility.
|
||||
|
||||
Args:
|
||||
tool_choice (`Literal["auto", "none", "any", "required"] | str \
|
||||
| None`, default `None`):
|
||||
Controls which (if any) tool is called by the model.
|
||||
Can be "auto", "none", "any", "required", or specific tool
|
||||
name. For more details, please refer to
|
||||
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
|
||||
Returns:
|
||||
`dict | None`:
|
||||
The formatted tool choice configuration dict, or None if
|
||||
tool_choice is None.
|
||||
"""
|
||||
if tool_choice is None:
|
||||
return None
|
||||
mode_mapping = {
|
||||
"auto": "auto",
|
||||
"none": "none",
|
||||
"any": "required",
|
||||
"required": "required",
|
||||
}
|
||||
if tool_choice in mode_mapping:
|
||||
return mode_mapping[tool_choice]
|
||||
return {"type": "function", "function": {"name": tool_choice}}
|
||||
@@ -0,0 +1,63 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""A model class for RL Training with Trinity-RFT."""
|
||||
from typing import (
|
||||
Optional,
|
||||
TYPE_CHECKING,
|
||||
)
|
||||
from ._openai_model import OpenAIChatModel
|
||||
from ..types import JSONSerializableObject
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from openai import AsyncOpenAI
|
||||
else:
|
||||
AsyncOpenAI = "openai.AsyncOpenAI"
|
||||
|
||||
|
||||
class TrinityChatModel(OpenAIChatModel):
|
||||
"""A model class for RL Training with Trinity-RFT."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
openai_async_client: AsyncOpenAI,
|
||||
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
|
||||
enable_thinking: Optional[bool] = None,
|
||||
) -> None:
|
||||
"""Initialize the Trinity model class.
|
||||
|
||||
Args:
|
||||
openai_async_client (`AsyncOpenAI`):
|
||||
The OpenAI async client instance provided by Trinity-RFT.
|
||||
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
|
||||
optional):
|
||||
Additional keyword arguments to pass to the model's generate
|
||||
method. Defaults to None.
|
||||
enable_thinking (`bool`, optional):
|
||||
Whether to enable the model's thinking capability. Only
|
||||
applicable for Qwen3 series models. Defaults to None.
|
||||
"""
|
||||
model_name = getattr(openai_async_client, "model_path", None)
|
||||
if model_name is None:
|
||||
raise ValueError(
|
||||
"The provided openai_async_client does not have a "
|
||||
"`model_path` attribute. Please ensure you are using "
|
||||
"the instance provided by Trinity-RFT.",
|
||||
)
|
||||
super().__init__(
|
||||
model_name=model_name,
|
||||
api_key="EMPTY",
|
||||
generate_kwargs=generate_kwargs,
|
||||
stream=False, # RL training does not support streaming
|
||||
)
|
||||
if enable_thinking is not None:
|
||||
if "chat_template_kwargs" not in self.generate_kwargs:
|
||||
self.generate_kwargs["chat_template_kwargs"] = {}
|
||||
assert isinstance(
|
||||
self.generate_kwargs["chat_template_kwargs"],
|
||||
dict,
|
||||
), "chat_template_kwargs must be a dictionary."
|
||||
self.generate_kwargs["chat_template_kwargs"][
|
||||
"enable_thinking"
|
||||
] = enable_thinking
|
||||
# change the client instance to the provided one
|
||||
self.client = openai_async_client
|
||||
Reference in New Issue
Block a user