chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,46 @@
# -*- coding: utf-8 -*-
"""The formatter module in agentscope."""
from ._formatter_base import FormatterBase
from ._truncated_formatter_base import TruncatedFormatterBase
from ._dashscope_formatter import (
DashScopeChatFormatter,
DashScopeMultiAgentFormatter,
)
from ._anthropic_formatter import (
AnthropicChatFormatter,
AnthropicMultiAgentFormatter,
)
from ._openai_formatter import (
OpenAIChatFormatter,
OpenAIMultiAgentFormatter,
)
from ._gemini_formatter import (
GeminiChatFormatter,
GeminiMultiAgentFormatter,
)
from ._ollama_formatter import (
OllamaChatFormatter,
OllamaMultiAgentFormatter,
)
from ._deepseek_formatter import (
DeepSeekChatFormatter,
DeepSeekMultiAgentFormatter,
)
__all__ = [
"FormatterBase",
"TruncatedFormatterBase",
"DashScopeChatFormatter",
"DashScopeMultiAgentFormatter",
"OpenAIChatFormatter",
"OpenAIMultiAgentFormatter",
"AnthropicChatFormatter",
"AnthropicMultiAgentFormatter",
"GeminiChatFormatter",
"GeminiMultiAgentFormatter",
"OllamaChatFormatter",
"OllamaMultiAgentFormatter",
"DeepSeekChatFormatter",
"DeepSeekMultiAgentFormatter",
]

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Anthropic formatter module."""
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
class AnthropicChatFormatter(TruncatedFormatterBase):
"""Formatter for Anthropic messages."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Anthropic API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
.. note:: Anthropic suggests always passing all previous thinking
blocks back to the API in subsequent calls to maintain reasoning
continuity. For more details, please refer to
`Anthropic's documentation
<https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#preserving-thinking-blocks>`_.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for index, msg in enumerate(msgs):
content_blocks = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ in ["thinking", "text", "image"]:
content_blocks.append({**block})
elif typ == "tool_use":
content_blocks.append(
{
"id": block.get("id"),
"type": "tool_use",
"name": block.get("name"),
"input": block.get("input", {}),
},
)
elif typ == "tool_result":
output = block.get("output")
if output is None:
content_value = [{"type": "text", "text": None}]
elif isinstance(output, list):
content_value = output
else:
content_value = [{"type": "text", "text": str(output)}]
messages.append(
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": block.get("id"),
"content": content_value,
},
],
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
# Claude only allow the first message to be system message
if msg.role == "system" and index != 0:
role = "user"
else:
role = msg.role
msg_anthropic = {
"role": role,
"content": content_blocks or None,
}
# When both content and tool_calls are None, skipped
if msg_anthropic["content"] or msg_anthropic.get("tool_calls"):
messages.append(msg_anthropic)
return messages
class AnthropicMultiAgentFormatter(TruncatedFormatterBase):
"""
Anthropic formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Anthropic API."""
return await AnthropicChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Anthropic API."""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Anthropic format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{
"text": "\n".join(accumulated_text),
"type": "text",
},
)
accumulated_text.clear()
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{
"text": "\n".join(accumulated_text),
"type": "text",
},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"type": "text",
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append(
{"type": "text", "text": "</history>"},
)
if conversation_blocks:
formatted_msgs.append(
{
"role": "user",
"content": conversation_blocks,
},
)
return formatted_msgs

View File

@@ -0,0 +1,426 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The dashscope formatter module."""
import json
import os.path
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _is_accessible_local_file
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
)
from ..token import TokenCounterBase
def _reformat_messages(
messages: list[dict[str, Any]],
) -> list[dict[str, Any]]:
"""Reformat the content to be compatible with HuggingFaceTokenCounter.
This function processes a list of messages and converts multi-part
text content into single string content when all parts are plain text.
This is necessary for compatibility with HuggingFaceTokenCounter which
expects simple string content rather than structured content with
multiple parts.
Args:
messages (list[dict[str, Any]]):
A list of message dictionaries where each message may contain a
"content" field. The content can be either:
- A string (unchanged)
- A list of content items, where each item is a dict that may
contain "text", "type", and other fields
Returns:
list[dict[str, Any]]:
A list of reformatted messages. For messages where all content
items are plain text (have "text" field and either no "type"
field or "type" == "text"), the content list is converted to a
single newline-joined string. Other messages remain unchanged.
Example:
.. code-block:: python
# Case 1: All text content - will be converted
messages = [
{
"role": "user",
"content": [
{"text": "Hello", "type": "text"},
{"text": "World", "type": "text"}
]
}
]
result = _reformat_messages(messages)
print(result[0]["content"])
# Output: "Hello\nWorld"
# Case 2: Mixed content - will remain unchanged
messages = [
{
"role": "user",
"content": [
{"text": "Hello", "type": "text"},
{"image_url": "...", "type": "image"}
]
}
]
result = _reformat_messages(messages) # remain unchanged
print(type(result[0]["content"]))
# Output: <class 'list'>
"""
for message in messages:
content = message.get("content", [])
is_all_text = True
texts = []
for item in content:
if not isinstance(item, dict) or "text" not in item:
is_all_text = False
break
if "type" in item and item["type"] != "text":
is_all_text = False
break
if item["text"]:
texts.append(item["text"])
if is_all_text and texts:
message["content"] = "\n".join(texts)
return messages
class DashScopeChatFormatter(TruncatedFormatterBase):
"""Formatter for DashScope messages."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into DashScope API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
formatted_msgs: list[dict] = []
for msg in msgs:
content_blocks = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append(
{
"text": block.get("text"),
},
)
elif typ in ["image", "audio"]:
source = block["source"]
if source["type"] == "url":
url = source["url"]
if _is_accessible_local_file(url):
content_blocks.append(
{typ: "file://" + os.path.abspath(url)},
)
else:
# treat as web url
content_blocks.append({typ: url})
elif source["type"] == "base64":
media_type = source["media_type"]
base64_data = source["data"]
content_blocks.append(
{typ: f"data:{media_type};base64,{base64_data}"},
)
else:
raise NotImplementedError(
f"Unsupported source type '{source.get('type')}' "
f"for {typ} block.",
)
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
formatted_msgs.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
),
"name": block.get("name"),
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
msg_dashscope = {
"role": msg.role,
"content": content_blocks or [{"text": None}],
}
if tool_calls:
msg_dashscope["tool_calls"] = tool_calls
if msg_dashscope["content"] != [
{"text": None},
] or msg_dashscope.get(
"tool_calls",
):
formatted_msgs.append(msg_dashscope)
return _reformat_messages(formatted_msgs)
class DashScopeMultiAgentFormatter(TruncatedFormatterBase):
"""DashScope formatter for multi-agent conversations, where more than
a user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the DashScope API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DashScope multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the DashScope API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DashScope API.
"""
return await DashScopeChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into a user message with conversation history tags. For the
first agent message, it will include the conversation history prompt.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DashScope API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required DashScope format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "audio"]:
# Handle the accumulated text as a single block
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
if block["source"]["type"] == "url":
url = block["source"]["url"]
if _is_accessible_local_file(url):
conversation_blocks.append(
{
block["type"]: "file://"
+ os.path.abspath(url),
},
)
else:
conversation_blocks.append({block["type"]: url})
elif block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
conversation_blocks.append(
{
block[
"type"
]: f"data:{media_type};base64,{base64_data}",
},
)
else:
logger.warning(
"Unsupported block type %s in the message, "
"skipped.",
block["type"],
)
if accumulated_text:
conversation_blocks.append({"text": "\n".join(accumulated_text)})
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
formatted_msgs.append(
{
"role": "user",
"content": conversation_blocks,
},
)
return _reformat_messages(formatted_msgs)
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for DashScope API."""
return {
"role": "system",
"content": msg.get_text_content(),
}

