chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,22 @@
# -*- coding: utf-8 -*-
"""The model module."""
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._dashscope_model import DashScopeChatModel
from ._openai_model import OpenAIChatModel
from ._anthropic_model import AnthropicChatModel
from ._ollama_model import OllamaChatModel
from ._gemini_model import GeminiChatModel
from ._trinity_model import TrinityChatModel
__all__ = [
"ChatModelBase",
"ChatResponse",
"DashScopeChatModel",
"OpenAIChatModel",
"AnthropicChatModel",
"OllamaChatModel",
"GeminiChatModel",
"TrinityChatModel",
]

View File

@@ -0,0 +1,507 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches, too-many-statements
"""The Anthropic API model classes."""
from datetime import datetime
from typing import (
Any,
AsyncGenerator,
TYPE_CHECKING,
List,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import (
_json_loads_with_repair,
_create_tool_from_base_model,
)
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types._json import JSONSerializableObject
if TYPE_CHECKING:
from anthropic.types.message import Message
from anthropic import AsyncStream
else:
Message = "anthropic.types.message.Message"
AsyncStream = "anthropic.AsyncStream"
class AnthropicChatModel(ChatModelBase):
"""The Anthropic model wrapper for AgentScope."""
def __init__(
self,
model_name: str,
api_key: str | None = None,
max_tokens: int = 2048,
stream: bool = True,
thinking: dict | None = None,
client_args: dict | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
) -> None:
"""Initialize the Anthropic chat model.
Args:
model_name (`str`):
The model names.
api_key (`str`):
The anthropic API key.
stream (`bool`):
The streaming output or not
max_tokens (`int`):
Limit the maximum token count the model can generate.
thinking (`dict | None`, default `None`):
Configuration for Claude's internal reasoning process.
.. code-block:: python
:caption: Example of thinking
{
"type": "enabled" | "disabled",
"budget_tokens": 1024
}
client_args (`dict | None`, optional):
The extra keyword arguments to initialize the Anthropic client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Gemini API generation,
e.g. `temperature`, `seed`.
"""
try:
import anthropic
except ImportError as e:
raise ImportError(
"Please install the `anthropic` package by running "
"`pip install anthropic`.",
) from e
super().__init__(model_name, stream)
self.client = anthropic.AsyncAnthropic(
api_key=api_key,
**(client_args or {}),
)
self.max_tokens = max_tokens
self.thinking = thinking
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "any", "required"]
| str
| None = None,
structured_model: Type[BaseModel] | None = None,
**generate_kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from Anthropic chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that in format of:
.. code-block:: python
:caption: Example of tools JSON schemas
[
{
"type": "function",
"function": {
"name": "xxx",
"description": "xxx",
"parameters": {
"type": "object",
"properties": {
"param1": {
"type": "string",
"description": "..."
},
# Add more parameters as needed
},
"required": ["param1"]
}
},
# More schemas here
]
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name. For more details, please refer to
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
**generate_kwargs (`Any`):
The keyword arguments for Anthropic chat completions API,
e.g. `temperature`, `top_p`, etc. Please
refer to the Anthropic API documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the Anthropic chat completions API."""
kwargs: dict[str, Any] = {
"model": self.model_name,
"max_tokens": self.max_tokens,
"stream": self.stream,
**self.generate_kwargs,
**generate_kwargs,
}
if self.thinking and "thinking" not in kwargs:
kwargs["thinking"] = self.thinking
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
format_tool = _create_tool_from_base_model(structured_model)
kwargs["tools"] = self._format_tools_json_schemas(
[format_tool],
)
kwargs["tool_choice"] = self._format_tool_choice(
format_tool["function"]["name"],
)
# Extract the system message
if messages[0]["role"] == "system":
kwargs["system"] = messages[0]["content"]
messages = messages[1:]
kwargs["messages"] = messages
start_datetime = datetime.now()
response = await self.client.messages.create(**kwargs)
if self.stream:
return self._parse_anthropic_stream_completion_response(
start_datetime,
response,
structured_model,
)
# Non-streaming response
parsed_response = await self._parse_anthropic_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_anthropic_completion_response(
self,
start_datetime: datetime,
response: Message,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Anthropic Message object, extract the content blocks and
usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`Message`):
Anthropic Message object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[ThinkingBlock | TextBlock | ToolUseBlock] = []
metadata = None
if hasattr(response, "content") and response.content:
for content_block in response.content:
if (
hasattr(content_block, "type")
and content_block.type == "thinking"
):
thinking_block = ThinkingBlock(
type="thinking",
thinking=content_block.thinking,
)
thinking_block["signature"] = content_block.signature
content_blocks.append(thinking_block)
elif hasattr(content_block, "text"):
content_blocks.append(
TextBlock(
type="text",
text=content_block.text,
),
)
elif (
hasattr(content_block, "type")
and content_block.type == "tool_use"
):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=content_block.id,
name=content_block.name,
input=content_block.input,
),
)
if structured_model:
metadata = content_block.input
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
async def _parse_anthropic_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncStream,
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Anthropic streaming response, extract the content blocks
and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncStream`):
Anthropic AsyncStream object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in
the streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
usage = None
text_buffer = ""
thinking_buffer = ""
thinking_signature = ""
tool_calls = OrderedDict()
tool_call_buffers = {}
res = None
metadata = None
async for event in response:
content_changed = False
thinking_changed = False
if event.type == "message_start":
message = event.message
if message.usage:
usage = ChatUsage(
input_tokens=message.usage.input_tokens,
output_tokens=getattr(
message.usage,
"output_tokens",
0,
),
time=(datetime.now() - start_datetime).total_seconds(),
)
elif event.type == "content_block_start":
if event.content_block.type == "tool_use":
block_index = event.index
tool_block = event.content_block
tool_calls[block_index] = {
"type": "tool_use",
"id": tool_block.id,
"name": tool_block.name,
"input": "",
}
tool_call_buffers[block_index] = ""
content_changed = True
elif event.type == "content_block_delta":
block_index = event.index
delta = event.delta
if delta.type == "text_delta":
text_buffer += delta.text
content_changed = True
elif delta.type == "thinking_delta":
thinking_buffer += delta.thinking
thinking_changed = True
elif delta.type == "signature_delta":
thinking_signature = delta.signature
elif (
delta.type == "input_json_delta"
and block_index in tool_calls
):
tool_call_buffers[block_index] += delta.partial_json or ""
tool_calls[block_index]["input"] = tool_call_buffers[
block_index
]
content_changed = True
elif event.type == "message_delta":
if event.usage and usage:
usage.output_tokens = event.usage.output_tokens
if (thinking_changed or content_changed) and usage:
contents: list = []
if thinking_buffer:
thinking_block = ThinkingBlock(
type="thinking",
thinking=thinking_buffer,
)
thinking_block["signature"] = thinking_signature
contents.append(thinking_block)
if text_buffer:
contents.append(
TextBlock(
type="text",
text=text_buffer,
),
)
for block_index, tool_call in tool_calls.items():
input_str = tool_call["input"]
try:
input_obj = _json_loads_with_repair(input_str or "{}")
if not isinstance(input_obj, dict):
input_obj = {}
except Exception:
input_obj = {}
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_obj,
),
)
if structured_model:
metadata = input_obj
if contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the JSON schemas of the tool functions to the format that
Anthropic API expects."""
formatted_schemas = []
for schema in schemas:
assert (
"function" in schema
), f"Invalid schema: {schema}, expect key 'function'."
assert "name" in schema["function"], (
f"Invalid schema: {schema}, "
"expect key 'name' in 'function' field."
)
formatted_schemas.append(
{
"name": schema["function"]["name"],
"description": schema["function"].get("description", ""),
"input_schema": schema["function"].get("parameters", {}),
},
)
return formatted_schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
) -> dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name. For more details, please refer to
https://docs.anthropic.com/en/docs/agents-and-tools/tool-use/implement-tool-use
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
type_mapping = {
"auto": {"type": "auto"},
"none": {"type": "none"},
"any": {"type": "any"},
"required": {"type": "any"},
}
if tool_choice in type_mapping:
return type_mapping[tool_choice]
return {"type": "tool", "name": tool_choice}

