chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,46 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The formatter module in agentscope."""
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from ._dashscope_formatter import (
|
||||
DashScopeChatFormatter,
|
||||
DashScopeMultiAgentFormatter,
|
||||
)
|
||||
from ._anthropic_formatter import (
|
||||
AnthropicChatFormatter,
|
||||
AnthropicMultiAgentFormatter,
|
||||
)
|
||||
from ._openai_formatter import (
|
||||
OpenAIChatFormatter,
|
||||
OpenAIMultiAgentFormatter,
|
||||
)
|
||||
from ._gemini_formatter import (
|
||||
GeminiChatFormatter,
|
||||
GeminiMultiAgentFormatter,
|
||||
)
|
||||
from ._ollama_formatter import (
|
||||
OllamaChatFormatter,
|
||||
OllamaMultiAgentFormatter,
|
||||
)
|
||||
from ._deepseek_formatter import (
|
||||
DeepSeekChatFormatter,
|
||||
DeepSeekMultiAgentFormatter,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"FormatterBase",
|
||||
"TruncatedFormatterBase",
|
||||
"DashScopeChatFormatter",
|
||||
"DashScopeMultiAgentFormatter",
|
||||
"OpenAIChatFormatter",
|
||||
"OpenAIMultiAgentFormatter",
|
||||
"AnthropicChatFormatter",
|
||||
"AnthropicMultiAgentFormatter",
|
||||
"GeminiChatFormatter",
|
||||
"GeminiMultiAgentFormatter",
|
||||
"OllamaChatFormatter",
|
||||
"OllamaMultiAgentFormatter",
|
||||
"DeepSeekChatFormatter",
|
||||
"DeepSeekMultiAgentFormatter",
|
||||
]
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The Anthropic formatter module."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
class AnthropicChatFormatter(TruncatedFormatterBase):
|
||||
"""Formatter for Anthropic messages."""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into Anthropic API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
|
||||
.. note:: Anthropic suggests always passing all previous thinking
|
||||
blocks back to the API in subsequent calls to maintain reasoning
|
||||
continuity. For more details, please refer to
|
||||
`Anthropic's documentation
|
||||
<https://docs.anthropic.com/en/docs/build-with-claude/extended-thinking#preserving-thinking-blocks>`_.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for index, msg in enumerate(msgs):
|
||||
content_blocks = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ in ["thinking", "text", "image"]:
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
content_blocks.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "tool_use",
|
||||
"name": block.get("name"),
|
||||
"input": block.get("input", {}),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
output = block.get("output")
|
||||
if output is None:
|
||||
content_value = [{"type": "text", "text": None}]
|
||||
elif isinstance(output, list):
|
||||
content_value = output
|
||||
else:
|
||||
content_value = [{"type": "text", "text": str(output)}]
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": block.get("id"),
|
||||
"content": content_value,
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
# Claude only allow the first message to be system message
|
||||
if msg.role == "system" and index != 0:
|
||||
role = "user"
|
||||
else:
|
||||
role = msg.role
|
||||
|
||||
msg_anthropic = {
|
||||
"role": role,
|
||||
"content": content_blocks or None,
|
||||
}
|
||||
|
||||
# When both content and tool_calls are None, skipped
|
||||
if msg_anthropic["content"] or msg_anthropic.get("tool_calls"):
|
||||
messages.append(msg_anthropic)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class AnthropicMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
Anthropic formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Anthropic API."""
|
||||
return await AnthropicChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Anthropic API."""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required Anthropic format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
"type": "text",
|
||||
},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
conversation_blocks.append({**block})
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
"type": "text",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"type": "text",
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append(
|
||||
{"type": "text", "text": "</history>"},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_msgs
|
||||
@@ -0,0 +1,426 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The dashscope formatter module."""
|
||||
|
||||
import json
|
||||
import os.path
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from .._utils._common import _is_accessible_local_file
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _reformat_messages(
|
||||
messages: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Reformat the content to be compatible with HuggingFaceTokenCounter.
|
||||
|
||||
This function processes a list of messages and converts multi-part
|
||||
text content into single string content when all parts are plain text.
|
||||
This is necessary for compatibility with HuggingFaceTokenCounter which
|
||||
expects simple string content rather than structured content with
|
||||
multiple parts.
|
||||
|
||||
Args:
|
||||
messages (list[dict[str, Any]]):
|
||||
A list of message dictionaries where each message may contain a
|
||||
"content" field. The content can be either:
|
||||
- A string (unchanged)
|
||||
- A list of content items, where each item is a dict that may
|
||||
contain "text", "type", and other fields
|
||||
|
||||
Returns:
|
||||
list[dict[str, Any]]:
|
||||
A list of reformatted messages. For messages where all content
|
||||
items are plain text (have "text" field and either no "type"
|
||||
field or "type" == "text"), the content list is converted to a
|
||||
single newline-joined string. Other messages remain unchanged.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
# Case 1: All text content - will be converted
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello", "type": "text"},
|
||||
{"text": "World", "type": "text"}
|
||||
]
|
||||
}
|
||||
]
|
||||
result = _reformat_messages(messages)
|
||||
print(result[0]["content"])
|
||||
# Output: "Hello\nWorld"
|
||||
|
||||
# Case 2: Mixed content - will remain unchanged
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"text": "Hello", "type": "text"},
|
||||
{"image_url": "...", "type": "image"}
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
result = _reformat_messages(messages) # remain unchanged
|
||||
print(type(result[0]["content"]))
|
||||
# Output: <class 'list'>
|
||||
|
||||
"""
|
||||
for message in messages:
|
||||
content = message.get("content", [])
|
||||
|
||||
is_all_text = True
|
||||
texts = []
|
||||
for item in content:
|
||||
if not isinstance(item, dict) or "text" not in item:
|
||||
is_all_text = False
|
||||
break
|
||||
if "type" in item and item["type"] != "text":
|
||||
is_all_text = False
|
||||
break
|
||||
if item["text"]:
|
||||
texts.append(item["text"])
|
||||
|
||||
if is_all_text and texts:
|
||||
message["content"] = "\n".join(texts)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class DashScopeChatFormatter(TruncatedFormatterBase):
|
||||
"""Formatter for DashScope messages."""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into DashScope API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
formatted_msgs: list[dict] = []
|
||||
for msg in msgs:
|
||||
content_blocks = []
|
||||
tool_calls = []
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
|
||||
if typ == "text":
|
||||
content_blocks.append(
|
||||
{
|
||||
"text": block.get("text"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ in ["image", "audio"]:
|
||||
source = block["source"]
|
||||
if source["type"] == "url":
|
||||
url = source["url"]
|
||||
if _is_accessible_local_file(url):
|
||||
content_blocks.append(
|
||||
{typ: "file://" + os.path.abspath(url)},
|
||||
)
|
||||
else:
|
||||
# treat as web url
|
||||
content_blocks.append({typ: url})
|
||||
|
||||
elif source["type"] == "base64":
|
||||
media_type = source["media_type"]
|
||||
base64_data = source["data"]
|
||||
content_blocks.append(
|
||||
{typ: f"data:{media_type};base64,{base64_data}"},
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"Unsupported source type '{source.get('type')}' "
|
||||
f"for {typ} block.",
|
||||
)
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": self.convert_tool_result_to_string(
|
||||
block.get("output"), # type: ignore[arg-type]
|
||||
),
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
msg_dashscope = {
|
||||
"role": msg.role,
|
||||
"content": content_blocks or [{"text": None}],
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_dashscope["tool_calls"] = tool_calls
|
||||
|
||||
if msg_dashscope["content"] != [
|
||||
{"text": None},
|
||||
] or msg_dashscope.get(
|
||||
"tool_calls",
|
||||
):
|
||||
formatted_msgs.append(msg_dashscope)
|
||||
|
||||
return _reformat_messages(formatted_msgs)
|
||||
|
||||
|
||||
class DashScopeMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""DashScope formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
|
||||
.. note:: This formatter will combine previous messages (except tool
|
||||
calls/results) into a history section in the first system message with
|
||||
the conversation history prompt.
|
||||
|
||||
.. note:: For tool calls/results, they will be presented as separate
|
||||
messages as required by the DashScope API. Therefore, the tool calls/
|
||||
results messages are expected to be placed at the end of the input
|
||||
messages.
|
||||
|
||||
.. tip:: Telling the assistant's name in the system prompt is very
|
||||
important in multi-agent conversations. So that LLM can know who it
|
||||
is playing as.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DashScope multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the DashScope API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DashScope API.
|
||||
"""
|
||||
return await DashScopeChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into a user message with conversation history tags. For the
|
||||
first agent message, it will include the conversation history prompt.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DashScope API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required DashScope format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] in ["image", "audio"]:
|
||||
# Handle the accumulated text as a single block
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
if block["source"]["type"] == "url":
|
||||
url = block["source"]["url"]
|
||||
if _is_accessible_local_file(url):
|
||||
conversation_blocks.append(
|
||||
{
|
||||
block["type"]: "file://"
|
||||
+ os.path.abspath(url),
|
||||
},
|
||||
)
|
||||
else:
|
||||
conversation_blocks.append({block["type"]: url})
|
||||
|
||||
elif block["source"]["type"] == "base64":
|
||||
media_type = block["source"]["media_type"]
|
||||
base64_data = block["source"]["data"]
|
||||
conversation_blocks.append(
|
||||
{
|
||||
block[
|
||||
"type"
|
||||
]: f"data:{media_type};base64,{base64_data}",
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, "
|
||||
"skipped.",
|
||||
block["type"],
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append({"text": "\n".join(accumulated_text)})
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": conversation_blocks,
|
||||
},
|
||||
)
|
||||
|
||||
return _reformat_messages(formatted_msgs)
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for DashScope API."""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_text_content(),
|
||||
}
|
||||
@@ -0,0 +1,250 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The DeepSeek formatter module."""
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import Msg, TextBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
class DeepSeekChatFormatter(TruncatedFormatterBase):
|
||||
"""Formatter for DeepSeek messages."""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = False
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into DeepSeek API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for msg in msgs:
|
||||
content_blocks: list = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": self.convert_tool_result_to_string(
|
||||
block.get("output"), # type: ignore[arg-type]
|
||||
),
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
content_msg = "\n".join(
|
||||
content.get("text", "") for content in content_blocks
|
||||
)
|
||||
msg_deepseek = {
|
||||
"role": msg.role,
|
||||
"content": content_msg or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_deepseek["tool_calls"] = tool_calls
|
||||
|
||||
if msg_deepseek["content"] or msg_deepseek.get("tool_calls"):
|
||||
messages.append(msg_deepseek)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class DeepSeekMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
DeepSeek formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = False
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the DeepSeek multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the DeepSeek API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DeepSeek API.
|
||||
"""
|
||||
return await DeepSeekChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the DeepSeek API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the DeepSeek API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required DeepSeek format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": conversation_blocks_text,
|
||||
}
|
||||
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
@@ -0,0 +1,106 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The formatter module."""
|
||||
|
||||
from abc import abstractmethod
|
||||
from typing import Any, List
|
||||
|
||||
from .._utils._common import _save_base64_data
|
||||
from ..message import Msg, AudioBlock, ImageBlock, TextBlock
|
||||
|
||||
|
||||
class FormatterBase:
|
||||
"""The base class for formatters."""
|
||||
|
||||
@abstractmethod
|
||||
async def format(self, *args: Any, **kwargs: Any) -> list[dict[str, Any]]:
|
||||
"""Format the Msg objects to a list of dictionaries that satisfy the
|
||||
API requirements."""
|
||||
|
||||
@staticmethod
|
||||
def assert_list_of_msgs(msgs: list[Msg]) -> None:
|
||||
"""Assert that the input is a list of Msg objects.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be validated.
|
||||
"""
|
||||
if not isinstance(msgs, list):
|
||||
raise TypeError("Input must be a list of Msg objects.")
|
||||
|
||||
for msg in msgs:
|
||||
if not isinstance(msg, Msg):
|
||||
raise TypeError(
|
||||
f"Expected Msg object, got {type(msg)} instead.",
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def convert_tool_result_to_string(
|
||||
output: str | List[TextBlock | ImageBlock | AudioBlock],
|
||||
) -> str:
|
||||
"""Turn the tool result list into a textual output to be compatible
|
||||
with the LLM API that doesn't support multimodal data.
|
||||
|
||||
Args:
|
||||
output (`str | List[TextBlock | ImageBlock | AudioBlock]`):
|
||||
The output of the tool response, including text and multimodal
|
||||
data like images and audio.
|
||||
|
||||
Returns:
|
||||
`str`:
|
||||
A string representation of the tool result, with text blocks
|
||||
concatenated and multimodal data represented by file paths
|
||||
or URLs.
|
||||
"""
|
||||
|
||||
if isinstance(output, str):
|
||||
return output
|
||||
|
||||
textual_output = []
|
||||
for block in output:
|
||||
assert isinstance(block, dict) and "type" in block, (
|
||||
f"Invalid block: {block}, a TextBlock, ImageBlock, or "
|
||||
f"AudioBlock is expected."
|
||||
)
|
||||
if block["type"] == "text":
|
||||
textual_output.append(block["text"])
|
||||
|
||||
elif block["type"] in ["image", "audio", "video"]:
|
||||
assert "source" in block, (
|
||||
f"Invalid {block['type']} block: {block}, 'source' key "
|
||||
"is required."
|
||||
)
|
||||
source = block["source"]
|
||||
# Save the image locally and return the file path
|
||||
if source["type"] == "url":
|
||||
textual_output.append(
|
||||
f"The returned {block['type']} can be found "
|
||||
f"at: {source['url']}",
|
||||
)
|
||||
|
||||
elif source["type"] == "base64":
|
||||
path_temp_file = _save_base64_data(
|
||||
source["media_type"],
|
||||
source["data"],
|
||||
)
|
||||
textual_output.append(
|
||||
f"The returned {block['type']} can be found "
|
||||
f"at: {path_temp_file}",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid image source: {block['source']}, "
|
||||
"expected 'url' or 'base64'.",
|
||||
)
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported block type: {block['type']}, "
|
||||
"expected 'text', 'image', 'audio', or 'video'.",
|
||||
)
|
||||
|
||||
if len(textual_output) == 1:
|
||||
return textual_output[0]
|
||||
|
||||
else:
|
||||
return "\n".join("- " + _ for _ in textual_output)
|
||||
@@ -0,0 +1,405 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""Google gemini API formatter in agentscope."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._utils._common import _get_bytes_from_web_url
|
||||
from ..message import (
|
||||
Msg,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
VideoBlock,
|
||||
)
|
||||
from .._logging import logger
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _to_gemini_inline_data(url: str) -> dict:
|
||||
"""Convert url into the Gemini API required format."""
|
||||
parsed_url = urlparse(url)
|
||||
extension = url.split(".")[-1].lower()
|
||||
|
||||
# Pre-calculate media type from extension (image/audio/video).
|
||||
typ = None
|
||||
for k, v in GeminiChatFormatter.supported_extensions.items():
|
||||
if extension in v:
|
||||
typ = k
|
||||
break
|
||||
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
# Web url
|
||||
if typ is None:
|
||||
raise TypeError(
|
||||
f"Unsupported file extension: {extension}, expected "
|
||||
f"{GeminiChatFormatter.supported_extensions}",
|
||||
)
|
||||
|
||||
data = _get_bytes_from_web_url(url)
|
||||
return {
|
||||
"data": data,
|
||||
"mime_type": f"{typ}/{extension}",
|
||||
}
|
||||
|
||||
elif os.path.exists(url):
|
||||
# Local file
|
||||
if typ is None:
|
||||
raise TypeError(
|
||||
f"Unsupported file extension: {extension}, expected "
|
||||
f"{GeminiChatFormatter.supported_extensions}",
|
||||
)
|
||||
|
||||
with open(url, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"mime_type": f"{typ}/{extension}",
|
||||
}
|
||||
|
||||
raise ValueError(
|
||||
f"The URL `{url}` is not a valid image URL or local file.",
|
||||
)
|
||||
|
||||
|
||||
class GeminiChatFormatter(TruncatedFormatterBase):
|
||||
"""The formatter for Google Gemini API."""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
supported_extensions: dict[str, list[str]] = {
|
||||
"image": ["png", "jpeg", "webp", "heic", "heif"],
|
||||
"video": [
|
||||
"mp4",
|
||||
"mpeg",
|
||||
"mov",
|
||||
"avi",
|
||||
"x-flv",
|
||||
"mpg",
|
||||
"webm",
|
||||
"wmv",
|
||||
"3gpp",
|
||||
],
|
||||
"audio": ["mp3", "wav", "aiff", "aac", "ogg", "flac"],
|
||||
}
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict]:
|
||||
"""Format message objects into Gemini API required format."""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list = []
|
||||
for msg in msgs:
|
||||
parts = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
parts.append(
|
||||
{
|
||||
"text": block.get("text"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_use":
|
||||
parts.append(
|
||||
{
|
||||
"function_call": {
|
||||
"id": block["id"],
|
||||
"name": block["name"],
|
||||
"args": block["input"],
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
text_output = self.convert_tool_result_to_string(
|
||||
block["output"], # type: ignore[arg-type]
|
||||
)
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"function_response": {
|
||||
"id": block["id"],
|
||||
"name": block["name"],
|
||||
"response": {
|
||||
"output": text_output,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
elif typ in ["image", "audio", "video"]:
|
||||
if block["source"]["type"] == "base64":
|
||||
media_type = block["source"]["media_type"]
|
||||
base64_data = block["source"]["data"]
|
||||
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"data": base64_data,
|
||||
"mime_type": media_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif block["source"]["type"] == "url":
|
||||
parts.append(
|
||||
{
|
||||
"inline_data": _to_gemini_inline_data(
|
||||
block["source"]["url"],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type: %s in the message, skipped. ",
|
||||
typ,
|
||||
)
|
||||
|
||||
role = "model" if msg.role == "assistant" else "user"
|
||||
|
||||
if parts:
|
||||
messages.append(
|
||||
{
|
||||
"role": role,
|
||||
"parts": parts,
|
||||
},
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class GeminiMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""The multi-agent formatter for Google Gemini API, where more than a
|
||||
user and an agent are involved.
|
||||
|
||||
.. note:: This formatter will combine previous messages (except tool
|
||||
calls/results) into a history section in the first system message with
|
||||
the conversation history prompt.
|
||||
|
||||
.. note:: For tool calls/results, they will be presented as separate
|
||||
messages as required by the Gemini API. Therefore, the tool calls/
|
||||
results messages are expected to be placed at the end of the input
|
||||
messages.
|
||||
|
||||
.. tip:: Telling the assistant's name in the system prompt is very
|
||||
important in multi-agent conversations. So that LLM can know who it
|
||||
is playing as.
|
||||
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
VideoBlock,
|
||||
AudioBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Gemini multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to be used for the conversation history section.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the Gemini API."""
|
||||
return {
|
||||
"role": "user",
|
||||
"parts": [
|
||||
{
|
||||
"text": msg.get_text_content(),
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Gemini API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Gemini API.
|
||||
"""
|
||||
return await GeminiChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Gemini API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Gemini API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into Gemini API required format
|
||||
formatted_msgs: list = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_parts: list = []
|
||||
accumulated_text = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] in ["image", "video", "audio"]:
|
||||
# handle the accumulated text as a single part if exists
|
||||
if accumulated_text:
|
||||
conversation_parts.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
# handle the multimodal data
|
||||
if block["source"]["type"] == "url":
|
||||
conversation_parts.append(
|
||||
{
|
||||
"inline_data": _to_gemini_inline_data(
|
||||
block["source"]["url"],
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
elif block["source"]["type"] == "base64":
|
||||
media_type = block["source"]["media_type"]
|
||||
base64_data = block["source"]["data"]
|
||||
conversation_parts.append(
|
||||
{
|
||||
"inline_data": {
|
||||
"data": base64_data,
|
||||
"mime_type": media_type,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_parts.append(
|
||||
{
|
||||
"text": "\n".join(accumulated_text),
|
||||
},
|
||||
)
|
||||
|
||||
# Add prompt and <history></history> tags around conversation history
|
||||
if conversation_parts:
|
||||
if conversation_parts[0].get("text"):
|
||||
conversation_parts[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>"
|
||||
+ conversation_parts[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_parts.insert(
|
||||
0,
|
||||
{"text": conversation_history_prompt + "<history>"},
|
||||
)
|
||||
|
||||
if conversation_parts[-1].get("text"):
|
||||
conversation_parts[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_parts.append(
|
||||
{"text": "</history>"},
|
||||
)
|
||||
|
||||
formatted_msgs.append(
|
||||
{
|
||||
"role": "user",
|
||||
"parts": conversation_parts,
|
||||
},
|
||||
)
|
||||
|
||||
return formatted_msgs
|
||||
@@ -0,0 +1,320 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The Ollama formatter module."""
|
||||
import base64
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from .._utils._common import _get_bytes_from_web_url
|
||||
from ..message import Msg, TextBlock, ImageBlock, ToolUseBlock, ToolResultBlock
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _convert_ollama_image_url_to_base64_data(url: str) -> str:
|
||||
"""Convert image url to base64."""
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
# Web url
|
||||
data = _get_bytes_from_web_url(url)
|
||||
return data
|
||||
if os.path.exists(url):
|
||||
# Local file
|
||||
with open(url, "rb") as f:
|
||||
data = base64.b64encode(f.read()).decode("utf-8")
|
||||
|
||||
return data
|
||||
|
||||
raise ValueError(
|
||||
f"The URL `{url}` is not a valid image URL or local file.",
|
||||
)
|
||||
|
||||
|
||||
class OllamaChatFormatter(TruncatedFormatterBase):
|
||||
"""Formatter for Ollama messages."""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = False
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into Ollama API format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of message objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages as a list of dictionaries.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for msg in msgs:
|
||||
content_blocks: list = []
|
||||
tool_calls = []
|
||||
images = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": block.get("input", {}),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": self.convert_tool_result_to_string(
|
||||
block.get("output"), # type: ignore[arg-type]
|
||||
),
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "image":
|
||||
source_type = block["source"]["type"]
|
||||
if source_type == "url":
|
||||
images.append(
|
||||
_convert_ollama_image_url_to_base64_data(
|
||||
block["source"]["url"],
|
||||
),
|
||||
)
|
||||
elif source_type == "base64":
|
||||
images.append(block["source"]["data"])
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
content_msg = "\n".join(
|
||||
content.get("text", "") for content in content_blocks
|
||||
)
|
||||
msg_ollama = {
|
||||
"role": msg.role,
|
||||
"content": content_msg or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_ollama["tool_calls"] = tool_calls
|
||||
|
||||
if images:
|
||||
msg_ollama["images"] = images
|
||||
|
||||
if msg_ollama["content"] or msg_ollama.get("tool_calls"):
|
||||
messages.append(msg_ollama)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OllamaMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
Ollama formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversations"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision data"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
# Multimodal
|
||||
ImageBlock,
|
||||
# Tool use
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""The list of supported message blocks"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the Ollama multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
The token counter used for truncation.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If `None`, no truncation will be applied.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the Ollama API."""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_text_content(),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the Ollama API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of messages containing tool calls/results to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the Ollama API.
|
||||
"""
|
||||
return await OllamaChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the Ollama API.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
A list of Msg objects to be formatted.
|
||||
is_first (`bool`, defaults to `True`):
|
||||
Whether this is the first agent message in the conversation.
|
||||
If `True`, the conversation history prompt will be included.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries formatted for the ollama API.
|
||||
"""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required Ollama format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
# Collect the multimodal files
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
images = []
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
# Handle the accumulated text as a single block
|
||||
source = block["source"]
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
accumulated_text.clear()
|
||||
|
||||
if source["type"] == "url":
|
||||
images.append(
|
||||
_convert_ollama_image_url_to_base64_data(
|
||||
source["url"],
|
||||
),
|
||||
)
|
||||
|
||||
elif source["type"] == "base64":
|
||||
images.append(source["data"])
|
||||
|
||||
conversation_blocks.append({**block})
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": conversation_blocks_text,
|
||||
}
|
||||
if images:
|
||||
user_message["images"] = images
|
||||
if conversation_blocks:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
@@ -0,0 +1,414 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
# pylint: disable=too-many-branches
|
||||
"""The OpenAI formatter for agentscope."""
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
|
||||
from ._truncated_formatter_base import TruncatedFormatterBase
|
||||
from .._logging import logger
|
||||
from ..message import (
|
||||
Msg,
|
||||
URLSource,
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
Base64Source,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
)
|
||||
from ..token import TokenCounterBase
|
||||
|
||||
|
||||
def _to_openai_image_url(url: str) -> str:
|
||||
"""Convert an image url to openai format. If the given url is a local
|
||||
file, it will be converted to base64 format. Otherwise, it will be
|
||||
returned directly.
|
||||
|
||||
Args:
|
||||
url (`str`):
|
||||
The local or public url of the image.
|
||||
"""
|
||||
# See https://platform.openai.com/docs/guides/vision for details of
|
||||
# support image extensions.
|
||||
support_image_extensions = (
|
||||
".png",
|
||||
".jpg",
|
||||
".jpeg",
|
||||
".gif",
|
||||
".webp",
|
||||
)
|
||||
|
||||
parsed_url = urlparse(url)
|
||||
|
||||
lower_url = url.lower()
|
||||
|
||||
# Web url
|
||||
if not os.path.exists(url) and parsed_url.scheme != "":
|
||||
if any(lower_url.endswith(_) for _ in support_image_extensions):
|
||||
return url
|
||||
|
||||
# Check if it is a local file
|
||||
elif os.path.exists(url) and os.path.isfile(url):
|
||||
if any(lower_url.endswith(_) for _ in support_image_extensions):
|
||||
with open(url, "rb") as image_file:
|
||||
base64_image = base64.b64encode(image_file.read()).decode(
|
||||
"utf-8",
|
||||
)
|
||||
extension = parsed_url.path.lower().split(".")[-1]
|
||||
mime_type = f"image/{extension}"
|
||||
return f"data:{mime_type};base64,{base64_image}"
|
||||
|
||||
raise TypeError(f'"{url}" should end with {support_image_extensions}.')
|
||||
|
||||
|
||||
def _to_openai_audio_data(source: URLSource | Base64Source) -> dict:
|
||||
"""Covert an audio source to OpenAI format."""
|
||||
if source["type"] == "url":
|
||||
extension = source["url"].split(".")[-1].lower()
|
||||
if extension not in ["wav", "mp3"]:
|
||||
raise TypeError(
|
||||
f"Unsupported audio file extension: {extension}, "
|
||||
"wav and mp3 are supported.",
|
||||
)
|
||||
|
||||
parsed_url = urlparse(source["url"])
|
||||
|
||||
if os.path.exists(source["url"]):
|
||||
with open(source["url"], "rb") as audio_file:
|
||||
data = base64.b64encode(audio_file.read()).decode("utf-8")
|
||||
|
||||
# web url
|
||||
elif parsed_url.scheme != "":
|
||||
response = requests.get(source["url"])
|
||||
response.raise_for_status()
|
||||
data = base64.b64encode(response.content).decode("utf-8")
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported audio source: {source['url']}, "
|
||||
"it should be a local file or a web URL.",
|
||||
)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"format": extension,
|
||||
}
|
||||
|
||||
if source["type"] == "base64":
|
||||
data = source["data"]
|
||||
media_type = source["media_type"]
|
||||
|
||||
if media_type not in ["audio/wav", "audio/mp3"]:
|
||||
raise TypeError(
|
||||
f"Unsupported audio media type: {media_type}, "
|
||||
"only audio/wav and audio/mp3 are supported.",
|
||||
)
|
||||
|
||||
return {
|
||||
"data": data,
|
||||
"format": media_type.split("/")[-1],
|
||||
}
|
||||
|
||||
raise TypeError(f"Unsupported audio source: {source['type']}.")
|
||||
|
||||
|
||||
class OpenAIChatFormatter(TruncatedFormatterBase):
|
||||
"""The class used to format message objects into the OpenAI API required
|
||||
format."""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversation"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision models"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""Supported message blocks for OpenAI API"""
|
||||
|
||||
async def _format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format message objects into OpenAI API required format.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The list of Msg objects to format.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
A list of dictionaries, where each dictionary has "name",
|
||||
"role", and "content" keys.
|
||||
"""
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
messages: list[dict] = []
|
||||
for msg in msgs:
|
||||
content_blocks = []
|
||||
tool_calls = []
|
||||
|
||||
for block in msg.get_content_blocks():
|
||||
typ = block.get("type")
|
||||
if typ == "text":
|
||||
content_blocks.append({**block})
|
||||
|
||||
elif typ == "tool_use":
|
||||
tool_calls.append(
|
||||
{
|
||||
"id": block.get("id"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": block.get("name"),
|
||||
"arguments": json.dumps(
|
||||
block.get("input", {}),
|
||||
ensure_ascii=False,
|
||||
),
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "tool_result":
|
||||
messages.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": block.get("id"),
|
||||
"content": self.convert_tool_result_to_string(
|
||||
block.get("output"), # type: ignore[arg-type]
|
||||
),
|
||||
"name": block.get("name"),
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "image":
|
||||
source_type = block["source"]["type"]
|
||||
if source_type == "url":
|
||||
url = _to_openai_image_url(block["source"]["url"])
|
||||
|
||||
elif source_type == "base64":
|
||||
data = block["source"]["data"]
|
||||
media_type = block["source"]["media_type"]
|
||||
url = f"data:{media_type};base64,{data}"
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image source type: {source_type}",
|
||||
)
|
||||
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
elif typ == "audio":
|
||||
input_audio = _to_openai_audio_data(block["source"])
|
||||
content_blocks.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": input_audio,
|
||||
},
|
||||
)
|
||||
|
||||
else:
|
||||
logger.warning(
|
||||
"Unsupported block type %s in the message, skipped.",
|
||||
typ,
|
||||
)
|
||||
|
||||
msg_openai = {
|
||||
"role": msg.role,
|
||||
"name": msg.name,
|
||||
"content": content_blocks or None,
|
||||
}
|
||||
|
||||
if tool_calls:
|
||||
msg_openai["tool_calls"] = tool_calls
|
||||
|
||||
# When both content and tool_calls are None, skipped
|
||||
if msg_openai["content"] or msg_openai.get("tool_calls"):
|
||||
messages.append(msg_openai)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class OpenAIMultiAgentFormatter(TruncatedFormatterBase):
|
||||
"""
|
||||
OpenAI formatter for multi-agent conversations, where more than
|
||||
a user and an agent are involved.
|
||||
.. tip:: This formatter is compatible with OpenAI API and
|
||||
OpenAI-compatible services like vLLM, Azure OpenAI, and others.
|
||||
"""
|
||||
|
||||
support_tools_api: bool = True
|
||||
"""Whether support tools API"""
|
||||
|
||||
support_multiagent: bool = True
|
||||
"""Whether support multi-agent conversation"""
|
||||
|
||||
support_vision: bool = True
|
||||
"""Whether support vision models"""
|
||||
|
||||
supported_blocks: list[type] = [
|
||||
TextBlock,
|
||||
ImageBlock,
|
||||
AudioBlock,
|
||||
ToolUseBlock,
|
||||
ToolResultBlock,
|
||||
]
|
||||
"""Supported message blocks for OpenAI API"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
conversation_history_prompt: str = (
|
||||
"# Conversation History\n"
|
||||
"The content between <history></history> tags contains "
|
||||
"your conversation history\n"
|
||||
),
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the OpenAI multi-agent formatter.
|
||||
|
||||
Args:
|
||||
conversation_history_prompt (`str`):
|
||||
The prompt to use for the conversation history section.
|
||||
"""
|
||||
super().__init__(token_counter=token_counter, max_tokens=max_tokens)
|
||||
self.conversation_history_prompt = conversation_history_prompt
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the OpenAI API."""
|
||||
return await OpenAIChatFormatter().format(msgs)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the OpenAI API."""
|
||||
|
||||
if is_first:
|
||||
conversation_history_prompt = self.conversation_history_prompt
|
||||
else:
|
||||
conversation_history_prompt = ""
|
||||
|
||||
# Format into required OpenAI format
|
||||
formatted_msgs: list[dict] = []
|
||||
|
||||
conversation_blocks: list = []
|
||||
accumulated_text = []
|
||||
images = []
|
||||
audios = []
|
||||
|
||||
for msg in msgs:
|
||||
for block in msg.get_content_blocks():
|
||||
if block["type"] == "text":
|
||||
accumulated_text.append(f"{msg.name}: {block['text']}")
|
||||
|
||||
elif block["type"] == "image":
|
||||
source_type = block["source"]["type"]
|
||||
if source_type == "url":
|
||||
url = _to_openai_image_url(block["source"]["url"])
|
||||
|
||||
elif source_type == "base64":
|
||||
data = block["source"]["data"]
|
||||
media_type = block["source"]["media_type"]
|
||||
url = f"data:{media_type};base64,{data}"
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unsupported image source type: {source_type}",
|
||||
)
|
||||
images.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": url,
|
||||
},
|
||||
},
|
||||
)
|
||||
elif block["type"] == "audio":
|
||||
input_audio = _to_openai_audio_data(block["source"])
|
||||
audios.append(
|
||||
{
|
||||
"type": "input_audio",
|
||||
"input_audio": input_audio,
|
||||
},
|
||||
)
|
||||
|
||||
if accumulated_text:
|
||||
conversation_blocks.append(
|
||||
{"text": "\n".join(accumulated_text)},
|
||||
)
|
||||
|
||||
if conversation_blocks:
|
||||
if conversation_blocks[0].get("text"):
|
||||
conversation_blocks[0]["text"] = (
|
||||
conversation_history_prompt
|
||||
+ "<history>\n"
|
||||
+ conversation_blocks[0]["text"]
|
||||
)
|
||||
|
||||
else:
|
||||
conversation_blocks.insert(
|
||||
0,
|
||||
{
|
||||
"text": conversation_history_prompt + "<history>\n",
|
||||
},
|
||||
)
|
||||
|
||||
if conversation_blocks[-1].get("text"):
|
||||
conversation_blocks[-1]["text"] += "\n</history>"
|
||||
|
||||
else:
|
||||
conversation_blocks.append({"text": "</history>"})
|
||||
|
||||
conversation_blocks_text = "\n".join(
|
||||
conversation_block.get("text", "")
|
||||
for conversation_block in conversation_blocks
|
||||
)
|
||||
|
||||
content_list: list[dict[str, Any]] = []
|
||||
if conversation_blocks_text:
|
||||
content_list.append(
|
||||
{
|
||||
"type": "text",
|
||||
"text": conversation_blocks_text,
|
||||
},
|
||||
)
|
||||
if images:
|
||||
content_list.extend(images)
|
||||
if audios:
|
||||
content_list.extend(audios)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": content_list,
|
||||
}
|
||||
|
||||
if content_list:
|
||||
formatted_msgs.append(user_message)
|
||||
|
||||
return formatted_msgs
|
||||
@@ -0,0 +1,297 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The truncated formatter base class, which allows to truncate the input
|
||||
messages."""
|
||||
from abc import ABC
|
||||
from copy import deepcopy
|
||||
from typing import (
|
||||
Any,
|
||||
Tuple,
|
||||
Literal,
|
||||
AsyncGenerator,
|
||||
)
|
||||
|
||||
from ._formatter_base import FormatterBase
|
||||
from ..message import Msg
|
||||
from ..token import TokenCounterBase
|
||||
from ..tracing import trace_format
|
||||
|
||||
|
||||
class TruncatedFormatterBase(FormatterBase, ABC):
|
||||
"""Base class for truncated formatters, which formats input messages into
|
||||
required formats with tokens under a specified limit."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
token_counter: TokenCounterBase | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> None:
|
||||
"""Initialize the TruncatedFormatterBase.
|
||||
|
||||
Args:
|
||||
token_counter (`TokenCounterBase | None`, optional):
|
||||
A token counter instance used to count tokens in the messages.
|
||||
If not provided, the formatter will format the messages
|
||||
without considering token limits.
|
||||
max_tokens (`int | None`, optional):
|
||||
The maximum number of tokens allowed in the formatted
|
||||
messages. If not provided, the formatter will not truncate
|
||||
the messages.
|
||||
"""
|
||||
self.token_counter = token_counter
|
||||
|
||||
assert (
|
||||
max_tokens is None or 0 < max_tokens
|
||||
), "max_tokens must be greater than 0"
|
||||
self.max_tokens = max_tokens
|
||||
|
||||
@trace_format
|
||||
async def format(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
**kwargs: Any,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. If token
|
||||
counter and max token limit are provided, the messages will be
|
||||
truncated to fit the limit.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be formatted.
|
||||
|
||||
Returns:
|
||||
`list[dict[str, Any]]`:
|
||||
The formatted messages in the required format.
|
||||
"""
|
||||
|
||||
# Check if the input messages are valid
|
||||
self.assert_list_of_msgs(msgs)
|
||||
|
||||
msgs = deepcopy(msgs)
|
||||
|
||||
while True:
|
||||
formatted_msgs = await self._format(msgs)
|
||||
n_tokens = await self._count(formatted_msgs)
|
||||
|
||||
if (
|
||||
n_tokens is None
|
||||
or self.max_tokens is None
|
||||
or n_tokens <= self.max_tokens
|
||||
):
|
||||
return formatted_msgs
|
||||
|
||||
# truncate the input messages
|
||||
msgs = await self._truncate(msgs)
|
||||
|
||||
async def _format(self, msgs: list[Msg]) -> list[dict[str, Any]]:
|
||||
"""Format the input messages into the required format. This method
|
||||
should be implemented by the subclasses."""
|
||||
|
||||
formatted_msgs = []
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
formatted_msgs.append(
|
||||
await self._format_system_message(msgs[0]),
|
||||
)
|
||||
start_index = 1
|
||||
|
||||
is_first_agent_message = True
|
||||
async for typ, group in self._group_messages(msgs[start_index:]):
|
||||
match typ:
|
||||
case "tool_sequence":
|
||||
formatted_msgs.extend(
|
||||
await self._format_tool_sequence(group),
|
||||
)
|
||||
case "agent_message":
|
||||
formatted_msgs.extend(
|
||||
await self._format_agent_message(
|
||||
group,
|
||||
is_first_agent_message,
|
||||
),
|
||||
)
|
||||
is_first_agent_message = False
|
||||
|
||||
return formatted_msgs
|
||||
|
||||
async def _format_system_message(
|
||||
self,
|
||||
msg: Msg,
|
||||
) -> dict[str, Any]:
|
||||
"""Format system message for the LLM API.
|
||||
|
||||
.. note:: This is the default implementation. For certain LLM APIs
|
||||
with specific requirements, you may need to implement a custom
|
||||
formatting function to accommodate those particular needs.
|
||||
"""
|
||||
return {
|
||||
"role": "system",
|
||||
"content": msg.get_content_blocks("text"),
|
||||
}
|
||||
|
||||
async def _format_tool_sequence(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of tool call/result messages, format them into
|
||||
the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_tool_sequence is not implemented",
|
||||
)
|
||||
|
||||
async def _format_agent_message(
|
||||
self,
|
||||
msgs: list[Msg],
|
||||
is_first: bool = True,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Given a sequence of messages without tool calls/results, format
|
||||
them into the required format for the LLM API."""
|
||||
raise NotImplementedError(
|
||||
"_format_agent_message is not implemented",
|
||||
)
|
||||
|
||||
async def _truncate(self, msgs: list[Msg]) -> list[Msg]:
|
||||
"""Truncate the input messages, so that it can fit the token limit.
|
||||
This function is called only when
|
||||
|
||||
- both `token_counter` and `max_tokens` are provided,
|
||||
- the formatted output of the input messages exceeds the token limit.
|
||||
|
||||
.. tip:: This function only provides a simple strategy, and developers
|
||||
can override this method to implement more sophisticated
|
||||
truncation strategies.
|
||||
|
||||
.. note:: The tool call message should be truncated together with
|
||||
its corresponding tool result message to satisfy the LLM API
|
||||
requirements.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be truncated.
|
||||
|
||||
Raises:
|
||||
`ValueError`:
|
||||
If the system prompt message already exceeds the token limit,
|
||||
or if there are tool calls without corresponding tool results.
|
||||
|
||||
Returns:
|
||||
`list[Msg]`:
|
||||
The truncated messages.
|
||||
"""
|
||||
start_index = 0
|
||||
if len(msgs) > 0 and msgs[0].role == "system":
|
||||
if len(msgs) == 1:
|
||||
# If the system prompt already exceeds the token limit, we
|
||||
# raise an error.
|
||||
raise ValueError(
|
||||
f"The system prompt message already exceeds the token "
|
||||
f"limit ({self.max_tokens} tokens).",
|
||||
)
|
||||
|
||||
start_index = 1
|
||||
|
||||
# Create a tool call IDs queues to delete the corresponding tool
|
||||
# result message
|
||||
tool_call_ids = set()
|
||||
for i in range(start_index, len(msgs)):
|
||||
msg = msgs[i]
|
||||
for block in msg.get_content_blocks("tool_use"):
|
||||
tool_call_ids.add(block["id"])
|
||||
|
||||
for block in msg.get_content_blocks("tool_result"):
|
||||
try:
|
||||
tool_call_ids.remove(block["id"])
|
||||
except KeyError:
|
||||
pass
|
||||
|
||||
# We can stop truncating if the queue is empty
|
||||
if len(tool_call_ids) == 0:
|
||||
return msgs[:start_index] + msgs[i + 1 :]
|
||||
|
||||
if len(tool_call_ids) > 0:
|
||||
raise ValueError(
|
||||
"The input messages contains tool call(s) that do not have "
|
||||
f"the corresponding tool result(s): {tool_call_ids}. ",
|
||||
)
|
||||
|
||||
return msgs[:start_index]
|
||||
|
||||
async def _count(self, msgs: list[dict[str, Any]]) -> int | None:
|
||||
"""Count the number of tokens in the input messages. If token counter
|
||||
is not provided, `None` will be returned.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to count tokens for.
|
||||
"""
|
||||
if self.token_counter is None:
|
||||
return None
|
||||
|
||||
return await self.token_counter.count(msgs)
|
||||
|
||||
@staticmethod
|
||||
async def _group_messages(
|
||||
msgs: list[Msg],
|
||||
) -> AsyncGenerator[
|
||||
Tuple[Literal["tool_sequence", "agent_message"], list[Msg]],
|
||||
None,
|
||||
]:
|
||||
"""Group the input messages into two types and yield them as a
|
||||
generator. The two types are:
|
||||
|
||||
- agent message that doesn't contain tool calls/results, and
|
||||
- tool sequence that consisted of a sequence of tool calls/results
|
||||
|
||||
.. note:: The group operation is used in multi-agent scenario, where
|
||||
multiple entities are involved in the input messages. So that to be
|
||||
compatible with tools API, we have to group the messages and format
|
||||
them with different strategies.
|
||||
|
||||
Args:
|
||||
msgs (`list[Msg]`):
|
||||
The input messages to be grouped, where the system prompt
|
||||
message shouldn't be included.
|
||||
|
||||
Yields:
|
||||
`AsyncGenerator[Tuple[str, list[Msg]], None]`:
|
||||
A generator that yields tuples of group type and the list of
|
||||
messages in that group. The group type can be either
|
||||
"tool_sequence" or "agent_message".
|
||||
"""
|
||||
|
||||
group_type: Literal["tool_sequence", "agent_message"] | None = None
|
||||
group = []
|
||||
for msg in msgs:
|
||||
if group_type is None:
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group_type = "tool_sequence"
|
||||
else:
|
||||
group_type = "agent_message"
|
||||
|
||||
group.append(msg)
|
||||
continue
|
||||
|
||||
# determine if this msg has the same type as the current group
|
||||
if group_type == "tool_sequence":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
group.append(msg)
|
||||
|
||||
else:
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "agent_message"
|
||||
|
||||
elif group_type == "agent_message":
|
||||
if msg.has_content_blocks(
|
||||
"tool_use",
|
||||
) or msg.has_content_blocks("tool_result"):
|
||||
yield group_type, group
|
||||
group = [msg]
|
||||
group_type = "tool_sequence"
|
||||
|
||||
else:
|
||||
group.append(msg)
|
||||
if group_type:
|
||||
yield group_type, group
|
||||
Reference in New Issue
Block a user