View File

@@ -0,0 +1,250 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The DeepSeek formatter module."""
import json
from typing import Any
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import Msg, TextBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
class DeepSeekChatFormatter(TruncatedFormatterBase):
"""Formatter for DeepSeek messages."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = False
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into DeepSeek API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for msg in msgs:
content_blocks: list = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
),
"name": block.get("name"),
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
msg_deepseek = {
"role": msg.role,
"content": content_msg or None,
}
if tool_calls:
msg_deepseek["tool_calls"] = tool_calls
if msg_deepseek["content"] or msg_deepseek.get("tool_calls"):
messages.append(msg_deepseek)
return messages
class DeepSeekMultiAgentFormatter(TruncatedFormatterBase):
"""
DeepSeek formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = False
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the DeepSeek multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the DeepSeek API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DeepSeek API.
"""
return await DeepSeekChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the DeepSeek API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the DeepSeek API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required DeepSeek format
formatted_msgs: list[dict] = []
conversation_blocks: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""The formatter module."""
from abc import abstractmethod
from typing import Any, List
from .._utils._common import _save_base64_data
from ..message import Msg, AudioBlock, ImageBlock, TextBlock
class FormatterBase:
"""The base class for formatters."""
@abstractmethod
async def format(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]:
"""Format the Msg objects to a list of dictionaries that satisfy the
API requirements."""
@staticmethod
def assert_list_of_msgs(msgs: list[Msg]) -> None:
"""Assert that the input is a list of Msg objects.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be validated.
"""
if not isinstance(msgs, list):
raise TypeError("Input must be a list of Msg objects.")
for msg in msgs:
if not isinstance(msg, Msg):
raise TypeError(
f"Expected Msg object, got {type(msg)} instead.",
)
@staticmethod
def convert_tool_result_to_string(
output: str | List[TextBlock | ImageBlock | AudioBlock],
) -> str:
"""Turn the tool result list into a textual output to be compatible
with the LLM API that doesn't support multimodal data.
Args:
output (`str | List[TextBlock | ImageBlock | AudioBlock]`):
The output of the tool response, including text and multimodal
data like images and audio.
Returns:
`str`:
A string representation of the tool result, with text blocks
concatenated and multimodal data represented by file paths
or URLs.
"""
if isinstance(output, str):
return output
textual_output = []
for block in output:
assert isinstance(block, dict) and "type" in block, (
f"Invalid block: {block}, a TextBlock, ImageBlock, or "
f"AudioBlock is expected."
)
if block["type"] == "text":
textual_output.append(block["text"])
elif block["type"] in ["image", "audio", "video"]:
assert "source" in block, (
f"Invalid {block['type']} block: {block}, 'source' key "
"is required."
)
source = block["source"]
# Save the image locally and return the file path
if source["type"] == "url":
textual_output.append(
f"The returned {block['type']} can be found "
f"at: {source['url']}",
)
elif source["type"] == "base64":
path_temp_file = _save_base64_data(
source["media_type"],
source["data"],
)
textual_output.append(
f"The returned {block['type']} can be found "
f"at: {path_temp_file}",
)
else:
raise ValueError(
f"Invalid image source: {block['source']}, "
"expected 'url' or 'base64'.",
)
else:
raise ValueError(
f"Unsupported block type: {block['type']}, "
"expected 'text', 'image', 'audio', or 'video'.",
)
if len(textual_output) == 1:
return textual_output[0]
else:
return "\n".join("- " + _ for _ in textual_output)

View File

@@ -0,0 +1,405 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""Google gemini API formatter in agentscope."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._utils._common import _get_bytes_from_web_url
from ..message import (
Msg,
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
VideoBlock,
)
from .._logging import logger
from ..token import TokenCounterBase
def _to_gemini_inline_data(url: str) -> dict:
"""Convert url into the Gemini API required format."""
parsed_url = urlparse(url)
extension = url.split(".")[-1].lower()
# Pre-calculate media type from extension (image/audio/video).
typ = None
for k, v in GeminiChatFormatter.supported_extensions.items():
if extension in v:
typ = k
break
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
data = _get_bytes_from_web_url(url)
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
elif os.path.exists(url):
# Local file
if typ is None:
raise TypeError(
f"Unsupported file extension: {extension}, expected "
f"{GeminiChatFormatter.supported_extensions}",
)
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return {
"data": data,
"mime_type": f"{typ}/{extension}",
}
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
class GeminiChatFormatter(TruncatedFormatterBase):
"""The formatter for Google Gemini API."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
supported_extensions: dict[str, list[str]] = {
"image": ["png", "jpeg", "webp", "heic", "heif"],
"video": [
"mp4",
"mpeg",
"mov",
"avi",
"x-flv",
"mpg",
"webm",
"wmv",
"3gpp",
],
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
}
async def _format(
self,
msgs: list[Msg],
) -> list[dict]:
"""Format message objects into Gemini API required format."""
self.assert_list_of_msgs(msgs)
messages: list = []
for msg in msgs:
parts = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
parts.append(
{
"text": block.get("text"),
},
)
elif typ == "tool_use":
parts.append(
{
"function_call": {
"id": block["id"],
"name": block["name"],
"args": block["input"],
},
},
)
elif typ == "tool_result":
text_output = self.convert_tool_result_to_string(
block["output"], # type: ignore[arg-type]
)
messages.append(
{
"role": "user",
"parts": [
{
"function_response": {
"id": block["id"],
"name": block["name"],
"response": {
"output": text_output,
},
},
},
],
},
)
elif typ in ["image", "audio", "video"]:
if block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
parts.append(
{
"inline_data": {
"data": base64_data,
"mime_type": media_type,
},
},
)
elif block["source"]["type"] == "url":
parts.append(
{
"inline_data": _to_gemini_inline_data(
block["source"]["url"],
),
},
)
else:
logger.warning(
"Unsupported block type: %s in the message, skipped. ",
typ,
)
role = "model" if msg.role == "assistant" else "user"
if parts:
messages.append(
{
"role": role,
"parts": parts,
},
)
return messages
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
"""The multi-agent formatter for Google Gemini API, where more than a
user and an agent are involved.
.. note:: This formatter will combine previous messages (except tool
calls/results) into a history section in the first system message with
the conversation history prompt.
.. note:: For tool calls/results, they will be presented as separate
messages as required by the Gemini API. Therefore, the tool calls/
results messages are expected to be placed at the end of the input
messages.
.. tip:: Telling the assistant's name in the system prompt is very
important in multi-agent conversations. So that LLM can know who it
is playing as.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
VideoBlock,
AudioBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Gemini multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to be used for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Gemini API."""
return {
"role": "user",
"parts": [
{
"text": msg.get_text_content(),
},
],
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
return await GeminiChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Gemini API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Gemini API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into Gemini API required format
formatted_msgs: list = []
# Collect the multimodal files
conversation_parts: list = []
accumulated_text = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] in ["image", "video", "audio"]:
# handle the accumulated text as a single part if exists
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
accumulated_text.clear()
# handle the multimodal data
if block["source"]["type"] == "url":
conversation_parts.append(
{
"inline_data": _to_gemini_inline_data(
block["source"]["url"],
),
},
)
elif block["source"]["type"] == "base64":
media_type = block["source"]["media_type"]
base64_data = block["source"]["data"]
conversation_parts.append(
{
"inline_data": {
"data": base64_data,
"mime_type": media_type,
},
},
)
if accumulated_text:
conversation_parts.append(
{
"text": "\n".join(accumulated_text),
},
)
# Add prompt and <history></history> tags around conversation history
if conversation_parts:
if conversation_parts[0].get("text"):
conversation_parts[0]["text"] = (
conversation_history_prompt
+ "<history>"
+ conversation_parts[0]["text"]
)
else:
conversation_parts.insert(
0,
{"text": conversation_history_prompt + "<history>"},
)
if conversation_parts[-1].get("text"):
conversation_parts[-1]["text"] += "\n</history>"
else:
conversation_parts.append(
{"text": "</history>"},
)
formatted_msgs.append(
{
"role": "user",
"parts": conversation_parts,
},
)
return formatted_msgs