View File

@@ -0,0 +1,524 @@
# -*- coding: utf-8 -*-
"""The dashscope API model classes."""
import collections
from datetime import datetime
from http import HTTPStatus
from typing import (
Any,
AsyncGenerator,
Generator,
Union,
TYPE_CHECKING,
List,
Literal,
Type,
)
from pydantic import BaseModel
from aioitertools import iter as giter
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ._model_usage import ChatUsage
from .._utils._common import (
_json_loads_with_repair,
_create_tool_from_base_model,
)
from ..message import TextBlock, ToolUseBlock, ThinkingBlock
from ..tracing import trace_llm
from ..types import JSONSerializableObject
from .._logging import logger
if TYPE_CHECKING:
from dashscope.api_entities.dashscope_response import GenerationResponse
from dashscope.api_entities.dashscope_response import (
MultiModalConversationResponse,
)
else:
GenerationResponse = (
"dashscope.api_entities.dashscope_response.GenerationResponse"
)
MultiModalConversationResponse = (
"dashscope.api_entities.dashscope_response."
"MultiModalConversationResponse"
)
class DashScopeChatModel(ChatModelBase):
"""The DashScope chat model class, which unifies the Generation and
MultimodalConversation APIs into one method."""
def __init__(
self,
model_name: str,
api_key: str,
stream: bool = True,
enable_thinking: bool | None = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
base_http_api_url: str | None = None,
) -> None:
"""Initialize the DashScope chat model.
Args:
model_name (`str`):
The model names.
api_key (`str`):
The dashscope API key.
stream (`bool`):
The streaming output or not
enable_thinking (`bool | None`, optional):
Enable thinking or not, only support Qwen3, QwQ, DeepSeek-R1.
Refer to `DashScope documentation
<https://help.aliyun.com/zh/model-studio/deep-thinking>`_
for more details.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in DashScope API generation,
e.g. `temperature`, `seed`.
base_http_api_url (`str | None`, optional):
The base URL for DashScope API requests. If not provided,
the default base URL from the DashScope SDK will be used.
"""
if enable_thinking and not stream:
logger.info(
"In DashScope API, `stream` must be True when "
"`enable_thinking` is True. ",
)
stream = True
super().__init__(model_name, stream)
self.api_key = api_key
self.enable_thinking = enable_thinking
self.generate_kwargs = generate_kwargs or {}
if base_http_api_url is not None:
import dashscope
dashscope.base_http_api_url = base_http_api_url
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "any", "required"]
| str
| None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from the dashscope
Generation/MultimodalConversation API by the given arguments.
.. note:: We unify the dashscope generation and multimodal conversation
APIs into one method, since they support similar arguments and share
the same functionality.
Args:
messages (`list[dict[str, Any]]`):
A list of dictionaries, where `role` and `content` fields are
required.
tools (`list[dict] | None`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", or specific tool name.
For more details, please refer to
https://help.aliyun.com/zh/model-studio/qwen-function-calling
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
**kwargs (`Any`):
The keyword arguments for DashScope chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
refer to `DashScope documentation
<https://help.aliyun.com/zh/dashscope/developer-reference/api-details>`_
for more detailed arguments.
"""
import dashscope
# For qvq and qwen-vl models, the content field cannot be `None` or
# `[{"text": None}]`, so we need to convert it to an empty list.
if self.model_name.startswith("qvq") or "-vl" in self.model_name:
for msg in messages:
if msg["content"] is None or msg["content"] == [
{"text": None},
]:
msg["content"] = []
kwargs = {
"messages": messages,
"model": self.model_name,
"stream": self.stream,
**self.generate_kwargs,
**kwargs,
"result_format": "message",
# In agentscope, the `incremental_output` must be `True` when
# `self.stream` is True
"incremental_output": self.stream,
}
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if (
self.enable_thinking is not None
and "enable_thinking" not in kwargs
):
kwargs["enable_thinking"] = self.enable_thinking
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
format_tool = _create_tool_from_base_model(structured_model)
kwargs["tools"] = self._format_tools_json_schemas(
[format_tool],
)
kwargs["tool_choice"] = self._format_tool_choice(
format_tool["function"]["name"],
)
start_datetime = datetime.now()
if self.model_name.startswith("qvq") or "-vl" in self.model_name:
response = dashscope.MultiModalConversation.call(
api_key=self.api_key,
**kwargs,
)
else:
response = await dashscope.aigc.generation.AioGeneration.call(
api_key=self.api_key,
**kwargs,
)
if self.stream:
return self._parse_dashscope_stream_response(
start_datetime,
response,
structured_model,
)
parsed_response = await self._parse_dashscope_generation_response(
start_datetime,
response,
structured_model,
)
return parsed_response
# pylint: disable=too-many-branches
async def _parse_dashscope_stream_response(
self,
start_datetime: datetime,
response: Union[
AsyncGenerator[GenerationResponse, None],
Generator[MultiModalConversationResponse, None, None],
],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, Any]:
"""Given a DashScope streaming response generator, extract the content
blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (
`Union[AsyncGenerator[GenerationResponse, None], Generator[ \
MultiModalConversationResponse, None, None]]`
):
DashScope streaming response generator (GenerationResponse or
MultiModalConversationResponse) to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
AsyncGenerator[ChatResponse, Any]:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
acc_content, acc_thinking_content = "", ""
acc_tool_calls = collections.defaultdict(dict)
metadata = None
async for chunk in giter(response):
if chunk.status_code != HTTPStatus.OK:
raise RuntimeError(
f"Failed to get response from _ API: {chunk}",
)
message = chunk.output.choices[0].message
# Update reasoning content
if isinstance(message.get("reasoning_content"), str):
acc_thinking_content += message["reasoning_content"]
# Update text content
if isinstance(message.content, str):
acc_content += message.content
elif isinstance(message.content, list):
for item in message.content:
if isinstance(item, dict) and "text" in item:
acc_content += item["text"]
# Update tool calls
for tool_call in message.get("tool_calls", []):
index = tool_call.get("index", 0)
if "id" in tool_call and tool_call["id"] != acc_tool_calls[
index
].get("id"):
acc_tool_calls[index]["id"] = (
acc_tool_calls[index].get("id", "") + tool_call["id"]
)
if "function" in tool_call:
func = tool_call["function"]
if "name" in func:
acc_tool_calls[index]["name"] = (
acc_tool_calls[index].get("name", "")
+ func["name"]
)
if "arguments" in func:
acc_tool_calls[index]["arguments"] = (
acc_tool_calls[index].get("arguments", "")
+ func["arguments"]
)
# to content blocks
content_blocks: list[TextBlock | ToolUseBlock | ThinkingBlock] = []
if acc_thinking_content:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=acc_thinking_content,
),
)
if acc_content:
content_blocks.append(
TextBlock(
type="text",
text=acc_content,
),
)
for tool_call in acc_tool_calls.values():
repaired_input = _json_loads_with_repair(
tool_call.get("arguments", "{}") or "{}",
)
if not isinstance(repaired_input, dict):
repaired_input = {}
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.get("id", ""),
name=tool_call.get("name", ""),
input=repaired_input,
),
)
if structured_model:
metadata = repaired_input
usage = None
if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.input_tokens,
output_tokens=chunk.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_chunk = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
yield parsed_chunk
async def _parse_dashscope_generation_response(
self,
start_datetime: datetime,
response: Union[
GenerationResponse,
MultiModalConversationResponse,
],
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given a DashScope GenerationResponse object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (
`Union[GenerationResponse, MultiModalConversationResponse]`
):
Dashscope GenerationResponse | MultiModalConversationResponse
object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
# Collect the content blocks from the response.
if response.status_code != 200:
raise RuntimeError(response)
content_blocks: List[TextBlock | ToolUseBlock] = []
metadata: dict | None = None
message = response.output.choices[0].message
content = message.get("content")
if response.output.choices[0].message.get("content") not in [
None,
"",
[],
]:
if isinstance(content, list):
for item in content:
if isinstance(item, dict) and "text" in item:
content_blocks.append(
TextBlock(
type="text",
text=item["text"],
),
)
else:
content_blocks.append(
TextBlock(
type="text",
text=content,
),
)
if message.get("tool_calls"):
for tool_call in message["tool_calls"]:
input_ = _json_loads_with_repair(
tool_call["function"].get(
"arguments",
"{}",
)
or "{}",
)
content_blocks.append(
ToolUseBlock(
type="tool_use",
name=tool_call["function"]["name"],
input=input_,
id=tool_call["id"],
),
)
if structured_model:
metadata = input_
# Usage information
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.input_tokens,
output_tokens=response.usage.output_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schema into required format for DashScope API.
Args:
schemas (`dict[str, dict[str, Any]]`):
The tools JSON schemas.
"""
# Check schemas format
for value in schemas:
if (
not isinstance(value, dict)
or "type" not in value
or value["type"] != "function"
or "function" not in value
):
raise ValueError(
f"Each schema must be a dict with 'type' as 'function' "
f"and 'function' key, got {value}",
)
return schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
) -> str | dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", or specific tool name.
For more details, please refer to
https://help.aliyun.com/zh/model-studio/qwen-function-calling
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
if tool_choice in ["auto", "none"]:
return tool_choice
if tool_choice in ["any", "required"]:
logger.warning(
"tool_choice '%s' is not supported by DashScope API. "
"Supported options are 'auto', 'none', or specific function "
"name. Automatically using 'auto' instead.",
tool_choice,
)
return "auto"
return {"type": "function", "function": {"name": tool_choice}}

View File

@@ -0,0 +1,487 @@
# -*- coding: utf-8 -*-
# mypy: disable-error-code="dict-item"
"""The Google Gemini model in agentscope."""
from datetime import datetime
from typing import (
AsyncGenerator,
Any,
TYPE_CHECKING,
AsyncIterator,
Literal,
Type,
List,
)
from pydantic import BaseModel
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
from ._model_usage import ChatUsage
from ._model_base import ChatModelBase
from ._model_response import ChatResponse
from ..tracing import trace_llm
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from google.genai.types import GenerateContentResponse
else:
GenerateContentResponse = "google.genai.types.GenerateContentResponse"
class GeminiChatModel(ChatModelBase):
"""The Google Gemini chat model class in agentscope."""
def __init__(
self,
model_name: str,
api_key: str,
stream: bool = True,
thinking_config: dict | None = None,
client_args: dict = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
) -> None:
"""Initialize the Gemini chat model.
Args:
model_name (`str`):
The name of the Gemini model to use, e.g. "gemini-2.5-flash".
api_key (`str`):
The API key for Google Gemini.
stream (`bool`, default `True`):
Whether to use streaming output or not.
thinking_config (`dict | None`, optional):
Thinking config, supported models are 2.5 Pro, 2.5 Flash, etc.
Refer to https://ai.google.dev/gemini-api/docs/thinking for
more details.
.. code-block:: python
:caption: Example of thinking_config
{
"include_thoughts": True, # enable thoughts or not
"thinking_budget": 1024 # Max tokens for reasoning
}
client_args (`dict`, default `None`):
The extra keyword arguments to initialize the OpenAI client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in Gemini API generation,
e.g. `temperature`, `seed`.
"""
try:
from google import genai
except ImportError as e:
raise ImportError(
"Please install gemini Python sdk with "
"`pip install -q -U google-genai`",
) from e
super().__init__(model_name, stream)
self.client = genai.Client(
api_key=api_key,
**(client_args or {}),
)
self.thinking_config = thinking_config
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "any", "required"]
| str
| None = None,
structured_model: Type[BaseModel] | None = None,
**config_kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Call the Gemini model with the provided arguments.
Args:
messages (`list[dict[str, Any]]`):
A list of dictionaries, where `role` and `content` fields are
required.
tools (`list[dict] | None`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name. For more details, please refer to
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
For more details, please refer to
https://ai.google.dev/gemini-api/docs/structured-output
**config_kwargs (`Any`):
The keyword arguments for Gemini chat completions API.
"""
config: dict = {
"thinking_config": self.thinking_config,
**self.generate_kwargs,
**config_kwargs,
}
if tools:
config["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
self._validate_tool_choice(tool_choice, tools)
config["tool_config"] = self._format_tool_choice(tool_choice)
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
config.pop("tools", None)
config.pop("tool_config", None)
config["response_mime_type"] = "application/json"
config["response_schema"] = structured_model
# Prepare the arguments for the Gemini API call
kwargs: dict[str, JSONSerializableObject] = {
"model": self.model_name,
"contents": messages,
"config": config,
}
start_datetime = datetime.now()
if self.stream:
response = await self.client.aio.models.generate_content_stream(
**kwargs,
)
return self._parse_gemini_stream_generation_response(
start_datetime,
response,
structured_model,
)
# non-streaming
response = await self.client.aio.models.generate_content(
**kwargs,
)
parsed_response = self._parse_gemini_generation_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_gemini_stream_generation_response(
self,
start_datetime: datetime,
response: AsyncIterator[GenerateContentResponse],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given a Gemini streaming generation response, extract the
content blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncIterator[GenerateContentResponse]`):
Gemini GenerateContentResponse async iterator to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
text = ""
thinking = ""
metadata: dict | None = None
async for chunk in response:
content_block: list = []
# Thinking parts
if (
chunk.candidates
and chunk.candidates[0].content
and chunk.candidates[0].content.parts
):
for part in chunk.candidates[0].content.parts:
if part.thought and part.text:
thinking += part.text
# Text parts
if chunk.text:
text += chunk.text
if structured_model:
metadata = _json_loads_with_repair(text)
# Function calls
tool_calls = []
if chunk.function_calls:
for function_call in chunk.function_calls:
tool_calls.append(
ToolUseBlock(
type="tool_use",
id=function_call.id,
name=function_call.name,
input=function_call.args or {},
),
)
usage = None
if chunk.usage_metadata:
usage = ChatUsage(
input_tokens=chunk.usage_metadata.prompt_token_count,
output_tokens=chunk.usage_metadata.total_token_count
- chunk.usage_metadata.prompt_token_count,
time=(datetime.now() - start_datetime).total_seconds(),
)
if thinking:
content_block.append(
ThinkingBlock(
type="thinking",
thinking=thinking,
),
)
if text:
content_block.append(
TextBlock(
type="text",
text=text,
),
)
content_block.extend(
[
*tool_calls,
],
)
parsed_chunk = ChatResponse(
content=content_block,
usage=usage,
metadata=metadata,
)
yield parsed_chunk
def _parse_gemini_generation_response(
self,
start_datetime: datetime,
response: GenerateContentResponse,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given a Gemini chat completion response object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`ChatCompletion`):
The OpenAI chat completion response object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
metadata: dict | None = None
if (
response.candidates
and response.candidates[0].content
and response.candidates[0].content.parts
):
for part in response.candidates[0].content.parts:
if part.thought and part.text:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=part.text,
),
)
if response.text:
content_blocks.append(
TextBlock(
type="text",
text=response.text,
),
)
if structured_model:
metadata = _json_loads_with_repair(response.text)
if response.function_calls:
for tool_call in response.function_calls:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.name,
input=tool_call.args or {},
),
)
if response.usage_metadata:
usage = ChatUsage(
input_tokens=response.usage_metadata.prompt_token_count,
output_tokens=response.usage_metadata.total_token_count
- response.usage_metadata.prompt_token_count,
time=(datetime.now() - start_datetime).total_seconds(),
)
else:
usage = None
return ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schema into required format for Gemini API.
Args:
schemas (`dict[str, Any]`):
The tools JSON schemas.
Returns:
List[Dict[str, Any]]:
A list containing a dictionary with the
"function_declarations" key, which maps to a list of
function definitions.
Example:
.. code-block:: python
:caption: Example tool schemas of Gemini API
# Input JSON schema
schemas = [
{
'type': 'function',
'function': {
'name': 'execute_shell_command',
'description': 'xxx',
'parameters': {
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'xxx.'
},
'timeout': {
'type': 'integer',
'default': 300
}
},
'required': ['command']
}
}
}
]
# Output format (Gemini API expected):
[
{
'function_declarations': [
{
'name': 'execute_shell_command',
'description': 'xxx.',
'parameters': {
'type': 'object',
'properties': {
'command': {
'type': 'string',
'description': 'xxx.'
},
'timeout': {
'type': 'integer',
'default': 300
}
},
'required': ['command']
}
}
]
}
]
"""
return [
{
"function_declarations": [
_["function"] for _ in schemas if "function" in _
],
},
]
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
) -> dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none"] | str | None`, default \
`None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name.
For more details, please refer to
https://ai.google.dev/gemini-api/docs/function-calling?hl=en&example=meeting#function_calling_modes
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
mode_mapping = {
"auto": "AUTO",
"none": "NONE",
"any": "ANY",
"required": "ANY",
}
mode = mode_mapping.get(tool_choice)
if mode:
return {"function_calling_config": {"mode": mode}}
return {
"function_calling_config": {
"mode": "ANY",
"allowed_function_names": [tool_choice],
},
}

View File

@@ -0,0 +1,76 @@
# -*- coding: utf-8 -*-
"""The chat model base class."""
from abc import abstractmethod
from typing import AsyncGenerator, Any
from ._model_response import ChatResponse
TOOL_CHOICE_MODES = ["auto", "none", "any", "required"]
class ChatModelBase:
"""Base class for chat models."""
model_name: str
"""The model name"""
stream: bool
"""Is the model output streaming or not"""
def __init__(
self,
model_name: str,
stream: bool,
) -> None:
"""Initialize the chat model base class.
Args:
model_name (`str`):
The name of the model
stream (`bool`):
Whether the model output is streaming or not
"""
self.model_name = model_name
self.stream = stream
@abstractmethod
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
pass
def _validate_tool_choice(
self,
tool_choice: str,
tools: list[dict] | None,
) -> None:
"""
Validate tool_choice parameter.
Args:
tool_choice (`str`):
Tool choice mode or function name
tools (`list[dict] | None`):
Available tools list
Raises:
TypeError: If tool_choice is not string
ValueError: If tool_choice is invalid
"""
if not isinstance(tool_choice, str):
raise TypeError(
f"tool_choice must be str, got {type(tool_choice)}",
)
if tool_choice in TOOL_CHOICE_MODES:
return
available_functions = [tool["function"]["name"] for tool in tools]
if tool_choice not in available_functions:
all_options = TOOL_CHOICE_MODES + available_functions
raise ValueError(
f"Invalid tool_choice '{tool_choice}'. "
f"Available options: {', '.join(sorted(all_options))}",
)

View File

@@ -0,0 +1,42 @@
# -*- coding: utf-8 -*-
"""The model response module."""
from dataclasses import dataclass, field
from typing import Literal, Sequence
from ._model_usage import ChatUsage
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..message import (
TextBlock,
ToolUseBlock,
ThinkingBlock,
AudioBlock,
)
from ..types import JSONSerializableObject
@dataclass
class ChatResponse(DictMixin):
"""The response of chat models."""
content: Sequence[TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock]
"""The content of the chat response, which can include text blocks,
tool use blocks, or thinking blocks."""
id: str = field(default_factory=lambda: _get_timestamp(True))
"""The unique identifier formatter """
created_at: str = field(default_factory=_get_timestamp)
"""When the response was created"""
type: Literal["chat"] = field(default_factory=lambda: "chat")
"""The type of the response, which is always 'chat'."""
usage: ChatUsage | None = field(default_factory=lambda: None)
"""The usage information of the chat response, if available."""
metadata: dict[str, JSONSerializableObject] | None = field(
default_factory=lambda: None,
)
"""The metadata of the chat response"""

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""The model usage class in agentscope."""
from dataclasses import dataclass, field
from typing import Literal
from .._utils._mixin import DictMixin
@dataclass
class ChatUsage(DictMixin):
"""The usage of a chat model API invocation."""
input_tokens: int
"""The number of input tokens."""
output_tokens: int
"""The number of output tokens."""
time: float
"""The time used in seconds."""
type: Literal["chat"] = field(default_factory=lambda: "chat")
"""The type of the usage, must be `chat`."""

View File

@@ -0,0 +1,345 @@
# -*- coding: utf-8 -*-
"""Model wrapper for Ollama models."""
from datetime import datetime
from typing import (
Any,
TYPE_CHECKING,
List,
AsyncGenerator,
AsyncIterator,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from . import ChatResponse
from ._model_base import ChatModelBase
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import ToolUseBlock, TextBlock, ThinkingBlock
from ..tracing import trace_llm
if TYPE_CHECKING:
from ollama._types import ChatResponse as OllamaChatResponse
else:
OllamaChatResponse = "ollama._types.ChatResponse"
class OllamaChatModel(ChatModelBase):
"""The Ollama chat model class in agentscope."""
def __init__(
self,
model_name: str,
stream: bool = False,
options: dict = None,
keep_alive: str = "5m",
enable_thinking: bool | None = None,
host: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Ollama chat model.
Args:
model_name (`str`):
The name of the model.
stream (`bool`, default `True`):
Streaming mode or not.
options (`dict`, default `None`):
Additional parameters to pass to the Ollama API. These can
include temperature etc.
keep_alive (`str`, default `"5m"`):
Duration to keep the model loaded in memory. The format is a
number followed by a unit suffix (s for seconds, m for minutes
, h for hours).
enable_thinking (`bool | None`, default `None`)
Whether enable thinking or not, only for models such as qwen3,
deepseek-r1, etc. For more details, please refer to
https://ollama.com/search?c=thinking
host (`str | None`, default `None`):
The host address of the Ollama server. If None, uses the
default address (typically http://localhost:11434).
**kwargs (`Any`):
Additional keyword arguments to pass to the base chat model
class.
"""
try:
import ollama
except ImportError as e:
raise ImportError(
"The package ollama is not found. Please install it by "
'running command `pip install "ollama>=0.1.7"`',
) from e
super().__init__(model_name, stream)
self.client = ollama.AsyncClient(
host=host,
**kwargs,
)
self.options = options
self.keep_alive = keep_alive
self.think = enable_thinking
@trace_llm
async def __call__(
self,
messages: list[dict[str, Any]],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "any", "required"]
| str
| None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from Ollama chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
**kwargs (`Any`):
The keyword arguments for Ollama chat completions API,
e.g. `think`etc. Please refer to the Ollama API
documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the Ollama chat completions API.
"""
kwargs = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
"options": self.options,
"keep_alive": self.keep_alive,
**kwargs,
}
if self.think is not None and "think" not in kwargs:
kwargs["think"] = self.think
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
logger.warning("Ollama does not support tool_choice yet, ignored.")
if structured_model:
kwargs["format"] = structured_model.model_json_schema()
start_datetime = datetime.now()
response = await self.client.chat(**kwargs)
if self.stream:
return self._parse_ollama_stream_completion_response(
start_datetime,
response,
structured_model,
)
parsed_response = await self._parse_ollama_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_ollama_stream_completion_response(
self,
start_datetime: datetime,
response: AsyncIterator[OllamaChatResponse],
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an Ollama streaming completion response, extract the
content blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncIterator[OllamaChatResponse]`):
Ollama streaming response async iterator to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
AsyncGenerator[ChatResponse, None]:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in the
streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
accumulated_text = ""
acc_thinking_content = ""
tool_calls = OrderedDict() # Store tool calls
metadata: dict | None = None
async for chunk in response:
# Handle text content
msg = chunk.message
acc_thinking_content += msg.thinking or ""
accumulated_text += msg.content or ""
# Handle tool calls
for idx, tool_call in enumerate(msg.tool_calls or []):
function = tool_call.function
tool_id = f"{idx}_{function.name}"
tool_calls[tool_id] = {
"type": "tool_use",
"id": tool_id,
"name": function.name,
"input": function.arguments,
}
# Calculate usage statistics
current_time = (datetime.now() - start_datetime).total_seconds()
usage = ChatUsage(
input_tokens=getattr(chunk, "prompt_eval_count", 0) or 0,
output_tokens=getattr(chunk, "eval_count", 0) or 0,
time=current_time,
)
# Create content blocks
contents: list = []
if acc_thinking_content:
contents.append(
ThinkingBlock(
type="thinking",
thinking=acc_thinking_content,
),
)
if accumulated_text:
contents.append(TextBlock(type="text", text=accumulated_text))
if structured_model:
metadata = _json_loads_with_repair(accumulated_text)
# Add tool call blocks
for tool_call in tool_calls.values():
try:
input_data = tool_call["input"]
if isinstance(input_data, str):
input_data = _json_loads_with_repair(input_data)
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=input_data,
),
)
except Exception as e:
print(f"Error parsing tool call input: {e}")
# Generate response when there's new content or at final chunk
if chunk.done and contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
async def _parse_ollama_completion_response(
self,
start_datetime: datetime,
response: OllamaChatResponse,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an Ollama chat completion response object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`OllamaChatResponse`):
Ollama OllamaChatResponse object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`ChatResponse`:
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[TextBlock | ToolUseBlock | ThinkingBlock] = []
metadata: dict | None = None
if response.message.thinking:
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=response.message.thinking,
),
)
if response.message.content:
content_blocks.append(
TextBlock(
type="text",
text=response.message.content,
),
)
if structured_model:
metadata = _json_loads_with_repair(
response.message.content,
)
for idx, tool_call in enumerate(response.message.tool_calls or []):
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=f"{idx}_{tool_call.function.name}",
name=tool_call.function.name,
input=tool_call.function.arguments,
),
)
usage = None
if "prompt_eval_count" in response and "eval_count" in response:
usage = ChatUsage(
input_tokens=response.get("prompt_eval_count", 0),
output_tokens=response.get("eval_count", 0),
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schemas to the Ollama format."""
return schemas

View File

@@ -0,0 +1,545 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""OpenAI Chat model class."""
from datetime import datetime
from typing import (
Any,
TYPE_CHECKING,
List,
AsyncGenerator,
Literal,
Type,
)
from collections import OrderedDict
from pydantic import BaseModel
from . import ChatResponse
from ._model_base import ChatModelBase
from ._model_usage import ChatUsage
from .._logging import logger
from .._utils._common import _json_loads_with_repair
from ..message import (
ToolUseBlock,
TextBlock,
ThinkingBlock,
AudioBlock,
Base64Source,
)
from ..tracing import trace_llm
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from openai.types.chat import ChatCompletion
from openai import AsyncStream
else:
ChatCompletion = "openai.types.chat.ChatCompletion"
AsyncStream = "openai.types.chat.AsyncStream"
def _format_audio_data_for_qwen_omni(messages: list[dict]) -> None:
"""Qwen-omni uses OpenAI-compatible API but requires different audio
data format than OpenAI with "data:;base64," prefix.
Refer to `Qwen-omni documentation
<https://bailian.console.aliyun.com/?tab=doc#/doc/?type=model&url=2867839>`_
for more details.
Args:
messages (`list[dict]`):
The list of message dictionaries from OpenAI formatter.
"""
for msg in messages:
if isinstance(msg.get("content"), list):
for block in msg["content"]:
if (
isinstance(block, dict)
and "input_audio" in block
and isinstance(block["input_audio"].get("data"), str)
):
if not block["input_audio"]["data"].startswith("http"):
block["input_audio"]["data"] = (
"data:;base64," + block["input_audio"]["data"]
)
class OpenAIChatModel(ChatModelBase):
"""The OpenAI chat model class."""
def __init__(
self,
model_name: str,
api_key: str | None = None,
stream: bool = True,
reasoning_effort: Literal["low", "medium", "high"] | None = None,
organization: str = None,
client_args: dict = None,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
) -> None:
"""Initialize the openai client.
Args:
model_name (`str`, default `None`):
The name of the model to use in OpenAI API.
api_key (`str`, default `None`):
The API key for OpenAI API. If not specified, it will
be read from the environment variable `OPENAI_API_KEY`.
stream (`bool`, default `True`):
Whether to use streaming output or not.
reasoning_effort (`Literal["low", "medium", "high"] | None`, \
optional):
Reasoning effort, supported for o3, o4, etc. Please refer to
`OpenAI documentation
<https://platform.openai.com/docs/guides/reasoning?api-mode=chat>`_
for more details.
organization (`str`, default `None`):
The organization ID for OpenAI API. If not specified, it will
be read from the environment variable `OPENAI_ORGANIZATION`.
client_args (`dict`, default `None`):
The extra keyword arguments to initialize the OpenAI client.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
The extra keyword arguments used in OpenAI API generation,
e.g. `temperature`, `seed`.
"""
super().__init__(model_name, stream)
import openai
self.client = openai.AsyncClient(
api_key=api_key,
organization=organization,
**(client_args or {}),
)
self.reasoning_effort = reasoning_effort
self.generate_kwargs = generate_kwargs or {}
@trace_llm
async def __call__(
self,
messages: list[dict],
tools: list[dict] | None = None,
tool_choice: Literal["auto", "none", "any", "required"]
| str
| None = None,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> ChatResponse | AsyncGenerator[ChatResponse, None]:
"""Get the response from OpenAI chat completions API by the given
arguments.
Args:
messages (`list[dict]`):
A list of dictionaries, where `role` and `content` fields are
required, and `name` field is optional.
tools (`list[dict]`, default `None`):
The tools JSON schemas that the model can use.
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name. For more details, please refer to
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output. When provided, the model will be forced
to return data that conforms to this schema by automatically
converting the BaseModel to a tool function and setting
`tool_choice` to enforce its usage. This enables structured
output generation.
.. note:: When `structured_model` is specified,
both `tools` and `tool_choice` parameters are ignored,
and the model will only perform structured output
generation without calling any other tools.
For more details, please refer to the `official document
<https://platform.openai.com/docs/guides/structured-outputs>`_
**kwargs (`Any`):
The keyword arguments for OpenAI chat completions API,
e.g. `temperature`, `max_tokens`, `top_p`, etc. Please
refer to the OpenAI API documentation for more details.
Returns:
`ChatResponse | AsyncGenerator[ChatResponse, None]`:
The response from the OpenAI chat completions API.
"""
# checking messages
if not isinstance(messages, list):
raise ValueError(
"OpenAI `messages` field expected type `list`, "
f"got `{type(messages)}` instead.",
)
if not all("role" in msg and "content" in msg for msg in messages):
raise ValueError(
"Each message in the 'messages' list must contain a 'role' "
"and 'content' key for OpenAI API.",
)
# Qwen-omni requires different base64 audio format from openai
if "omni" in self.model_name.lower():
_format_audio_data_for_qwen_omni(messages)
kwargs = {
"model": self.model_name,
"messages": messages,
"stream": self.stream,
**self.generate_kwargs,
**kwargs,
}
if self.reasoning_effort and "reasoning_effort" not in kwargs:
kwargs["reasoning_effort"] = self.reasoning_effort
if tools:
kwargs["tools"] = self._format_tools_json_schemas(tools)
if tool_choice:
self._validate_tool_choice(tool_choice, tools)
kwargs["tool_choice"] = self._format_tool_choice(tool_choice)
if self.stream:
kwargs["stream_options"] = {"include_usage": True}
start_datetime = datetime.now()
if structured_model:
if tools or tool_choice:
logger.warning(
"structured_model is provided. Both 'tools' and "
"'tool_choice' parameters will be overridden and "
"ignored. The model will only perform structured output "
"generation without calling any other tools.",
)
kwargs.pop("stream", None)
kwargs.pop("tools", None)
kwargs.pop("tool_choice", None)
kwargs["response_format"] = structured_model
if not self.stream:
response = await self.client.chat.completions.parse(**kwargs)
else:
response = self.client.chat.completions.stream(**kwargs)
return self._parse_openai_stream_response(
start_datetime,
response,
structured_model,
)
else:
response = await self.client.chat.completions.create(**kwargs)
if self.stream:
return self._parse_openai_stream_response(
start_datetime,
response,
structured_model,
)
# Non-streaming response
parsed_response = self._parse_openai_completion_response(
start_datetime,
response,
structured_model,
)
return parsed_response
async def _parse_openai_stream_response(
self,
start_datetime: datetime,
response: AsyncStream,
structured_model: Type[BaseModel] | None = None,
) -> AsyncGenerator[ChatResponse, None]:
"""Given an OpenAI streaming completion response, extract the content
blocks and usages from it and yield ChatResponse objects.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`AsyncStream`):
OpenAI AsyncStream object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
`AsyncGenerator[ChatResponse, None]`:
An async generator that yields ChatResponse objects containing
the content blocks and usage information for each chunk in
the streaming response.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
usage, res = None, None
text = ""
thinking = ""
audio = ""
tool_calls = OrderedDict()
metadata: dict | None = None
contents: List[
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
] = []
async with response as stream:
async for item in stream:
if structured_model:
if item.type != "chunk":
continue
chunk = item.chunk
else:
chunk = item
if chunk.usage:
usage = ChatUsage(
input_tokens=chunk.usage.prompt_tokens,
output_tokens=chunk.usage.completion_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
if not chunk.choices:
if usage and contents:
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
continue
choice = chunk.choices[0]
thinking += (
getattr(choice.delta, "reasoning_content", None) or ""
)
text += choice.delta.content or ""
if (
hasattr(choice.delta, "audio")
and "data" in choice.delta.audio
):
audio += choice.delta.audio["data"]
if (
hasattr(choice.delta, "audio")
and "transcript" in choice.delta.audio
):
text += choice.delta.audio["transcript"]
for tool_call in choice.delta.tool_calls or []:
if tool_call.index in tool_calls:
if tool_call.function.arguments is not None:
tool_calls[tool_call.index][
"input"
] += tool_call.function.arguments
else:
tool_calls[tool_call.index] = {
"type": "tool_use",
"id": tool_call.id,
"name": tool_call.function.name,
"input": tool_call.function.arguments or "",
}
contents = []
if thinking:
contents.append(
ThinkingBlock(
type="thinking",
thinking=thinking,
),
)
if audio:
media_type = self.generate_kwargs.get("audio", {}).get(
"format",
"wav",
)
contents.append(
AudioBlock(
type="audio",
source=Base64Source(
data=audio,
media_type=f"audio/{media_type}",
type="base64",
),
),
)
if text:
contents.append(
TextBlock(
type="text",
text=text,
),
)
if structured_model:
metadata = _json_loads_with_repair(text)
for tool_call in tool_calls.values():
contents.append(
ToolUseBlock(
type=tool_call["type"],
id=tool_call["id"],
name=tool_call["name"],
input=_json_loads_with_repair(
tool_call["input"] or "{}",
),
),
)
if not contents:
continue
res = ChatResponse(
content=contents,
usage=usage,
metadata=metadata,
)
yield res
def _parse_openai_completion_response(
self,
start_datetime: datetime,
response: ChatCompletion,
structured_model: Type[BaseModel] | None = None,
) -> ChatResponse:
"""Given an OpenAI chat completion response object, extract the content
blocks and usages from it.
Args:
start_datetime (`datetime`):
The start datetime of the response generation.
response (`ChatCompletion`):
OpenAI ChatCompletion object to parse.
structured_model (`Type[BaseModel] | None`, default `None`):
A Pydantic BaseModel class that defines the expected structure
for the model's output.
Returns:
ChatResponse (`ChatResponse`):
A ChatResponse object containing the content blocks and usage.
.. note::
If `structured_model` is not `None`, the expected structured output
will be stored in the metadata of the `ChatResponse`.
"""
content_blocks: List[
TextBlock | ToolUseBlock | ThinkingBlock | AudioBlock
] = []
metadata: dict | None = None
if response.choices:
choice = response.choices[0]
if (
hasattr(choice.message, "reasoning_content")
and choice.message.reasoning_content is not None
):
content_blocks.append(
ThinkingBlock(
type="thinking",
thinking=response.choices[0].message.reasoning_content,
),
)
if choice.message.content:
content_blocks.append(
TextBlock(
type="text",
text=response.choices[0].message.content,
),
)
if choice.message.audio:
media_type = self.generate_kwargs.get("audio", {}).get(
"format",
"mp3",
)
content_blocks.append(
AudioBlock(
type="audio",
source=Base64Source(
data=choice.message.audio.data,
media_type=f"audio/{media_type}",
type="base64",
),
),
)
if choice.message.audio.transcript:
content_blocks.append(
TextBlock(
type="text",
text=choice.message.audio.transcript,
),
)
for tool_call in choice.message.tool_calls or []:
content_blocks.append(
ToolUseBlock(
type="tool_use",
id=tool_call.id,
name=tool_call.function.name,
input=_json_loads_with_repair(
tool_call.function.arguments,
),
),
)
if structured_model:
metadata = choice.message.parsed.model_dump()
usage = None
if response.usage:
usage = ChatUsage(
input_tokens=response.usage.prompt_tokens,
output_tokens=response.usage.completion_tokens,
time=(datetime.now() - start_datetime).total_seconds(),
)
parsed_response = ChatResponse(
content=content_blocks,
usage=usage,
metadata=metadata,
)
return parsed_response
def _format_tools_json_schemas(
self,
schemas: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Format the tools JSON schemas to the OpenAI format."""
return schemas
def _format_tool_choice(
self,
tool_choice: Literal["auto", "none", "any", "required"] | str | None,
) -> str | dict | None:
"""Format tool_choice parameter for API compatibility.
Args:
tool_choice (`Literal["auto", "none", "any", "required"] | str \
| None`, default `None`):
Controls which (if any) tool is called by the model.
Can be "auto", "none", "any", "required", or specific tool
name. For more details, please refer to
https://platform.openai.com/docs/api-reference/responses/create#responses_create-tool_choice
Returns:
`dict | None`:
The formatted tool choice configuration dict, or None if
tool_choice is None.
"""
if tool_choice is None:
return None
mode_mapping = {
"auto": "auto",
"none": "none",
"any": "required",
"required": "required",
}
if tool_choice in mode_mapping:
return mode_mapping[tool_choice]
return {"type": "function", "function": {"name": tool_choice}}

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""A model class for RL Training with Trinity-RFT."""
from typing import (
Optional,
TYPE_CHECKING,
)
from ._openai_model import OpenAIChatModel
from ..types import JSONSerializableObject
if TYPE_CHECKING:
from openai import AsyncOpenAI
else:
AsyncOpenAI = "openai.AsyncOpenAI"
class TrinityChatModel(OpenAIChatModel):
"""A model class for RL Training with Trinity-RFT."""
def __init__(
self,
openai_async_client: AsyncOpenAI,
generate_kwargs: dict[str, JSONSerializableObject] | None = None,
enable_thinking: Optional[bool] = None,
) -> None:
"""Initialize the Trinity model class.
Args:
openai_async_client (`AsyncOpenAI`):
The OpenAI async client instance provided by Trinity-RFT.
generate_kwargs (`dict[str, JSONSerializableObject] | None`, \
optional):
Additional keyword arguments to pass to the model's generate
method. Defaults to None.
enable_thinking (`bool`, optional):
Whether to enable the model's thinking capability. Only
applicable for Qwen3 series models. Defaults to None.
"""
model_name = getattr(openai_async_client, "model_path", None)
if model_name is None:
raise ValueError(
"The provided openai_async_client does not have a "
"`model_path` attribute. Please ensure you are using "
"the instance provided by Trinity-RFT.",
)
super().__init__(
model_name=model_name,
api_key="EMPTY",
generate_kwargs=generate_kwargs,
stream=False, # RL training does not support streaming
)
if enable_thinking is not None:
if "chat_template_kwargs" not in self.generate_kwargs:
self.generate_kwargs["chat_template_kwargs"] = {}
assert isinstance(
self.generate_kwargs["chat_template_kwargs"],
dict,
), "chat_template_kwargs must be a dictionary."
self.generate_kwargs["chat_template_kwargs"][
"enable_thinking"
] = enable_thinking
# change the client instance to the provided one
self.client = openai_async_client