View File

@@ -0,0 +1,320 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The Ollama formatter module."""
import base64
import os
from typing import Any
from urllib.parse import urlparse
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from .._utils._common import _get_bytes_from_web_url
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
from ..token import TokenCounterBase
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
"""Convert image url to base64."""
parsed_url = urlparse(url)
if not os.path.exists(url) and parsed_url.scheme != "":
# Web url
data = _get_bytes_from_web_url(url)
return data
if os.path.exists(url):
# Local file
with open(url, "rb") as f:
data = base64.b64encode(f.read()).decode("utf-8")
return data
raise ValueError(
f"The URL `{url}` is not a valid image URL or local file.",
)
class OllamaChatFormatter(TruncatedFormatterBase):
"""Formatter for Ollama messages."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = False
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into Ollama API format.
Args:
msgs (`list[Msg]`):
The list of message objects to format.
Returns:
`list[dict[str, Any]]`:
The formatted messages as a list of dictionaries.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for msg in msgs:
content_blocks: list = []
tool_calls = []
images = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": block.get("input", {}),
},
},
)
elif typ == "tool_result":
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
),
"name": block.get("name"),
},
)
elif typ == "image":
source_type = block["source"]["type"]
if source_type == "url":
images.append(
_convert_ollama_image_url_to_base64_data(
block["source"]["url"],
),
)
elif source_type == "base64":
images.append(block["source"]["data"])
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
content_msg = "\n".join(
content.get("text", "") for content in content_blocks
)
msg_ollama = {
"role": msg.role,
"content": content_msg or None,
}
if tool_calls:
msg_ollama["tool_calls"] = tool_calls
if images:
msg_ollama["images"] = images
if msg_ollama["content"] or msg_ollama.get("tool_calls"):
messages.append(msg_ollama)
return messages
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
"""
Ollama formatter for multi-agent conversations, where more than
a user and an agent are involved.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversations"""
support_vision: bool = True
"""Whether support vision data"""
supported_blocks: list[type] = [
TextBlock,
# Multimodal
ImageBlock,
# Tool use
ToolUseBlock,
ToolResultBlock,
]
"""The list of supported message blocks"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the Ollama multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
token_counter (`TokenCounterBase | None`, optional):
The token counter used for truncation.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If `None`, no truncation will be applied.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the Ollama API."""
return {
"role": "system",
"content": msg.get_text_content(),
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
The list of messages containing tool calls/results to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the Ollama API.
"""
return await OllamaChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the Ollama API.
Args:
msgs (`list[Msg]`):
A list of Msg objects to be formatted.
is_first (`bool`, defaults to `True`):
Whether this is the first agent message in the conversation.
If `True`, the conversation history prompt will be included.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries formatted for the ollama API.
"""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required Ollama format
formatted_msgs: list[dict] = []
# Collect the multimodal files
conversation_blocks: list = []
accumulated_text = []
images = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
# Handle the accumulated text as a single block
source = block["source"]
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
accumulated_text.clear()
if source["type"] == "url":
images.append(
_convert_ollama_image_url_to_base64_data(
source["url"],
),
)
elif source["type"] == "base64":
images.append(source["data"])
conversation_blocks.append({**block})
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
user_message = {
"role": "user",
"content": conversation_blocks_text,
}
if images:
user_message["images"] = images
if conversation_blocks:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,414 @@
# -*- coding: utf-8 -*-
# pylint: disable=too-many-branches
"""The OpenAI formatter for agentscope."""
import base64
import json
import os
from typing import Any
from urllib.parse import urlparse
import requests
from ._truncated_formatter_base import TruncatedFormatterBase
from .._logging import logger
from ..message import (
Msg,
URLSource,
TextBlock,
ImageBlock,
AudioBlock,
Base64Source,
ToolUseBlock,
ToolResultBlock,
)
from ..token import TokenCounterBase
def _to_openai_image_url(url: str) -> str:
"""Convert an image url to openai format. If the given url is a local
file, it will be converted to base64 format. Otherwise, it will be
returned directly.
Args:
url (`str`):
The local or public url of the image.
"""
# See https://platform.openai.com/docs/guides/vision for details of
# support image extensions.
support_image_extensions = (
".png",
".jpg",
".jpeg",
".gif",
".webp",
)
parsed_url = urlparse(url)
lower_url = url.lower()
# Web url
if not os.path.exists(url) and parsed_url.scheme != "":
if any(lower_url.endswith(_) for _ in support_image_extensions):
return url
# Check if it is a local file
elif os.path.exists(url) and os.path.isfile(url):
if any(lower_url.endswith(_) for _ in support_image_extensions):
with open(url, "rb") as image_file:
base64_image = base64.b64encode(image_file.read()).decode(
"utf-8",
)
extension = parsed_url.path.lower().split(".")[-1]
mime_type = f"image/{extension}"
return f"data:{mime_type};base64,{base64_image}"
raise TypeError(f'"{url}" should end with {support_image_extensions}.')
def _to_openai_audio_data(source: URLSource | Base64Source) -> dict:
"""Covert an audio source to OpenAI format."""
if source["type"] == "url":
extension = source["url"].split(".")[-1].lower()
if extension not in ["wav", "mp3"]:
raise TypeError(
f"Unsupported audio file extension: {extension}, "
"wav and mp3 are supported.",
)
parsed_url = urlparse(source["url"])
if os.path.exists(source["url"]):
with open(source["url"], "rb") as audio_file:
data = base64.b64encode(audio_file.read()).decode("utf-8")
# web url
elif parsed_url.scheme != "":
response = requests.get(source["url"])
response.raise_for_status()
data = base64.b64encode(response.content).decode("utf-8")
else:
raise ValueError(
f"Unsupported audio source: {source['url']}, "
"it should be a local file or a web URL.",
)
return {
"data": data,
"format": extension,
}
if source["type"] == "base64":
data = source["data"]
media_type = source["media_type"]
if media_type not in ["audio/wav", "audio/mp3"]:
raise TypeError(
f"Unsupported audio media type: {media_type}, "
"only audio/wav and audio/mp3 are supported.",
)
return {
"data": data,
"format": media_type.split("/")[-1],
}
raise TypeError(f"Unsupported audio source: {source['type']}.")
class OpenAIChatFormatter(TruncatedFormatterBase):
"""The class used to format message objects into the OpenAI API required
format."""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversation"""
support_vision: bool = True
"""Whether support vision models"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
"""Supported message blocks for OpenAI API"""
async def _format(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Format message objects into OpenAI API required format.
Args:
msgs (`list[Msg]`):
The list of Msg objects to format.
Returns:
`list[dict[str, Any]]`:
A list of dictionaries, where each dictionary has "name",
"role", and "content" keys.
"""
self.assert_list_of_msgs(msgs)
messages: list[dict] = []
for msg in msgs:
content_blocks = []
tool_calls = []
for block in msg.get_content_blocks():
typ = block.get("type")
if typ == "text":
content_blocks.append({**block})
elif typ == "tool_use":
tool_calls.append(
{
"id": block.get("id"),
"type": "function",
"function": {
"name": block.get("name"),
"arguments": json.dumps(
block.get("input", {}),
ensure_ascii=False,
),
},
},
)
elif typ == "tool_result":
messages.append(
{
"role": "tool",
"tool_call_id": block.get("id"),
"content": self.convert_tool_result_to_string(
block.get("output"), # type: ignore[arg-type]
),
"name": block.get("name"),
},
)
elif typ == "image":
source_type = block["source"]["type"]
if source_type == "url":
url = _to_openai_image_url(block["source"]["url"])
elif source_type == "base64":
data = block["source"]["data"]
media_type = block["source"]["media_type"]
url = f"data:{media_type};base64,{data}"
else:
raise ValueError(
f"Unsupported image source type: {source_type}",
)
content_blocks.append(
{
"type": "image_url",
"image_url": {
"url": url,
},
},
)
elif typ == "audio":
input_audio = _to_openai_audio_data(block["source"])
content_blocks.append(
{
"type": "input_audio",
"input_audio": input_audio,
},
)
else:
logger.warning(
"Unsupported block type %s in the message, skipped.",
typ,
)
msg_openai = {
"role": msg.role,
"name": msg.name,
"content": content_blocks or None,
}
if tool_calls:
msg_openai["tool_calls"] = tool_calls
# When both content and tool_calls are None, skipped
if msg_openai["content"] or msg_openai.get("tool_calls"):
messages.append(msg_openai)
return messages
class OpenAIMultiAgentFormatter(TruncatedFormatterBase):
"""
OpenAI formatter for multi-agent conversations, where more than
a user and an agent are involved.
.. tip:: This formatter is compatible with OpenAI API and
OpenAI-compatible services like vLLM, Azure OpenAI, and others.
"""
support_tools_api: bool = True
"""Whether support tools API"""
support_multiagent: bool = True
"""Whether support multi-agent conversation"""
support_vision: bool = True
"""Whether support vision models"""
supported_blocks: list[type] = [
TextBlock,
ImageBlock,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
]
"""Supported message blocks for OpenAI API"""
def __init__(
self,
conversation_history_prompt: str = (
"# Conversation History\n"
"The content between <history></history> tags contains "
"your conversation history\n"
),
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the OpenAI multi-agent formatter.
Args:
conversation_history_prompt (`str`):
The prompt to use for the conversation history section.
"""
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
self.conversation_history_prompt = conversation_history_prompt
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the OpenAI API."""
return await OpenAIChatFormatter().format(msgs)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the OpenAI API."""
if is_first:
conversation_history_prompt = self.conversation_history_prompt
else:
conversation_history_prompt = ""
# Format into required OpenAI format
formatted_msgs: list[dict] = []
conversation_blocks: list = []
accumulated_text = []
images = []
audios = []
for msg in msgs:
for block in msg.get_content_blocks():
if block["type"] == "text":
accumulated_text.append(f"{msg.name}: {block['text']}")
elif block["type"] == "image":
source_type = block["source"]["type"]
if source_type == "url":
url = _to_openai_image_url(block["source"]["url"])
elif source_type == "base64":
data = block["source"]["data"]
media_type = block["source"]["media_type"]
url = f"data:{media_type};base64,{data}"
else:
raise ValueError(
f"Unsupported image source type: {source_type}",
)
images.append(
{
"type": "image_url",
"image_url": {
"url": url,
},
},
)
elif block["type"] == "audio":
input_audio = _to_openai_audio_data(block["source"])
audios.append(
{
"type": "input_audio",
"input_audio": input_audio,
},
)
if accumulated_text:
conversation_blocks.append(
{"text": "\n".join(accumulated_text)},
)
if conversation_blocks:
if conversation_blocks[0].get("text"):
conversation_blocks[0]["text"] = (
conversation_history_prompt
+ "<history>\n"
+ conversation_blocks[0]["text"]
)
else:
conversation_blocks.insert(
0,
{
"text": conversation_history_prompt + "<history>\n",
},
)
if conversation_blocks[-1].get("text"):
conversation_blocks[-1]["text"] += "\n</history>"
else:
conversation_blocks.append({"text": "</history>"})
conversation_blocks_text = "\n".join(
conversation_block.get("text", "")
for conversation_block in conversation_blocks
)
content_list: list[dict[str, Any]] = []
if conversation_blocks_text:
content_list.append(
{
"type": "text",
"text": conversation_blocks_text,
},
)
if images:
content_list.extend(images)
if audios:
content_list.extend(audios)
user_message = {
"role": "user",
"content": content_list,
}
if content_list:
formatted_msgs.append(user_message)
return formatted_msgs

View File

@@ -0,0 +1,297 @@
# -*- coding: utf-8 -*-
"""The truncated formatter base class, which allows to truncate the input
messages."""
from abc import ABC
from copy import deepcopy
from typing import (
Any,
Tuple,
Literal,
AsyncGenerator,
)
from ._formatter_base import FormatterBase
from ..message import Msg
from ..token import TokenCounterBase
from ..tracing import trace_format
class TruncatedFormatterBase(FormatterBase, ABC):
"""Base class for truncated formatters, which formats input messages into
required formats with tokens under a specified limit."""
def __init__(
self,
token_counter: TokenCounterBase | None = None,
max_tokens: int | None = None,
) -> None:
"""Initialize the TruncatedFormatterBase.
Args:
token_counter (`TokenCounterBase | None`, optional):
A token counter instance used to count tokens in the messages.
If not provided, the formatter will format the messages
without considering token limits.
max_tokens (`int | None`, optional):
The maximum number of tokens allowed in the formatted
messages. If not provided, the formatter will not truncate
the messages.
"""
self.token_counter = token_counter
assert (
max_tokens is None or 0 < max_tokens
), "max_tokens must be greater than 0"
self.max_tokens = max_tokens
@trace_format
async def format(
self,
msgs: list[Msg],
**kwargs: Any,
) -> list[dict[str, Any]]:
"""Format the input messages into the required format. If token
counter and max token limit are provided, the messages will be
truncated to fit the limit.
Args:
msgs (`list[Msg]`):
The input messages to be formatted.
Returns:
`list[dict[str, Any]]`:
The formatted messages in the required format.
"""
# Check if the input messages are valid
self.assert_list_of_msgs(msgs)
msgs = deepcopy(msgs)
while True:
formatted_msgs = await self._format(msgs)
n_tokens = await self._count(formatted_msgs)
if (
n_tokens is None
or self.max_tokens is None
or n_tokens <= self.max_tokens
):
return formatted_msgs
# truncate the input messages
msgs = await self._truncate(msgs)
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
"""Format the input messages into the required format. This method
should be implemented by the subclasses."""
formatted_msgs = []
start_index = 0
if len(msgs) > 0 and msgs[0].role == "system":
formatted_msgs.append(
await self._format_system_message(msgs[0]),
)
start_index = 1
is_first_agent_message = True
async for typ, group in self._group_messages(msgs[start_index:]):
match typ:
case "tool_sequence":
formatted_msgs.extend(
await self._format_tool_sequence(group),
)
case "agent_message":
formatted_msgs.extend(
await self._format_agent_message(
group,
is_first_agent_message,
),
)
is_first_agent_message = False
return formatted_msgs
async def _format_system_message(
self,
msg: Msg,
) -> dict[str, Any]:
"""Format system message for the LLM API.
.. note:: This is the default implementation. For certain LLM APIs
with specific requirements, you may need to implement a custom
formatting function to accommodate those particular needs.
"""
return {
"role": "system",
"content": msg.get_content_blocks("text"),
}
async def _format_tool_sequence(
self,
msgs: list[Msg],
) -> list[dict[str, Any]]:
"""Given a sequence of tool call/result messages, format them into
the required format for the LLM API."""
raise NotImplementedError(
"_format_tool_sequence is not implemented",
)
async def _format_agent_message(
self,
msgs: list[Msg],
is_first: bool = True,
) -> list[dict[str, Any]]:
"""Given a sequence of messages without tool calls/results, format
them into the required format for the LLM API."""
raise NotImplementedError(
"_format_agent_message is not implemented",
)
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
"""Truncate the input messages, so that it can fit the token limit.
This function is called only when
- both `token_counter` and `max_tokens` are provided,
- the formatted output of the input messages exceeds the token limit.
.. tip:: This function only provides a simple strategy, and developers
can override this method to implement more sophisticated
truncation strategies.
.. note:: The tool call message should be truncated together with
its corresponding tool result message to satisfy the LLM API
requirements.
Args:
msgs (`list[Msg]`):
The input messages to be truncated.
Raises:
`ValueError`:
If the system prompt message already exceeds the token limit,
or if there are tool calls without corresponding tool results.
Returns:
`list[Msg]`:
The truncated messages.
"""
start_index = 0
if len(msgs) > 0 and msgs[0].role == "system":
if len(msgs) == 1:
# If the system prompt already exceeds the token limit, we
# raise an error.
raise ValueError(
f"The system prompt message already exceeds the token "
f"limit ({self.max_tokens} tokens).",
)
start_index = 1
# Create a tool call IDs queues to delete the corresponding tool
# result message
tool_call_ids = set()
for i in range(start_index, len(msgs)):
msg = msgs[i]
for block in msg.get_content_blocks("tool_use"):
tool_call_ids.add(block["id"])
for block in msg.get_content_blocks("tool_result"):
try:
tool_call_ids.remove(block["id"])
except KeyError:
pass
# We can stop truncating if the queue is empty
if len(tool_call_ids) == 0:
return msgs[:start_index] + msgs[i + 1 :]
if len(tool_call_ids) > 0:
raise ValueError(
"The input messages contains tool call(s) that do not have "
f"the corresponding tool result(s): {tool_call_ids}. ",
)
return msgs[:start_index]
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
"""Count the number of tokens in the input messages. If token counter
is not provided, `None` will be returned.
Args:
msgs (`list[Msg]`):
The input messages to count tokens for.
"""
if self.token_counter is None:
return None
return await self.token_counter.count(msgs)
@staticmethod
async def _group_messages(
msgs: list[Msg],
) -> AsyncGenerator[
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
None,
]:
"""Group the input messages into two types and yield them as a
generator. The two types are:
- agent message that doesn't contain tool calls/results, and
- tool sequence that consisted of a sequence of tool calls/results
.. note:: The group operation is used in multi-agent scenario, where
multiple entities are involved in the input messages. So that to be
compatible with tools API, we have to group the messages and format
them with different strategies.
Args:
msgs (`list[Msg]`):
The input messages to be grouped, where the system prompt
message shouldn't be included.
Yields:
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
A generator that yields tuples of group type and the list of
messages in that group. The group type can be either
"tool_sequence" or "agent_message".
"""
group_type: Literal["tool_sequence", "agent_message"] | None = None
group = []
for msg in msgs:
if group_type is None:
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
group_type = "tool_sequence"
else:
group_type = "agent_message"
group.append(msg)
continue
# determine if this msg has the same type as the current group
if group_type == "tool_sequence":
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
group.append(msg)
else:
yield group_type, group
group = [msg]
group_type = "agent_message"
elif group_type == "agent_message":
if msg.has_content_blocks(
"tool_use",
) or msg.has_content_blocks("tool_result"):
yield group_type, group
group = [msg]
group_type = "tool_sequence"
else:
group.append(msg)
if group_type:
yield group_type, group