chore: 添加虚拟环境到仓库

- 添加 backend_service/venv 虚拟环境
- 包含所有Python依赖包
- 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
2025-12-03 10:19:25 +08:00
parent a6c2027caa
commit c4f851d387
12655 changed files with 3009376 additions and 0 deletions

View File

@@ -0,0 +1,133 @@
# -*- coding: utf-8 -*-
"""The agentscope serialization module"""
import os
import requests
from . import exception
from . import module
from . import message
from . import model
from . import tool
from . import formatter
from . import memory
from . import agent
from . import session
from . import embedding
from . import token
from . import evaluate
from . import pipeline
from . import tracing
from . import rag
from ._logging import (
logger,
setup_logger,
)
from .hooks import _equip_as_studio_hooks
from ._version import __version__
def init(
project: str | None = None,
name: str | None = None,
logging_path: str | None = None,
logging_level: str = "INFO",
studio_url: str | None = None,
tracing_url: str | None = None,
) -> None:
"""Initialize the agentscope library.
Args:
project (`str | None`, optional):
The project name.
name (`str | None`, optional):
The name of the run.
logging_path (`str | None`, optional):
The path to saving the log file. If not provided, logs will not be
saved.
logging_level (`str | None`, optional):
The logging level. Defaults to "INFO".
studio_url (`str | None`, optional):
The URL of the AgentScope Studio to connect to.
tracing_url (`str | None`, optional):
The URL of the tracing endpoint, which can connect to third-party
OpenTelemetry tracing platforms like Arize-Phoenix and Langfuse.
If not provided and `studio_url` is provided, it will send traces
to the AgentScope Studio's tracing endpoint.
"""
from . import _config
if project:
_config.project = project
if name:
_config.name = name
setup_logger(logging_level, logging_path)
if studio_url:
# Register the run
data = {
"id": _config.run_id,
"project": _config.project,
"name": _config.name,
"timestamp": _config.created_at,
"pid": os.getpid(),
"status": "running",
# Deprecated fields
"run_dir": "",
}
response = requests.post(
url=f"{studio_url}/trpc/registerRun",
json=data,
)
response.raise_for_status()
from .agent import UserAgent, StudioUserInput
UserAgent.override_class_input_method(
StudioUserInput(
studio_url=studio_url,
run_id=_config.run_id,
max_retries=3,
),
)
_equip_as_studio_hooks(studio_url)
if tracing_url:
endpoint = tracing_url
else:
endpoint = studio_url.strip("/") + "/v1/traces" if studio_url else None
if endpoint:
from .tracing import setup_tracing
setup_tracing(endpoint=endpoint)
__all__ = [
# modules
"exception",
"module",
"message",
"model",
"tool",
"formatter",
"memory",
"agent",
"session",
"logger",
"embedding",
"token",
"evaluate",
"pipeline",
"tracing",
"rag",
# functions
"init",
"setup_logger",
"__version__",
]

View File

@@ -0,0 +1,23 @@
# -*- coding: utf-8 -*-
"""The runtime configuration in agentscope.
.. note:: You should import this module as ``import ._config``, then use the
variables defined in this module, instead of ``from ._config import xxx``.
Because when the variables are changed, the changes will not be reflected in
the imported module.
"""
from datetime import datetime
import shortuuid
def _generate_random_suffix(length: int) -> str:
"""Generate a random suffix."""
return shortuuid.uuid()[:length]
project = "UnnamedProject_At" + datetime.now().strftime("%Y%m%d")
name = datetime.now().strftime("%H%M%S_") + _generate_random_suffix(4)
run_id: str = shortuuid.uuid()
created_at: str = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
trace_enabled: bool = False

View File

@@ -0,0 +1,47 @@
# -*- coding: utf-8 -*-
"""The logger for agentscope."""
import logging
_DEFAULT_FORMAT = (
"%(asctime)s | %(levelname)-7s | "
"%(module)s:%(funcName)s:%(lineno)s - %(message)s"
)
logger = logging.getLogger("as")
def setup_logger(
level: str,
filepath: str | None = None,
) -> None:
"""Set up the agentscope logger.
Args:
level (`str`):
The logging level, chosen from "INFO", "DEBUG", "WARNING",
"ERROR", "CRITICAL".
filepath (`str | None`, optional):
The filepath to save the logging output.
"""
if level not in ["INFO", "DEBUG", "WARNING", "ERROR", "CRITICAL"]:
raise ValueError(
f"Invalid logging level: {level}. Must be one of "
f"'INFO', 'DEBUG', 'WARNING', 'ERROR', 'CRITICAL'.",
)
logger.handlers.clear()
logger.setLevel(level)
handler = logging.StreamHandler()
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
logger.addHandler(handler)
if filepath:
handler = logging.FileHandler(filepath)
handler.setFormatter(logging.Formatter(_DEFAULT_FORMAT))
logger.addHandler(handler)
logger.propagate = False
setup_logger("INFO")

View File

@@ -0,0 +1,285 @@
# -*- coding: utf-8 -*-
"""The common utilities for agentscope library."""
import asyncio
import base64
import functools
import inspect
import json
import os
import tempfile
import types
import typing
import uuid
from datetime import datetime
from typing import Union, Any, Callable, Type, Dict
import requests
from json_repair import repair_json
from pydantic import BaseModel
from .._logging import logger
if typing.TYPE_CHECKING:
from mcp.types import Tool
else:
Tool = "mcp.types.Tool"
def _json_loads_with_repair(
json_str: str,
) -> Union[dict, list, str, float, int, bool, None]:
"""The given json_str maybe incomplete, e.g. '{"key', so we need to
repair and load it into a Python object.
"""
repaired = json_str
try:
repaired = repair_json(json_str)
except Exception:
pass
try:
return json.loads(repaired)
except json.JSONDecodeError as e:
raise ValueError(
f"Failed to decode JSON string `{json_str}` after repairing it "
f"into `{repaired}`. Error: {e}",
) from e
def _is_accessible_local_file(url: str) -> bool:
"""Check if the given URL is a local URL."""
return os.path.isfile(url)
def _get_timestamp(add_random_suffix: bool = False) -> str:
"""Get the current timestamp in the format YYYY-MM-DD HH:MM:SS.sss."""
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S.%f")[:-3]
if add_random_suffix:
# Add a random suffix to the timestamp
timestamp += f"_{os.urandom(3).hex()}"
return timestamp
async def _is_async_func(func: Callable) -> bool:
"""Check if the given function is an async function, including
coroutine functions, async generators, and coroutine objects.
"""
return (
inspect.iscoroutinefunction(func)
or inspect.isasyncgenfunction(func)
or isinstance(func, types.CoroutineType)
or isinstance(func, types.GeneratorType)
and asyncio.iscoroutine(func)
or isinstance(func, functools.partial)
and await _is_async_func(func.func)
)
async def _execute_async_or_sync_func(
func: Callable,
*args: Any,
**kwargs: Any,
) -> Any:
"""Execute an async or sync function based on its type.
Args:
func (`Callable`):
The function to be executed, which can be either async or sync.
*args (`Any`):
Positional arguments to be passed to the function.
**kwargs (`Any`):
Keyword arguments to be passed to the function.
Returns:
`Any`:
The result of the function execution.
"""
if await _is_async_func(func):
return await func(*args, **kwargs)
return func(*args, **kwargs)
def _get_bytes_from_web_url(
url: str,
max_retries: int = 3,
) -> str:
"""Get the bytes from a given URL.
Args:
url (`str`):
The URL to fetch the bytes from.
max_retries (`int`, defaults to `3`):
The maximum number of retries.
"""
for _ in range(max_retries):
try:
response = requests.get(url)
response.raise_for_status()
return response.content.decode("utf-8")
except UnicodeDecodeError:
return base64.b64encode(response.content).decode("ascii")
except Exception as e:
logger.info(
"Failed to fetch bytes from URL %s. Error %s. Retrying...",
url,
str(e),
)
raise RuntimeError(
f"Failed to fetch bytes from URL `{url}` after {max_retries} retries.",
)
def _save_base64_data(
media_type: str,
base64_data: str,
) -> str:
"""Save the base64 data to a temp file and return the file path. The
extension is guessed from the MIME type.
Args:
media_type (`str`):
The MIME type of the data, e.g. "image/png", "audio/mpeg".
base64_data (`str):
The base64 data to be saved.
"""
extension = "." + media_type.split("/")[-1]
with tempfile.NamedTemporaryFile(
suffix=f".{extension}",
delete=False,
) as temp_file:
decoded_data = base64.b64decode(base64_data)
temp_file.write(decoded_data)
temp_file.close()
return temp_file.name
def _extract_json_schema_from_mcp_tool(tool: Tool) -> dict[str, Any]:
"""Extract JSON schema from MCP tool."""
return {
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": {
"type": "object",
"properties": tool.inputSchema.get(
"properties",
{},
),
"required": tool.inputSchema.get(
"required",
[],
),
},
},
}
def _remove_title_field(schema: dict) -> None:
"""Remove the title field from the JSON schema to avoid
misleading the LLM."""
# The top level title field
if "title" in schema:
schema.pop("title")
# properties
if "properties" in schema:
for prop in schema["properties"].values():
if isinstance(prop, dict):
_remove_title_field(prop)
# items
if "items" in schema and isinstance(schema["items"], dict):
_remove_title_field(schema["items"])
# additionalProperties
if "additionalProperties" in schema and isinstance(
schema["additionalProperties"],
dict,
):
_remove_title_field(
schema["additionalProperties"],
)
def _create_tool_from_base_model(
structured_model: Type[BaseModel],
tool_name: str = "generate_structured_output",
) -> Dict[str, Any]:
"""Create a function tool definition from a Pydantic BaseModel.
This function converts a Pydantic BaseModel class into a tool definition
that can be used with function calling API. The resulting tool
definition includes the model's JSON schema as parameters, enabling
structured output generation by forcing the model to call this function
with properly formatted data.
Args:
structured_model (`Type[BaseModel]`):
A Pydantic BaseModel class that defines the expected structure
for the tool's output.
tool_name (`str`, default `"generate_structured_output"`):
The tool name that used to force the LLM to generate structured
output by calling this function.
Returns:
`Dict[str, Any]`: A tool definition dictionary compatible with
function calling API, containing type ("function") and
function dictionary with name, description, and parameters
(JSON schema).
.. code-block:: python
:caption: Example usage
from pydantic import BaseModel
class PersonInfo(BaseModel):
name: str
age: int
email: str
tool = _create_tool_from_base_model(PersonInfo, "extract_person")
print(tool["function"]["name"]) # extract_person
print(tool["type"]) # function
.. note:: The function automatically removes the 'title' field from
the JSON schema to ensure compatibility with function calling
format. This is handled by the internal ``_remove_title_field()``
function.
"""
schema = structured_model.model_json_schema()
_remove_title_field(schema)
tool_definition = {
"type": "function",
"function": {
"name": tool_name,
"description": "Generate the required structured output with "
"this function",
"parameters": schema,
},
}
return tool_definition
def _map_text_to_uuid(text: str) -> str:
"""Map the given text to a deterministic UUID string.
Args:
text (`str`):
The input text to be mapped to a UUID.
Returns:
`str`:
A deterministic UUID string derived from the input text.
"""
return str(uuid.uuid3(uuid.NAMESPACE_DNS, text))

View File

@@ -0,0 +1,9 @@
# -*- coding: utf-8 -*-
"""The mixin for agentscope."""
class DictMixin(dict):
"""The dictionary mixin that allows attribute-style access."""
__setattr__ = dict.__setitem__
__getattr__ = dict.__getitem__

View File

@@ -0,0 +1,4 @@
# -*- coding: utf-8 -*-
"""The version of agentscope."""
__version__ = "1.0.7"

View File

@@ -0,0 +1,24 @@
# -*- coding: utf-8 -*-
"""The agent base class."""
from ._agent_base import AgentBase
from ._react_agent_base import ReActAgentBase
from ._react_agent import ReActAgent
from ._user_input import (
UserInputBase,
UserInputData,
TerminalUserInput,
StudioUserInput,
)
from ._user_agent import UserAgent
__all__ = [
"AgentBase",
"ReActAgentBase",
"ReActAgent",
"UserInputData",
"UserInputBase",
"TerminalUserInput",
"StudioUserInput",
"UserAgent",
]

View File

@@ -0,0 +1,703 @@
# -*- coding: utf-8 -*-
"""The agent base class in agentscope."""
import asyncio
import io
import json
from asyncio import Task, Queue
from collections import OrderedDict
from copy import deepcopy
from typing import Callable, Any
import base64
import shortuuid
import numpy as np
from typing_extensions import deprecated
from ._agent_meta import _AgentMeta
from .._logging import logger
from ..module import StateModule
from ..message import (
Msg,
AudioBlock,
ToolUseBlock,
ToolResultBlock,
ImageBlock,
VideoBlock,
)
from ..types import AgentHookTypes
class AgentBase(StateModule, metaclass=_AgentMeta):
"""Base class for asynchronous agents."""
id: str
"""The agent's unique identifier, generated using shortuuid."""
supported_hook_types: list[str] = [
"pre_reply",
"post_reply",
"pre_print",
"post_print",
"pre_observe",
"post_observe",
]
"""Supported hook types for the agent base class."""
_class_pre_reply_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before the reply
function, taking `self` object, the input arguments as input, and
generating the modified arguments (if needed). Then input arguments of the
reply function will be re-organized into a keyword arguments dictionary.
If the one hook returns a new dictionary, the modified arguments will be
passed to the next hook or the original reply function."""
_class_post_reply_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
Msg, # output, the output message
],
Msg | None,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the reply
function, which takes the `self` object and deep copied
positional and keyword arguments (args and kwargs), and the output message
as input. If the hook returns a message, the new message will be passed
to the next hook or the original reply function. Otherwise, the original
output will be passed instead."""
_class_pre_print_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before printing,
which takes the `self` object, a deep copied arguments dictionary as input,
and output the modified arguments (if needed). """
_class_post_print_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
Any, # output, `None` if no output
],
Any,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the speak
function, which takes the `self` object as input."""
_class_pre_observe_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level hook functions that will be called before the observe
function, which takes the `self` object and a deep copied input
arguments dictionary as input. To change the input arguments, the hook
function needs to output the modified arguments dictionary, which will be
used as the input of the next hook function or the original observe
function."""
_class_post_observe_hooks: dict[
str,
Callable[
[
"AgentBase", # self
dict[str, Any], # kwargs
None, # The output, `None` if no output
],
None,
],
] = OrderedDict()
"""The class-level hook functions that will be called after the observe
function, which takes the `self` object as input."""
def __init__(self) -> None:
"""Initialize the agent."""
super().__init__()
self.id = shortuuid.uuid()
# The replying task and identify of the current replying
self._reply_task: Task | None = None
self._reply_id: str | None = None
# Initialize the instance-level hooks
self._instance_pre_print_hooks = OrderedDict()
self._instance_post_print_hooks = OrderedDict()
self._instance_pre_reply_hooks = OrderedDict()
self._instance_post_reply_hooks = OrderedDict()
self._instance_pre_observe_hooks = OrderedDict()
self._instance_post_observe_hooks = OrderedDict()
# The prefix used in streaming printing, which will save the
# accumulated text and audio streaming data for each message id.
# e.g. {"text": "xxx", "audio": (stream_obj, "{base64_data}")}
self._stream_prefix = {}
# The subscribers that will receive the reply message by their
# `observe` method. The key is the MsgHub id, and the value is the
# list of agents.
self._subscribers: dict[str, list[AgentBase]] = {}
# We add this variable in case developers want to disable the console
# output of the agent, e.g., in a production environment.
self._disable_console_output: bool = False
# The streaming message queue used to export the messages as a
# generator
self._disable_msg_queue: bool = True
self.msg_queue = None
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Receive the given message(s) without generating a reply.
Args:
msg (`Msg | list[Msg] | None`):
The message(s) to be observed.
"""
raise NotImplementedError(
f"The observe function is not implemented in"
f" {self.__class__.__name__} class.",
)
async def reply(self, *args: Any, **kwargs: Any) -> Msg:
"""The main logic of the agent, which generates a reply based on the
current state and input arguments."""
raise NotImplementedError(
"The reply function is not implemented in "
f"{self.__class__.__name__} class.",
)
async def print(self, msg: Msg, last: bool = True) -> None:
"""The function to display the message.
Args:
msg (`Msg`):
The message object to be printed.
last (`bool`, defaults to `True`):
Whether this is the last one in streaming messages. For
non-streaming message, this should always be `True`.
"""
if not self._disable_msg_queue:
await self.msg_queue.put((deepcopy(msg), last))
if self._disable_console_output:
return
# The accumulated textual content to print, including the text blocks
# and the thinking blocks
thinking_and_text_to_print = []
for block in msg.get_content_blocks():
if block["type"] == "audio":
self._process_audio_block(msg.id, block)
elif block["type"] == "text":
self._print_text_block(
msg.id,
name_prefix=msg.name,
text_content=block["text"],
thinking_and_text_to_print=thinking_and_text_to_print,
)
elif block["type"] == "thinking":
self._print_text_block(
msg.id,
name_prefix=f"{msg.name}(thinking)",
text_content=block["thinking"],
thinking_and_text_to_print=thinking_and_text_to_print,
)
elif last:
self._print_last_block(block, msg)
# Clean up resources if this is the last message in streaming
if last and msg.id in self._stream_prefix:
if "audio" in self._stream_prefix[msg.id]:
player, _ = self._stream_prefix[msg.id]["audio"]
# Close the miniaudio player
player.close()
stream_prefix = self._stream_prefix.pop(msg.id)
if "text" in stream_prefix and not stream_prefix["text"].endswith(
"\n",
):
print()
def _process_audio_block(
self,
msg_id: str,
audio_block: AudioBlock,
) -> None:
"""Process audio block content.
Args:
msg_id (`str`):
The unique identifier of the message
audio_block (`AudioBlock`):
The audio content block
"""
if "source" not in audio_block:
raise ValueError(
"The audio block must contain the 'source' field.",
)
if audio_block["source"]["type"] == "url":
import urllib.request
import wave
import sounddevice as sd
url = audio_block["source"]["url"]
try:
with urllib.request.urlopen(url) as response:
audio_data = response.read()
with wave.open(io.BytesIO(audio_data), "rb") as wf:
samplerate = wf.getframerate()
n_frames = wf.getnframes()
audio_frames = wf.readframes(n_frames)
# Convert byte data to numpy array
audio_np = np.frombuffer(audio_frames, dtype=np.int16)
# Play audio
sd.play(audio_np, samplerate)
sd.wait()
except Exception as e:
logger.error(
"Failed to play audio from url %s: %s",
url,
str(e),
)
elif audio_block["source"]["type"] == "base64":
data = audio_block["source"]["data"]
if msg_id not in self._stream_prefix:
self._stream_prefix[msg_id] = {}
audio_prefix = self._stream_prefix[msg_id].get("audio", None)
import sounddevice as sd
# The player and the prefix data is cached for streaming audio
if audio_prefix:
player, audio_prefix_data = audio_prefix
else:
player = sd.OutputStream(
samplerate=24000,
channels=1,
dtype=np.float32,
blocksize=1024,
latency="low",
)
player.start()
audio_prefix_data = ""
# play the audio data
new_audio_data = data[len(audio_prefix_data) :]
if new_audio_data:
audio_bytes = base64.b64decode(new_audio_data)
audio_np = np.frombuffer(audio_bytes, dtype=np.int16)
audio_float = audio_np.astype(np.float32) / 32768.0
# Write to the audio output stream
player.write(audio_float)
# save the player and the prefix data
self._stream_prefix[msg_id]["audio"] = (
player,
data,
)
else:
raise ValueError(
"Unsupported audio source type: "
f"{audio_block['source']['type']}",
)
def _print_text_block(
self,
msg_id: str,
name_prefix: str,
text_content: str,
thinking_and_text_to_print: list[str],
) -> None:
"""Print the text block and thinking block content.
Args:
msg_id (`str`):
The unique identifier of the message
name_prefix (`str`):
The prefix for the message, e.g. "{name}: " for text block and
"{name}(thinking): " for thinking block.
text_content (`str`):
The textual content to be printed.
thinking_and_text_to_print (`list[str]`):
A list of textual content to be printed together. Here we
gather the text and thinking blocks to print them together.
"""
thinking_and_text_to_print.append(
f"{name_prefix}: {text_content}",
)
# The accumulated text and thinking blocks to print
to_print = "\n".join(thinking_and_text_to_print)
# The text prefix that has been printed
if msg_id not in self._stream_prefix:
self._stream_prefix[msg_id] = {}
text_prefix = self._stream_prefix[msg_id].get("text", "")
# Only print when there is new text content
if len(to_print) > len(text_prefix):
print(to_print[len(text_prefix) :], end="")
# Save the printed text prefix
self._stream_prefix[msg_id]["text"] = to_print
def _print_last_block(
self,
block: ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock,
msg: Msg,
) -> None:
"""Process and print the last content block, and the block type
is not audio, text, or thinking.
Args:
block (`ToolUseBlock | ToolResultBlock | ImageBlock | VideoBlock`):
The content block to be printed
msg (`Msg`):
The message object
"""
text_prefix = self._stream_prefix.get(msg.id, {}).get("text", "")
if text_prefix:
# Add a newline to separate from previous text content
print_newline = "" if text_prefix.endswith("\n") else "\n"
print(
f"{print_newline}"
f"{json.dumps(block, indent=4, ensure_ascii=False)}",
)
else:
print(
f"{msg.name}:"
f" {json.dumps(block, indent=4, ensure_ascii=False)}",
)
async def __call__(self, *args: Any, **kwargs: Any) -> Msg:
"""Call the reply function with the given arguments."""
self._reply_id = shortuuid.uuid()
reply_msg: Msg | None = None
try:
self._reply_task = asyncio.current_task()
reply_msg = await self.reply(*args, **kwargs)
# The interruption is triggered by calling the interrupt method
except asyncio.CancelledError:
reply_msg = await self.handle_interrupt(*args, **kwargs)
finally:
# Broadcast the reply message to all subscribers
if reply_msg:
await self._broadcast_to_subscribers(reply_msg)
self._reply_task = None
return reply_msg
async def _broadcast_to_subscribers(
self,
msg: Msg | list[Msg] | None,
) -> None:
"""Broadcast the message to all subscribers."""
for subscribers in self._subscribers.values():
for subscriber in subscribers:
await subscriber.observe(msg)
async def handle_interrupt(
self,
*args: Any,
**kwargs: Any,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else."""
raise NotImplementedError(
f"The handle_interrupt function is not implemented in "
f"{self.__class__.__name__}",
)
async def interrupt(self, msg: Msg | list[Msg] | None = None) -> None:
"""Interrupt the current reply process."""
if self._reply_task and not self._reply_task.done():
self._reply_task.cancel(msg)
def register_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
"""Register a hook to the agent instance, which only takes effect
for the current instance.
Args:
hook_type (`str`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook. If the name is already registered, the
hook will be overwritten.
hook (`Callable`):
The hook function.
"""
if not isinstance(self, AgentBase):
raise TypeError(
"The register_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks[hook_name] = hook
def remove_instance_hook(
self,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
"""Remove an instance-level hook from the agent instance.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook to remove.
"""
if not isinstance(self, AgentBase):
raise TypeError(
"The remove_instance_hook method should be called on an "
f"instance of AsyncAgentBase, but got {self} of "
f"type {type(self)}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{self.__class__.__name__} instance.",
)
@classmethod
def register_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
hook: Callable,
) -> None:
"""The universal function to register a hook to the agent class, which
will take effect for all instances of the class.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook. If the name is already registered, the
hook will be overwritten.
hook (`Callable`):
The hook function.
"""
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks[hook_name] = hook
@classmethod
def remove_class_hook(
cls,
hook_type: AgentHookTypes,
hook_name: str,
) -> None:
"""Remove a class-level hook from the agent class.
Args:
hook_type (`AgentHookTypes`):
The type of the hook, indicating where the hook is to be
triggered.
hook_name (`str`):
The name of the hook to remove.
"""
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
if hook_name in hooks:
del hooks[hook_name]
else:
raise ValueError(
f"Hook '{hook_name}' not found in '{hook_type}' hooks of "
f"{cls.__name__} class.",
)
@classmethod
def clear_class_hooks(
cls,
hook_type: AgentHookTypes | None = None,
) -> None:
"""Clear all class-level hooks.
Args:
hook_type (`AgentHookTypes`, optional):
The type of the hook to clear. If not specified, all
class-level hooks will be cleared.
"""
if hook_type is None:
for typ in cls.supported_hook_types:
hooks = getattr(cls, f"_class_{typ}_hooks")
hooks.clear()
else:
assert (
hook_type in cls.supported_hook_types
), f"Invalid hook type: {hook_type}"
hooks = getattr(cls, f"_class_{hook_type}_hooks")
hooks.clear()
def clear_instance_hooks(
self,
hook_type: AgentHookTypes | None = None,
) -> None:
"""If `hook_type` is not specified, clear all instance-level hooks.
Otherwise, clear the specified type of instance-level hooks."""
if hook_type is None:
for typ in self.supported_hook_types:
if not hasattr(self, f"_instance_{typ}_hooks"):
raise ValueError(
f"Call super().__init__() in the constructor "
f"to initialize the instance-level hooks for "
f"{self.__class__.__name__}.",
)
hooks = getattr(self, f"_instance_{typ}_hooks")
hooks.clear()
else:
assert (
hook_type in self.supported_hook_types
), f"Invalid hook type: {hook_type}"
if not hasattr(self, f"_instance_{hook_type}_hooks"):
raise ValueError(
f"Call super().__init__() in the constructor "
f"to initialize the instance-level hooks for "
f"{self.__class__.__name__}.",
)
hooks = getattr(self, f"_instance_{hook_type}_hooks")
hooks.clear()
def reset_subscribers(
self,
msghub_name: str,
subscribers: list["AgentBase"],
) -> None:
"""Reset the subscribers of the agent.
Args:
msghub_name (`str`):
The name of the MsgHub that manages the subscribers.
subscribers (`list[AgentBase]`):
A list of agents that will receive the reply message from
this agent via their `observe` method.
"""
self._subscribers[msghub_name] = [_ for _ in subscribers if _ != self]
def remove_subscribers(self, msghub_name: str) -> None:
"""Remove the msghub subscribers by the given msg hub name.
Args:
msghub_name (`str`):
The name of the MsgHub that manages the subscribers.
"""
if msghub_name not in self._subscribers:
logger.warning(
"MsgHub named '%s' not found",
msghub_name,
)
else:
self._subscribers.pop(msghub_name)
@deprecated("Please use set_console_output_enabled() instead.")
def disable_console_output(self) -> None:
"""This function will disable the console output of the agent, e.g.
in a production environment to avoid messy logs."""
self._disable_console_output = True
def set_console_output_enabled(self, enabled: bool) -> None:
"""Enable or disable the console output of the agent. E.g. in a
production environment, you may want to disable the console output to
avoid messy logs.
Args:
enabled (`bool`):
If `True`, enable the console output. If `False`, disable
the console output.
"""
self._disable_console_output = not enabled
def set_msg_queue_enabled(
self,
enabled: bool,
queue: Queue | None = None,
) -> None:
"""Enable or disable the message queue for streaming outputs.
Args:
enabled (`bool`):
If `True`, enable the message queue to allow streaming
outputs. If `False`, disable the message queue.
queue (`Queue | None`, optional):
The queue instance that will be used to initialize the
message queue when `enable` is `True`.
"""
if enabled:
if queue is None:
if self.msg_queue is None:
self.msg_queue = asyncio.Queue(maxsize=100)
else:
self.msg_queue = queue
else:
self.msg_queue = None
self._disable_msg_queue = not enabled

View File

@@ -0,0 +1,180 @@
# -*- coding: utf-8 -*-
"""The metaclass for agents in agentscope."""
import inspect
from copy import deepcopy
from functools import wraps
from typing import (
Any,
Dict,
TYPE_CHECKING,
Callable,
)
from .._utils._common import _execute_async_or_sync_func
if TYPE_CHECKING:
from ._agent_base import AgentBase
else:
AgentBase = "AgentBase"
def _normalize_to_kwargs(
func: Callable,
self: Any,
*args: Any,
**kwargs: Any,
) -> dict:
"""Normalize the provided positional and keyword arguments into a
keyword arguments dictionary that matches the function signature."""
sig = inspect.signature(func)
try:
# Bind the provided arguments to the function signature
bound = sig.bind(self, *args, **kwargs)
# Apply the default values for parameters
bound.apply_defaults()
# Return the arguments in a dictionary format
res = dict(bound.arguments)
res.pop("self")
return res
except TypeError as e:
# If failed to bind, we raise a TypeError with more context
param_names = list(sig.parameters.keys())
provided_args = len(args)
provided_kwargs = list(kwargs.keys())
raise TypeError(
f"Failed to bind parameters for function '{func.__name__}': {e}\n"
f"Expected parameters: {param_names}\n"
f"Provided {provided_args} positional args and kwargs: "
f"{provided_kwargs}",
) from e
def _wrap_with_hooks(
original_func: Callable,
) -> Callable:
"""A decorator to wrap the original async function with pre- and post-hooks
Args:
original_func (`Callable`):
The original async function to be wrapped with hooks.
"""
func_name = original_func.__name__.replace("_", "")
@wraps(original_func)
async def async_wrapper(
self: AgentBase,
*args: Any,
**kwargs: Any,
) -> Any:
"""The wrapped function, which call the pre- and post-hooks before and
after the original function."""
# Unify all positional and keyword arguments into a keyword arguments
normalized_kwargs = _normalize_to_kwargs(
original_func,
self,
*args,
**kwargs,
)
current_normalized_kwargs = normalized_kwargs
assert (
hasattr(self, f"_instance_pre_{func_name}_hooks")
and hasattr(self, f"_instance_post_{func_name}_hooks")
and hasattr(self.__class__, f"_class_pre_{func_name}_hooks")
and hasattr(self.__class__, f"_class_post_{func_name}_hooks")
), f"Hooks for {func_name} not found in {self.__class__.__name__}"
# pre-hooks
pre_hooks = list(
getattr(self, f"_instance_pre_{func_name}_hooks").values(),
) + list(
getattr(self, f"_class_pre_{func_name}_hooks").values(),
)
for pre_hook in pre_hooks:
modified_keywords = await _execute_async_or_sync_func(
pre_hook,
self,
deepcopy(current_normalized_kwargs),
)
if modified_keywords is not None:
assert isinstance(modified_keywords, dict), (
f"Pre-hook must return a dict of keyword arguments, rather"
f" than {type(modified_keywords)} from hook "
f"{pre_hook.__name__}"
)
current_normalized_kwargs = modified_keywords
# original function
# handle positional and keyword arguments specifically
args = current_normalized_kwargs.get("args", [])
kwargs = current_normalized_kwargs.get("kwargs", {})
others = {
k: v
for k, v in current_normalized_kwargs.items()
if k not in ["args", "kwargs"]
}
current_output = await original_func(
self,
*args,
**others,
**kwargs,
)
# post_hooks
post_hooks = list(
getattr(self, f"_instance_post_{func_name}_hooks").values(),
) + list(
getattr(self, f"_class_post_{func_name}_hooks").values(),
)
for post_hook in post_hooks:
modified_output = await _execute_async_or_sync_func(
post_hook,
self,
deepcopy(current_normalized_kwargs),
deepcopy(current_output),
)
if modified_output is not None:
current_output = modified_output
return current_output
return async_wrapper
class _AgentMeta(type):
"""The agent metaclass that wraps the agent's reply, observe and print
functions with pre- and post-hooks."""
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the agent's functions with hooks."""
for func_name in [
"reply",
"print",
"observe",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
return super().__new__(mcs, name, bases, attrs)
class _ReActAgentMeta(_AgentMeta):
"""The ReAct metaclass that adds pre- and post-hooks for the _reasoning
and _acting functions."""
def __new__(mcs, name: Any, bases: Any, attrs: Dict) -> Any:
"""Wrap the ReAct agent's _reasoning and _acting functions with
hooks."""
for func_name in [
"_reasoning",
"_acting",
]:
if func_name in attrs:
attrs[func_name] = _wrap_with_hooks(attrs[func_name])
return super().__new__(mcs, name, bases, attrs)

View File

@@ -0,0 +1,767 @@
# -*- coding: utf-8 -*-
# pylint: disable=not-an-iterable
# mypy: disable-error-code="list-item"
"""ReAct agent class in agentscope."""
import asyncio
from typing import Type, Any, AsyncGenerator, Literal
import shortuuid
from pydantic import BaseModel, ValidationError, Field
from ._react_agent_base import ReActAgentBase
from .._logging import logger
from ..formatter import FormatterBase
from ..memory import MemoryBase, LongTermMemoryBase, InMemoryMemory
from ..message import Msg, ToolUseBlock, ToolResultBlock, TextBlock
from ..model import ChatModelBase
from ..rag import KnowledgeBase, Document
from ..plan import PlanNotebook
from ..tool import Toolkit, ToolResponse
from ..tracing import trace_reply
class _QueryRewriteModel(BaseModel):
"""The structured model used for query rewriting."""
rewritten_query: str = Field(
description=(
"The rewritten query, which should be specific and concise. "
),
)
def finish_function_pre_print_hook(
self: "ReActAgent",
kwargs: dict[str, Any],
) -> dict[str, Any] | None:
"""A pre-speak hook function that check if finish_function is called. If
so, it will wrap the response argument into a message and return it to
replace the original message. By this way, the calling of the finish
function will be displayed as a text reply instead of a tool call."""
msg = kwargs["msg"]
if isinstance(msg.content, str):
return None
if isinstance(msg.content, list):
for i, block in enumerate(msg.content):
if (
block["type"] == "tool_use"
and block["name"] == self.finish_function_name
):
# Convert the response argument into a text block for
# displaying
try:
msg.content[i] = TextBlock(
type="text",
text=block["input"].get("response", ""),
)
return kwargs
except Exception:
print("Error in block input", block["input"])
return None
class ReActAgent(ReActAgentBase):
"""A ReAct agent implementation in AgentScope, which supports
- Realtime steering
- API-based (parallel) tool calling
- Hooks around reasoning, acting, reply, observe and print functions
- Structured output generation
"""
finish_function_name: str = "generate_response"
"""The function name used to finish replying and return a response to
the user."""
def __init__(
self,
name: str,
sys_prompt: str,
model: ChatModelBase,
formatter: FormatterBase,
toolkit: Toolkit | None = None,
memory: MemoryBase | None = None,
long_term_memory: LongTermMemoryBase | None = None,
long_term_memory_mode: Literal[
"agent_control",
"static_control",
"both",
] = "both",
enable_meta_tool: bool = False,
parallel_tool_calls: bool = False,
knowledge: KnowledgeBase | list[KnowledgeBase] | None = None,
enable_rewrite_query: bool = True,
plan_notebook: PlanNotebook | None = None,
print_hint_msg: bool = False,
max_iters: int = 10,
) -> None:
"""Initialize the ReAct agent
Args:
name (`str`):
The name of the agent.
sys_prompt (`str`):
The system prompt of the agent.
model (`ChatModelBase`):
The chat model used by the agent.
formatter (`FormatterBase`):
The formatter used to format the messages into the required
format of the model API provider.
toolkit (`Toolkit | None`, optional):
A `Toolkit` object that contains the tool functions. If not
provided, a default empty `Toolkit` will be created.
memory (`MemoryBase | None`, optional):
The memory used to store the dialogue history. If not provided,
a default `InMemoryMemory` will be created, which stores
messages in a list in memory.
long_term_memory (`LongTermMemoryBase | None`, optional):
The optional long-term memory, which will provide two tool
functions: `retrieve_from_memory` and `record_to_memory`, and
will attach the retrieved information to the system prompt
before each reply.
enable_meta_tool (`bool`, defaults to `False`):
If `True`, a meta tool function `reset_equipped_tools` will be
added to the toolkit, which allows the agent to manage its
equipped tools dynamically.
long_term_memory_mode (`Literal['agent_control', 'static_control',\
'both']`, defaults to `both`):
The mode of the long-term memory. If `agent_control`, two
tool functions `retrieve_from_memory` and `record_to_memory`
will be registered in the toolkit to allow the agent to
manage the long-term memory. If `static_control`, retrieving
and recording will happen in the beginning and end of
each reply respectively.
parallel_tool_calls (`bool`, defaults to `False`):
When LLM generates multiple tool calls, whether to execute
them in parallel.
knowledge (`KnowledgeBase | list[KnowledgeBase] | None`, optional):
The knowledge object(s) used by the agent to retrieve
relevant documents at the beginning of each reply.
enable_rewrite_query (`bool`, defaults to `True`):
Whether ask the agent to rewrite the user input query before
retrieving from the knowledge base(s), e.g. rewrite "Who am I"
to "{user's name}" to get more relevant documents. Only works
when the knowledge base(s) is provided.
plan_notebook (`PlanNotebook | None`, optional):
The plan notebook instance, allow the agent to finish the
complex task by decomposing it into a sequence of subtasks.
print_hint_msg (`bool`, defaults to `False`):
Whether to print the hint messages, including the reasoning
hint from the plan notebook, the retrieved information from
the long-term memory and knowledge base(s).
max_iters (`int`, defaults to `10`):
The maximum number of iterations of the reasoning-acting loops.
"""
super().__init__()
assert long_term_memory_mode in [
"agent_control",
"static_control",
"both",
]
# Static variables in the agent
self.name = name
self._sys_prompt = sys_prompt
self.max_iters = max_iters
self.model = model
self.formatter = formatter
# -------------- Memory management --------------
# Record the dialogue history in the memory
self.memory = memory or InMemoryMemory()
# If provide the long-term memory, it will be used to retrieve info
# in the beginning of each reply, and the result will be added to the
# system prompt
self.long_term_memory = long_term_memory
# The long-term memory mode
self._static_control = long_term_memory and long_term_memory_mode in [
"static_control",
"both",
]
self._agent_control = long_term_memory and long_term_memory_mode in [
"agent_control",
"both",
]
# -------------- Tool management --------------
# If None, a default Toolkit will be created
self.toolkit = toolkit or Toolkit()
self.toolkit.register_tool_function(
getattr(self, self.finish_function_name),
)
if self._agent_control:
# Adding two tool functions into the toolkit to allow self-control
self.toolkit.register_tool_function(
long_term_memory.record_to_memory,
)
self.toolkit.register_tool_function(
long_term_memory.retrieve_from_memory,
)
# Add a meta tool function to allow agent-controlled tool management
if enable_meta_tool or plan_notebook:
self.toolkit.register_tool_function(
self.toolkit.reset_equipped_tools,
)
self.parallel_tool_calls = parallel_tool_calls
# -------------- RAG management --------------
# The knowledge base(s) used by the agent
if isinstance(knowledge, KnowledgeBase):
knowledge = [knowledge]
self.knowledge: list[KnowledgeBase] = knowledge or []
self.enable_rewrite_query = enable_rewrite_query
# -------------- Plan management --------------
# Equipped the plan-related tools provided by the plan notebook as
# a tool group named "plan_related". So that the agent can activate
# the plan tools by the meta tool function
self.plan_notebook = None
if plan_notebook:
self.plan_notebook = plan_notebook
# When enable_meta_tool is True, plan tools are in plan_related
# group and active by agent.
# Otherwise, plan tools in bassic group and always active.
if enable_meta_tool:
self.toolkit.create_tool_group(
"plan_related",
description=self.plan_notebook.description,
)
for tool in plan_notebook.list_tools():
self.toolkit.register_tool_function(
tool,
group_name="plan_related",
)
else:
for tool in plan_notebook.list_tools():
self.toolkit.register_tool_function(
tool,
)
# If print the reasoning hint messages
self.print_hint_msg = print_hint_msg
# The maximum number of iterations of the reasoning-acting loops
self.max_iters = max_iters
# The hint messages that will be attached to the prompt to guide the
# agent's behavior before each reasoning step, and cleared after
# each reasoning step, meaning the hint messages is one-time use only.
# We use an InMemoryMemory instance to store the hint messages
self._reasoning_hint_msgs = InMemoryMemory()
# Variables to record the intermediate state
# If required structured output model is provided
self._required_structured_model: Type[BaseModel] | None = None
# -------------- State registration and hooks --------------
# Register the status variables
self.register_state("name")
self.register_state("_sys_prompt")
self.register_instance_hook(
"pre_print",
"finish_function_pre_print_hook",
finish_function_pre_print_hook,
)
@property
def sys_prompt(self) -> str:
"""The dynamic system prompt of the agent."""
return self._sys_prompt
@trace_reply
async def reply(
self,
msg: Msg | list[Msg] | None = None,
structured_model: Type[BaseModel] | None = None,
) -> Msg:
"""Generate a reply based on the current state and input arguments.
Args:
msg (`Msg | list[Msg] | None`, optional):
The input message(s) to the agent.
structured_model (`Type[BaseModel] | None`, optional):
The required structured output model. If provided, the agent
is expected to generate structured output in the `metadata`
field of the output message.
Returns:
`Msg`:
The output message generated by the agent.
"""
# Record the input message(s) in the memory
await self.memory.add(msg)
# Retrieve relevant records from the long-term memory if activated
await self._retrieve_from_long_term_memory(msg)
# Retrieve relevant documents from the knowledge base(s) if any
await self._retrieve_from_knowledge(msg)
self._required_structured_model = structured_model
# Record structured output model if provided
if structured_model:
self.toolkit.set_extended_model(
self.finish_function_name,
structured_model,
)
# The reasoning-acting loop
reply_msg = None
for _ in range(self.max_iters):
msg_reasoning = await self._reasoning()
futures = [
self._acting(tool_call)
for tool_call in msg_reasoning.get_content_blocks(
"tool_use",
)
]
# Parallel tool calls or not
if self.parallel_tool_calls:
acting_responses = await asyncio.gather(*futures)
else:
# Sequential tool calls
acting_responses = [await _ for _ in futures]
# Find the first non-None replying message from the acting
for acting_msg in acting_responses:
reply_msg = reply_msg or acting_msg
if reply_msg:
break
# When the maximum iterations are reached
if reply_msg is None:
reply_msg = await self._summarizing()
# Post-process the memory, long-term memory
if self._static_control:
await self.long_term_memory.record(
[
*([*msg] if isinstance(msg, list) else [msg]),
*await self.memory.get_memory(),
reply_msg,
],
)
await self.memory.add(reply_msg)
return reply_msg
async def _reasoning(
self,
) -> Msg:
"""Perform the reasoning process."""
if self.plan_notebook:
# Insert the reasoning hint from the plan notebook
hint_msg = await self.plan_notebook.get_current_hint()
if self.print_hint_msg and hint_msg:
await self.print(hint_msg)
await self._reasoning_hint_msgs.add(hint_msg)
# Convert Msg objects into the required format of the model API
prompt = await self.formatter.format(
msgs=[
Msg("system", self.sys_prompt, "system"),
*await self.memory.get_memory(),
# The hint messages to guide the agent's behavior, maybe empty
*await self._reasoning_hint_msgs.get_memory(),
],
)
# Clear the hint messages after use
await self._reasoning_hint_msgs.clear()
res = await self.model(
prompt,
tools=self.toolkit.get_json_schemas(),
)
# handle output from the model
interrupted_by_user = False
msg = None
try:
if self.model.stream:
msg = Msg(self.name, [], "assistant")
async for content_chunk in res:
msg.content = content_chunk.content
await self.print(msg, False)
await self.print(msg, True)
# Add a tiny sleep to yield the last message object in the
# message queue
await asyncio.sleep(0.001)
else:
msg = Msg(self.name, list(res.content), "assistant")
await self.print(msg, True)
except asyncio.CancelledError as e:
interrupted_by_user = True
raise e from None
finally:
if msg and not msg.has_content_blocks("tool_use"):
# Turn plain text response into a tool call of the finish
# function
msg = Msg.from_dict(msg.to_dict())
msg.content = [
ToolUseBlock(
id=shortuuid.uuid(),
type="tool_use",
name=self.finish_function_name,
input={"response": msg.get_text_content()},
),
]
# None will be ignored by the memory
await self.memory.add(msg)
# Post-process for user interruption
if interrupted_by_user and msg:
# Fake tool results
tool_use_blocks: list = msg.get_content_blocks(
"tool_use",
)
for tool_call in tool_use_blocks:
msg_res = Msg(
"system",
[
ToolResultBlock(
type="tool_result",
id=tool_call["id"],
name=tool_call["name"],
output="The tool call has been interrupted "
"by the user.",
),
],
"system",
)
await self.memory.add(msg_res)
await self.print(msg_res, True)
return msg
async def _acting(self, tool_call: ToolUseBlock) -> Msg | None:
"""Perform the acting process.
Args:
tool_call (`ToolUseBlock`):
The tool use block to be executed.
Returns:
`Union[Msg, None]`:
Return a message to the user if the `finish_function` is
called, otherwise return `None`.
"""
tool_res_msg = Msg(
"system",
[
ToolResultBlock(
type="tool_result",
id=tool_call["id"],
name=tool_call["name"],
output=[],
),
],
"system",
)
try:
# Execute the tool call
tool_res = await self.toolkit.call_tool_function(tool_call)
response_msg = None
# Async generator handling
async for chunk in tool_res:
# Turn into a tool result block
tool_res_msg.content[0][ # type: ignore[index]
"output"
] = chunk.content
# Skip the printing of the finish function call
if (
tool_call["name"] != self.finish_function_name
or tool_call["name"] == self.finish_function_name
and (
chunk.metadata is None
or not chunk.metadata.get("success")
)
):
await self.print(tool_res_msg, chunk.is_last)
# Raise the CancelledError to handle the interruption in the
# handle_interrupt function
if chunk.is_interrupted:
raise asyncio.CancelledError()
# Return message if generate_response is called successfully
if (
tool_call["name"] == self.finish_function_name
and chunk.metadata
and chunk.metadata.get(
"success",
True,
)
):
response_msg = chunk.metadata.get("response_msg")
return response_msg
finally:
# Record the tool result message in the memory
await self.memory.add(tool_res_msg)
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Receive observing message(s) without generating a reply.
Args:
msg (`Msg | list[Msg] | None`):
The message or messages to be observed.
"""
await self.memory.add(msg)
async def _summarizing(self) -> Msg:
"""Generate a response when the agent fails to solve the problem in
the maximum iterations."""
hint_msg = Msg(
"user",
"You have failed to generate response within the maximum "
"iterations. Now respond directly by summarizing the current "
"situation.",
role="user",
)
# Generate a reply by summarizing the current situation
prompt = await self.formatter.format(
[
Msg("system", self.sys_prompt, "system"),
*await self.memory.get_memory(),
hint_msg,
],
)
# TODO: handle the structured output here, maybe force calling the
# finish_function here
res = await self.model(prompt)
res_msg = Msg(self.name, [], "assistant")
if isinstance(res, AsyncGenerator):
async for chunk in res:
res_msg.content = chunk.content
await self.print(res_msg, False)
await self.print(res_msg, True)
else:
res_msg.content = res.content
await self.print(res_msg, True)
return res_msg
async def handle_interrupt(
self,
_msg: Msg | list[Msg] | None = None,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else."""
response_msg = Msg(
self.name,
"I noticed that you have interrupted me. What can I "
"do for you?",
"assistant",
metadata={
# Expose this field to indicate the interruption
"is_interrupted": True,
},
)
await self.print(response_msg, True)
await self.memory.add(response_msg)
return response_msg
def generate_response(
self,
response: str,
**kwargs: Any,
) -> ToolResponse:
"""Generate a response. Note only the input argument `response` is
visible to the others, you should include all the necessary
information in the `response` argument.
Args:
response (`str`):
Your response to the user.
"""
response_msg = Msg(
self.name,
response,
"assistant",
)
# Prepare structured output
if self._required_structured_model:
try:
# Use the metadata field of the message to store the
# structured output
response_msg.metadata = (
self._required_structured_model.model_validate(
kwargs,
).model_dump()
)
except ValidationError as e:
return ToolResponse(
content=[
TextBlock(
type="text",
text=f"Arguments Validation Error: {e}",
),
],
metadata={
"success": False,
"response_msg": None,
},
)
return ToolResponse(
content=[
TextBlock(
type="text",
text="Successfully generated response.",
),
],
metadata={
"success": True,
"response_msg": response_msg,
},
is_last=True,
)
async def _retrieve_from_long_term_memory(
self,
msg: Msg | list[Msg] | None,
) -> None:
"""Insert the retrieved information from the long-term memory into
the short-term memory as a Msg object.
Args:
msg (`Msg | list[Msg] | None`):
The input message to the agent.
"""
if self._static_control and msg:
# Retrieve information from the long-term memory if available
retrieved_info = await self.long_term_memory.retrieve(msg)
if retrieved_info:
retrieved_msg = Msg(
name="long_term_memory",
content="<long_term_memory>The content below are "
"retrieved from long-term memory, which maybe "
f"useful:\n{retrieved_info}</long_term_memory>",
role="user",
)
if self.print_hint_msg:
await self.print(retrieved_msg, True)
await self.memory.add(retrieved_msg)
async def _retrieve_from_knowledge(
self,
msg: Msg | list[Msg] | None,
) -> None:
"""Insert the retrieved documents from the RAG knowledge base(s) if
available.
Args:
msg (`Msg | list[Msg] | None`):
The input message to the agent.
"""
if self.knowledge and msg:
# Prepare the user input query
query = None
if isinstance(msg, Msg):
query = msg.get_text_content()
elif isinstance(msg, list):
query = "\n".join(_.get_text_content() for _ in msg)
# Skip if the query is empty
if not query:
return
# Rewrite the query by the LLM if enabled
if self.enable_rewrite_query:
try:
rewrite_prompt = await self.formatter.format(
msgs=[
Msg("system", self.sys_prompt, "system"),
*await self.memory.get_memory(),
Msg(
"user",
"<system-hint>Now you need to rewrite "
"the above user query to be more specific and "
"concise for knowledge retrieval. For "
"example, rewrite the query 'what happened "
"last day' to 'what happened on 2023-10-01' "
"(assuming today is 2023-10-02)."
"</system-hint>",
"user",
),
],
)
stream_tmp = self.model.stream
self.model.stream = False
res = await self.model(
rewrite_prompt,
structured_model=_QueryRewriteModel,
)
self.model.stream = stream_tmp
if res.metadata and res.metadata.get("rewritten_query"):
query = res.metadata["rewritten_query"]
except Exception as e:
logger.warning(
"Skipping the query rewriting due to error: %s",
str(e),
)
docs: list[Document] = []
for kb in self.knowledge:
# retrieve the user input query
docs.extend(
await kb.retrieve(query=query),
)
if docs:
# Rerank by the relevance score
docs = sorted(
docs,
key=lambda doc: doc.score or 0.0,
reverse=True,
)
# Prepare the retrieved knowledge string
retrieved_msg = Msg(
name="user",
content=[
TextBlock(
type="text",
text=(
"<retrieved_knowledge>Use the following "
"content from the knowledge base(s) if it's "
"helpful:\n"
),
),
*[_.metadata.content for _ in docs],
TextBlock(
type="text",
text="</retrieved_knowledge>",
),
],
role="user",
)
if self.print_hint_msg:
await self.print(retrieved_msg, True)
await self.memory.add(retrieved_msg)

View File

@@ -0,0 +1,116 @@
# -*- coding: utf-8 -*-
"""The base class for ReAct agent in agentscope."""
from abc import abstractmethod
from collections import OrderedDict
from typing import Callable, Any
from ._agent_base import AgentBase
from ._agent_meta import _ReActAgentMeta
from ..message import Msg
class ReActAgentBase(AgentBase, metaclass=_ReActAgentMeta):
"""The ReAct agent base class.
To support ReAct algorithm, this class extends the AgentBase class by
adding two abstract interfaces: reasoning and acting, while supporting
hook functions at four positions: pre-reasoning, post-reasoning,
pre-acting, and post-acting by the `_ReActAgentMeta` metaclass.
"""
supported_hook_types: list[str] = [
"pre_reply",
"post_reply",
"pre_print",
"post_print",
"pre_observe",
"post_observe",
"pre_reasoning",
"post_reasoning",
"pre_acting",
"post_acting",
]
"""Supported hook types for the agent base class."""
_class_pre_reasoning_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level pre-reasoning hooks, taking `self` object, the input
arguments as input"""
_class_post_reasoning_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
Any, # output
],
Msg | None, # the modified output message or None
],
] = OrderedDict()
"""The class-level post-reasoning hooks, taking `self` object, the input
arguments and the output message as input, and return the modified output
message or None if no modification is needed."""
_class_pre_acting_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
],
dict[str, Any] | None, # The modified kwargs or None
],
] = OrderedDict()
"""The class-level pre-acting hooks, taking `self` object, the input
arguments as input, and return the modified input arguments or None if no
modification is needed."""
_class_post_acting_hooks: dict[
str,
Callable[
[
"ReActAgentBase", # self
dict[str, Any], # kwargs
Any, # output
],
Msg | None, # the modified output message or None
],
] = OrderedDict()
"""The class-level post-acting hooks, taking `self` object, the input
arguments and the output message as input, and return the modified output
message or None if no modification is needed."""
def __init__(
self,
) -> None:
"""Initialize the ReAct agent base class."""
super().__init__()
# Init reasoning and acting hooks
self._instance_pre_reasoning_hooks = OrderedDict()
self._instance_post_reasoning_hooks = OrderedDict()
self._instance_pre_acting_hooks = OrderedDict()
self._instance_post_acting_hooks = OrderedDict()
@abstractmethod
async def _reasoning(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""The reasoning process of the ReAct agent, which will be wrapped
with pre- and post-hooks."""
@abstractmethod
async def _acting(self, *args: Any, **kwargs: Any) -> Any:
"""The acting process of the ReAct agent, which will be wrapped with
pre- and post-hooks."""

View File

@@ -0,0 +1,128 @@
# -*- coding: utf-8 -*-
"""The user agent class."""
from typing import Type, Any
from pydantic import BaseModel
from ._agent_base import AgentBase
from ._user_input import UserInputBase, TerminalUserInput
from ..message import Msg
class UserAgent(AgentBase):
"""The class for user interaction, allowing developers to handle the user
input from different sources, such as web UI, cli, and other interfaces.
"""
_input_method: UserInputBase = TerminalUserInput()
"""The user input method, can be overridden by calling the
`register_instance/class_input_method` function."""
def __init__(
self,
name: str,
) -> None:
"""Initialize the user agent with a name."""
super().__init__()
self.name = name
async def reply(
self,
msg: Msg | list[Msg] | None = None,
structured_model: Type[BaseModel] | None = None,
) -> Msg:
"""Receive input message(s) and generate a reply message from the user.
Args:
msg (`Msg | list[Msg] | None`, defaults to `None`):
The message(s) to be replied. If `None`, the agent will wait
for user input.
structured_model (`Type[BaseModel] | None`, defaults to `None`):
A child class of `pydantic.BaseModel` that defines the
structured output format. If provided, the user will be
prompted to fill in the required fields.
Returns:
`Msg`:
The reply message generated by the user.
"""
# Get the input from the specified input method.
input_data = await self._input_method(
agent_id=self.id,
agent_name=self.name,
structured_model=structured_model,
)
blocks_input = input_data.blocks_input
if (
blocks_input
and len(blocks_input) == 1
and blocks_input[0].get("type") == "text"
):
# Turn blocks_input into a string if only one text block exists
blocks_input = blocks_input[0].get("text")
msg = Msg(
self.name,
content=blocks_input,
role="user",
metadata=input_data.structured_input,
)
await self.print(msg)
return msg
def override_instance_input_method(
self,
input_method: UserInputBase,
) -> None:
"""Override the input method of the current UserAgent instance.
Args:
input_method (`UserInputBase`):
The callable input method, which should be an object of a
class that inherits from `UserInputBase`.
"""
if not isinstance(input_method, UserInputBase):
raise ValueError(
f"The input method should be an instance of the child class "
f"of `UserInputBase`, but got {type(input_method)} instead.",
)
self._input_method = input_method
@classmethod
def override_class_input_method(
cls,
input_method: UserInputBase,
) -> None:
"""Override the input method of the current UserAgent class.
Args:
input_method (`UserInputBase`):
The callable input method, which should be an object of a
class that inherits from `UserInputBase`.
"""
if not isinstance(input_method, UserInputBase):
raise ValueError(
f"The input method should be an instance of the child class "
f"of `UserInputBase`, but got {type(input_method)} instead.",
)
cls._input_method = input_method
async def handle_interrupt(
self,
*args: Any,
**kwargs: Any,
) -> Msg:
"""The post-processing logic when the reply is interrupted by the
user or something else."""
raise NotImplementedError(
f"The handle_interrupt function is not implemented in "
f"{self.__class__.__name__}",
)
async def observe(self, msg: Msg | list[Msg] | None) -> None:
"""Observe the message(s) from the other agents or the environment."""

View File

@@ -0,0 +1,411 @@
# -*- coding: utf-8 -*-
"""The user input related classes."""
import json.decoder
import time
from abc import abstractmethod
from dataclasses import dataclass
from queue import Queue
from threading import Event
from typing import Any, Type, List
import jsonschema
import requests
import shortuuid
import socketio
from pydantic import BaseModel
import json5
from .. import _config
from .._logging import logger
from ..message import (
TextBlock,
VideoBlock,
AudioBlock,
ImageBlock,
)
@dataclass
class UserInputData:
"""The user input data."""
blocks_input: List[TextBlock | ImageBlock | AudioBlock | VideoBlock] = None
"""The text input from the user"""
structured_input: dict[str, Any] | None = None
"""The structured input from the user"""
class UserInputBase:
"""The base class used to handle the user input from different sources."""
@abstractmethod
async def __call__(
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> UserInputData:
"""The user input method, which returns the user input and the
required structured data.
Args:
agent_id (`str`):
The agent identifier.
agent_name (`str`):
The agent name.
structured_model (`Type[BaseModel] | None`, optional):
A base model class that defines the structured input format.
Returns:
`UserInputData`:
The user input data.
"""
class TerminalUserInput(UserInputBase):
"""The terminal user input."""
def __init__(self, input_hint: str = "User Input: ") -> None:
"""Initialize the terminal user input with a hint."""
self.input_hint = input_hint
async def __call__(
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_model: Type[BaseModel] | None = None,
**kwargs: Any,
) -> UserInputData:
"""Handle the user input from the terminal.
Args:
agent_id (`str`):
The agent identifier.
agent_name (`str`):
The agent name.
structured_model (`Type[BaseModel] | None`, optional):
A base model class that defines the structured input format.
Returns:
`UserInputData`:
The user input data.
"""
text_input = input(self.input_hint)
structured_input = None
if structured_model is not None:
structured_input = {}
json_schema = structured_model.model_json_schema()
required = json_schema.get("required", [])
print("Structured input (press Enter to skip for optional):)")
for key, item in json_schema.get("properties").items():
requirements = {**item}
requirements.pop("title")
while True:
res = input(f"\t{key} ({requirements}): ")
if res == "":
if key in required:
print(f"Key {key} is required.")
continue
res = item.get("default", None)
if item.get("type").lower() == "integer":
try:
res = json5.loads(res)
except json.decoder.JSONDecodeError as e:
print(
"\033[31mInvalid input with error:\n"
"```\n"
f"{e}\n"
"```\033[0m",
)
continue
try:
jsonschema.validate(res, item)
structured_input[key] = res
break
except jsonschema.ValidationError as e:
print(
f"\033[31mValidation error:\n```\n{e}\n```\033[0m",
)
time.sleep(0.5)
return UserInputData(
blocks_input=[TextBlock(type="text", text=text_input)],
structured_input=structured_input,
)
class StudioUserInput(UserInputBase):
"""The class that host the user input on the AgentScope Studio."""
_websocket_namespace: str = "/python"
def __init__(
self,
studio_url: str,
run_id: str,
max_retries: int = 3,
reconnect_attempts: int = 3,
reconnection_delay: int = 1,
reconnection_delay_max: int = 5,
) -> None:
"""Initialize the StudioUserInput object.
Args:
studio_url (`str`):
The URL of the AgentScope Studio.
run_id (`str`):
The current run identity.
max_retries (`int`, defaults to `3`):
The maximum number of retries to get user input.
"""
self._is_connected = False
self._is_reconnecting = False
self.studio_url = studio_url
self.run_id = run_id
self.max_retries = max_retries
# Init Websocket
self.sio = socketio.Client(
reconnection=True,
reconnection_attempts=reconnect_attempts,
reconnection_delay=reconnection_delay,
reconnection_delay_max=reconnection_delay_max,
)
self.input_queues = {}
self.input_events = {}
@self.sio.on("connect", namespace=self._websocket_namespace)
def on_connect() -> None:
self._is_connected = True
logger.info(
'Connected to AgentScope Studio at "%s" with '
'run name "%s".',
self.studio_url,
run_id,
)
logger.info(
"View the run at: %s/dashboard/projects/%s",
self.studio_url,
_config.project,
)
@self.sio.on("disconnect", namespace=self._websocket_namespace)
def on_disconnect() -> None:
self._is_connected = False
logger.info(
"Disconnected from AgentScope Studio at %s",
self.studio_url,
)
@self.sio.on("reconnect", namespace=self._websocket_namespace)
def on_reconnect(attempt_number: int) -> None:
self._is_connected = True
self._is_reconnecting = False
logger.info(
"Reconnected to AgentScope Studio at %s with run_id %s after "
"%d attempts",
self.studio_url,
self.run_id,
attempt_number,
)
@self.sio.on("reconnect_attempt", namespace=self._websocket_namespace)
def on_reconnect_attempt(attempt_number: int) -> None:
self._is_reconnecting = True
logger.info(
"Attempting to reconnect to AgentScope Studio at %s "
"(attempt %d)",
self.studio_url,
attempt_number,
)
@self.sio.on("reconnect_failed", namespace=self._websocket_namespace)
def on_reconnect_failed() -> None:
self._is_reconnecting = False
logger.error(
"Failed to reconnect to AgentScope Studio at %s",
self.studio_url,
)
@self.sio.on("reconnect_error", namespace=self._websocket_namespace)
def on_reconnect_error(error: Any) -> None:
logger.error(
"Error while reconnecting to AgentScope Studio at %s: %s",
self.studio_url,
str(error),
)
# The AgentScope Studio backend send the "sendUserInput" event to
# the current python run
@self.sio.on("forwardUserInput", namespace=self._websocket_namespace)
def receive_user_input(
request_id: str,
blocks_input: List[
TextBlock | ImageBlock | AudioBlock | VideoBlock
],
structured_input: dict[str, Any],
) -> None:
if request_id in self.input_queues:
self.input_queues[request_id].put(
UserInputData(
blocks_input=blocks_input,
structured_input=structured_input,
),
)
self.input_events[request_id].set()
try:
self.sio.connect(
f"{self.studio_url}",
namespaces=["/python"],
auth={"run_id": self.run_id},
)
except Exception as e:
raise RuntimeError(
f"Failed to connect to AgentScope Studio at {self.studio_url}",
) from e
def _ensure_connected(
self,
timeout: float = 30.0,
check_interval: float = 5.0,
) -> None:
"""Ensure the connection is established or wait for reconnection.
Args:
timeout (`float`):
Maximum time to wait for reconnection in seconds. Defaults
to 30.0.
check_interval (`float`):
Interval between connection checks in seconds. Defaults to 1.0.
Raises:
`RuntimeError`:
If connection cannot be established within timeout.
"""
if self._is_connected:
return
if self._is_reconnecting:
start_time = time.time()
while self._is_reconnecting:
# Check timeout
elapsed_time = time.time() - start_time
if elapsed_time > timeout:
raise RuntimeError(
f"Reconnection timeout after {elapsed_time} seconds",
)
# Log status
logger.info(
"Waiting for reconnection... (%.1fs / %.1fs)",
elapsed_time,
timeout,
)
# Wait for next check
time.sleep(check_interval)
# After reconnection attempt completed, check final status
if self._is_connected:
return
# Not connected and not reconnecting
raise RuntimeError(
f"Not connected to AgentScope Studio at {self.studio_url}.",
)
async def __call__( # type: ignore[override]
self,
agent_id: str,
agent_name: str,
*args: Any,
structured_model: Type[BaseModel] | None = None,
) -> UserInputData:
"""Get the user input from AgentScope Studio.
Args:
agent_id (`str`):
The identity of the agent.
agent_name (`str`):
The name of the agent.
structured_model (`Type[BaseModel] | None`, optional):
The base model class of the structured input.
Raises:
`RuntimeError`:
Failed to get user input from AgentScope Studio.
Returns:
`UserInputData`:
The user input.
"""
self._ensure_connected()
request_id = shortuuid.uuid()
self.input_queues[request_id] = Queue()
self.input_events[request_id] = Event()
if structured_model is None:
structured_input = None
else:
structured_input = structured_model.model_json_schema()
n_retry = 0
while True:
try:
response = requests.post(
f"{self.studio_url}/trpc/requestUserInput",
json={
"requestId": request_id,
"runId": self.run_id,
"agentId": agent_id,
"agentName": agent_name,
"structuredInput": structured_input,
},
)
response.raise_for_status()
break
except Exception as e:
if n_retry < self.max_retries:
n_retry += 1
continue
raise RuntimeError(
"Failed to get user input from AgentScope Studio",
) from e
try:
self.input_events[request_id].wait()
response_data = self.input_queues[request_id].get()
return response_data
finally:
self.input_queues.pop(request_id, None)
self.input_events.pop(request_id, None)
def __del__(self) -> None:
"""Cleanup socket connection when object it destroyed"""
try:
self.sio.disconnect()
except Exception as e:
logger.error(
"Failed to disconnect from AgentScope Studio at %s: %s",
self.studio_url,
str(e),
)

View File

@@ -0,0 +1,27 @@
# -*- coding: utf-8 -*-
"""The embedding module in agentscope."""
from ._embedding_base import EmbeddingModelBase
from ._embedding_usage import EmbeddingUsage
from ._embedding_response import EmbeddingResponse
from ._dashscope_embedding import DashScopeTextEmbedding
from ._dashscope_multimodal_embedding import DashScopeMultiModalEmbedding
from ._openai_embedding import OpenAITextEmbedding
from ._gemini_embedding import GeminiTextEmbedding
from ._ollama_embedding import OllamaTextEmbedding
from ._cache_base import EmbeddingCacheBase
from ._file_cache import FileEmbeddingCache
__all__ = [
"EmbeddingModelBase",
"EmbeddingUsage",
"EmbeddingResponse",
"DashScopeTextEmbedding",
"DashScopeMultiModalEmbedding",
"OpenAITextEmbedding",
"GeminiTextEmbedding",
"OllamaTextEmbedding",
"EmbeddingCacheBase",
"FileEmbeddingCache",
]

View File

@@ -0,0 +1,63 @@
# -*- coding: utf-8 -*-
"""The embedding cache base class."""
from abc import abstractmethod
from typing import List, Any
from ..types import (
JSONSerializableObject,
Embedding,
)
class EmbeddingCacheBase:
"""Base class for embedding caches, which is responsible for storing and
retrieving embeddings."""
@abstractmethod
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings.
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
@abstractmethod
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not
found, return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings.
"""
@abstractmethod
async def remove(
self,
identifier: JSONSerializableObject,
) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifier to remove the embeddings.
"""
@abstractmethod
async def clear(self) -> None:
"""Clear all cached embeddings."""

View File

@@ -0,0 +1,169 @@
# -*- coding: utf-8 -*-
"""The dashscope embedding module in agentscope."""
from datetime import datetime
from typing import Any, List, Literal
from ._cache_base import EmbeddingCacheBase
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._embedding_base import EmbeddingModelBase
from .._logging import logger
from ..message import TextBlock
class DashScopeTextEmbedding(EmbeddingModelBase):
"""DashScope text embedding API class.
.. note:: From the `official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_:
- The max batch size that DashScope text embedding API
supports is 10 for `text-embedding-v4` and `text-embedding-v3` models, and
25 for `text-embedding-v2` and `text-embedding-v1` models.
- The max token limit for a single input is 8192 tokens for `v4` and `v3`
models, and 2048 tokens for `v2` and `v1` models.
"""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
) -> None:
"""Initialize the DashScope text embedding model class.
Args:
api_key (`str`):
The dashscope API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector, refer to the
`official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712515>`_
for more details.
embedding_cache (`EmbeddingCacheBase`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
super().__init__(model_name, dimensions)
self.api_key = api_key
self.embedding_cache = embedding_cache
self.batch_size_limit = 10
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
"""Call the DashScope embedding API by the given keyword arguments."""
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
import dashscope
start_time = datetime.now()
response = dashscope.embeddings.TextEmbedding.call(
api_key=self.api_key,
**kwargs,
)
time = (datetime.now() - start_time).total_seconds()
if response.status_code != 200:
raise RuntimeError(
f"Failed to get embedding from DashScope API: {response}",
)
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[
_["embedding"] for _ in response.output["embeddings"]
],
)
return EmbeddingResponse(
embeddings=[_["embedding"] for _ in response.output["embeddings"]],
usage=EmbeddingUsage(
tokens=response.usage["total_tokens"],
time=time,
),
)
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the DashScope embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
if len(gather_text) > self.batch_size_limit:
logger.info(
"The input texts (%d) will be embedded with %d API calls due "
f"to the batch size limit of {self.batch_size_limit} for "
f"DashScope embedding API.",
len(gather_text),
(len(gather_text) + self.batch_size_limit - 1)
// self.batch_size_limit,
)
# Handle the batch size limit for DashScope embedding API
collected_embeddings = []
collected_time = 0.0
collected_tokens = 0
collected_source: Literal["cache", "api"] = "cache"
for _ in range(0, len(gather_text), self.batch_size_limit):
batch_texts = gather_text[_ : _ + self.batch_size_limit]
batch_kwargs = {
"input": batch_texts,
"model": self.model_name,
"dimension": self.dimensions,
**kwargs,
}
res = await self._call_api(batch_kwargs)
collected_embeddings.extend(res.embeddings)
collected_time += res.usage.time
if res.usage.tokens:
collected_tokens += res.usage.tokens
if res.source == "api":
collected_source = "api"
return EmbeddingResponse(
embeddings=collected_embeddings,
usage=EmbeddingUsage(
tokens=collected_tokens,
time=collected_time,
),
source=collected_source,
)

View File

@@ -0,0 +1,244 @@
# -*- coding: utf-8 -*-
"""The dashscope multimodal embedding model in agentscope."""
from datetime import datetime
from typing import Any, Literal
from ._cache_base import EmbeddingCacheBase
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._embedding_base import EmbeddingModelBase
from ..message import (
VideoBlock,
ImageBlock,
TextBlock,
)
class DashScopeMultiModalEmbedding(EmbeddingModelBase):
"""The DashScope multimodal embedding API, supporting text, image and
video embedding."""
supported_modalities: list[str] = ["text", "image", "video"]
"""This class supports text, image and video input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int | None = None,
embedding_cache: EmbeddingCacheBase | None = None,
) -> None:
"""Initialize the DashScope multimodal embedding model class.
Args:
api_key (`str`):
The dashscope API key.
model_name (`str`):
The name of the embedding model, e.g. "multimodal-embedding-
v1", "tongyi-embedding-vision-plus".
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector, refer to the
`official documentation
<https://bailian.console.aliyun.com/?tab=api#/api/?type=model&url=2712517>`_
for more details.
embedding_cache (`EmbeddingCacheBase`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
path_doc = (
"https://bailian.console.aliyun.com/?tab=api#/api/?type=model&"
"url=2712517"
)
self.batch_size_limit = 1
if model_name.startswith("tongyi-embedding-vision-plus"):
self.batch_size_limit = 8
if dimensions is None:
dimensions = 1152
elif dimensions != 1152:
raise ValueError(
f"The dimension of model {model_name} must be 1152, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
if model_name.startswith("tongyi-embedding-vision-flash"):
self.batch_size_limit = 8
if dimensions is None:
dimensions = 768
elif dimensions != 768:
raise ValueError(
f"The dimension of model {model_name} must be 768, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
if model_name.startswith("multimodal-embedding-v"):
if dimensions is None:
dimensions = 1024
elif dimensions != 1024:
raise ValueError(
f"The dimension of model {model_name} must be 1024, "
"refer to the official documentation for more details: "
f"{path_doc}",
)
refined_dimensions: int = 1024
if dimensions is not None:
refined_dimensions = dimensions
super().__init__(model_name, refined_dimensions)
self.api_key = api_key
self.embedding_cache = embedding_cache
async def __call__(
self,
inputs: list[TextBlock | ImageBlock | VideoBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the DashScope multimodal embedding API, which accepts text,
image, and video data.
Args:
inputs (`list[TextBlock | ImageBlock | VideoBlock]`):
The input data to be embedded. It can be a list of text,
image, and video blocks.
Returns:
`EmbeddingResponse`:
The embedding response object, which contains the embeddings
and usage information.
"""
# check data type
formatted_data = []
for _ in inputs:
if (
not isinstance(_, dict)
or "type" not in _
or _["type"]
not in [
"text",
"image",
"video",
]
):
raise ValueError(
f"Invalid data : {_}. It should be a list of "
"TextBlock, ImageBlock, or VideoBlock.",
)
if (
_["type"] == "video"
and _.get("source", {}).get("type") != "url"
):
raise ValueError(
f"The multimodal embedding API only supports URL input "
f"for video data, but got {_}.",
)
if _["type"] == "text":
assert "text" in _, (
f"Invalid text block: {_}. It should contain a "
f"'text' field.",
)
formatted_data.append({"text": _["text"]})
elif _["type"] == "video":
formatted_data.append({"video": _["source"]["url"]})
elif (
_["type"] == "image"
and "source" in _
and _["source"].get("type") in ["base64", "url"]
):
typ = _["source"]["type"]
if typ == "base64":
formatted_data.append(
{
"image": f'data:{_["source"]["media_type"]};'
f'base64,{_["source"]["data"]}',
},
)
elif typ == "url":
formatted_data.append(
{"image": _["source"]["url"]},
)
else:
raise ValueError(
f"Invalid block {_}. It should be a valid TextBlock, "
f"ImageBlock, or VideoBlock.",
)
# Handle the batch size limit of the DashScope multimodal embedding API
collected_embeddings = []
collected_time = 0.0
collected_tokens = 0
collected_source: Literal["cache", "api"] = "cache"
for _ in range(0, len(formatted_data), self.batch_size_limit):
batch_data = formatted_data[_ : _ + self.batch_size_limit]
batch_kwargs = {
"input": batch_data,
"model": self.model_name,
**kwargs,
}
res = await self._call_api(batch_kwargs)
collected_embeddings.extend(res.embeddings)
collected_time += res.usage.time
if res.usage.tokens:
collected_tokens += res.usage.tokens
if res.source == "api":
collected_source = "api"
return EmbeddingResponse(
embeddings=collected_embeddings,
usage=EmbeddingUsage(
tokens=collected_tokens,
time=collected_time,
),
source=collected_source,
)
async def _call_api(self, kwargs: dict[str, Any]) -> EmbeddingResponse:
"""
Call the DashScope multimodal embedding API by the given arguments.
"""
# Search in cache first
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
import dashscope
kwargs["api_key"] = self.api_key
start_time = datetime.now()
res = dashscope.MultiModalEmbedding.call(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if res.status_code != 200:
raise RuntimeError(
f"Failed to get embedding from DashScope API: {res}",
)
return EmbeddingResponse(
embeddings=[_["embedding"] for _ in res.output["embeddings"]],
usage=EmbeddingUsage(
tokens=res.usage.get(
"image_tokens",
0,
)
+ res.usage.get(
"input_tokens",
0,
),
time=time,
),
source="api",
)

View File

@@ -0,0 +1,45 @@
# -*- coding: utf-8 -*-
"""The embedding model base class."""
from typing import Any
from ._embedding_response import EmbeddingResponse
class EmbeddingModelBase:
"""Base class for embedding models."""
model_name: str
"""The embedding model name"""
supported_modalities: list[str]
"""The supported data modalities, e.g. "text", "image", "video"."""
dimensions: int
"""The dimensions of the embedding vector."""
def __init__(
self,
model_name: str,
dimensions: int,
) -> None:
"""Initialize the embedding model base class.
Args:
model_name (`str`):
The name of the embedding model.
dimensions (`int`):
The dimension of the embedding vector.
"""
self.model_name = model_name
self.dimensions = dimensions
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the embedding API with the given arguments."""
raise NotImplementedError(
f"The {self.__class__.__name__} class does not implement "
f"the __call__ method.",
)

View File

@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""The embedding response class."""
from dataclasses import dataclass, field
from typing import Literal, List
from ._embedding_usage import EmbeddingUsage
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..types import Embedding
@dataclass
class EmbeddingResponse(DictMixin):
"""The embedding response class."""
embeddings: List[Embedding]
"""The embedding data"""
id: str = field(default_factory=lambda: _get_timestamp(True))
"""The identity of the embedding response"""
created_at: str = field(default_factory=_get_timestamp)
"""The timestamp of the embedding response creation"""
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
"""The type of the response, must be `embedding`."""
usage: EmbeddingUsage | None = field(default_factory=lambda: None)
"""The usage of the embedding model API invocation, if available."""
source: Literal["cache", "api"] = field(default_factory=lambda: "api")
"""If the response comes from the cache or the API."""

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""The embedding usage class in agentscope."""
from dataclasses import dataclass, field
from typing import Literal
from .._utils._mixin import DictMixin
@dataclass
class EmbeddingUsage(DictMixin):
"""The usage of an embedding model API invocation."""
time: float
"""The time used in seconds."""
tokens: int | None = field(default_factory=lambda: None)
"""The number of tokens used, if available."""
type: Literal["embedding"] = field(default_factory=lambda: "embedding")
"""The type of the usage, must be `embedding`."""

View File

@@ -0,0 +1,187 @@
# -*- coding: utf-8 -*-
"""A file embedding cache implementation for storing and retrieving
embeddings in binary files."""
import hashlib
import json
import os
from typing import Any, List
import numpy as np
from ._cache_base import EmbeddingCacheBase
from .._logging import logger
from ..types import (
Embedding,
JSONSerializableObject,
)
class FileEmbeddingCache(EmbeddingCacheBase):
"""The embedding cache class that stores each embeddings vector in
binary files."""
def __init__(
self,
cache_dir: str = "./.cache/embeddings",
max_file_number: int | None = None,
max_cache_size: int | None = None,
) -> None:
"""Initialize the file embedding cache class.
Args:
cache_dir (`str`, defaults to `"./.cache/embeddings"`):
The directory to store the embedding files.
max_file_number (`int | None`, defaults to `None`):
The maximum number of files to keep in the cache directory. If
exceeded, the oldest files will be removed.
max_cache_size (`int | None`, defaults to `None`):
The maximum size of the cache directory in MB. If exceeded,
the oldest files will be removed until the size is within the
limit.
"""
self._cache_dir = os.path.abspath(cache_dir)
self.max_file_number = max_file_number
self.max_cache_size = max_cache_size
@property
def cache_dir(self) -> str:
"""The cache directory where the embedding files are stored."""
if not os.path.exists(self._cache_dir):
os.makedirs(self._cache_dir, exist_ok=True)
return self._cache_dir
async def store(
self,
embeddings: List[Embedding],
identifier: JSONSerializableObject,
overwrite: bool = False,
**kwargs: Any,
) -> None:
"""Store the embeddings with the given identifier.
Args:
embeddings (`List[Embedding]`):
The embeddings to store.
identifier (`JSONSerializableObject`):
The identifier to distinguish the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
overwrite (`bool`, defaults to `False`):
Whether to overwrite existing embeddings with the same
identifier. If `True`, existing embeddings will be replaced.
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
if not os.path.isfile(path_file):
raise RuntimeError(
f"Path {path_file} exists but is not a file.",
)
if overwrite:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
else:
np.save(path_file, embeddings)
await self._maintain_cache_dir()
async def retrieve(
self,
identifier: JSONSerializableObject,
) -> List[Embedding] | None:
"""Retrieve the embeddings with the given identifier. If not found,
return `None`.
Args:
identifier (`JSONSerializableObject`):
The identifier to retrieve the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
return np.load(os.path.join(self.cache_dir, filename)).tolist()
return None
async def remove(self, identifier: JSONSerializableObject) -> None:
"""Remove the embeddings with the given identifier.
Args:
identifier (`JSONSerializableObject`):
The identifiers to remove the embeddings, which will be
used to generate a hashable filename, so it should be
JSON serializable (e.g. a string, number, list, dict).
"""
filename = self._get_filename(identifier)
path_file = os.path.join(self.cache_dir, filename)
if os.path.exists(path_file):
os.remove(path_file)
else:
raise FileNotFoundError(f"File {path_file} does not exist.")
async def clear(self) -> None:
"""Clear the cache directory by removing all files."""
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
os.remove(os.path.join(self.cache_dir, filename))
def _get_cache_size(self) -> float:
"""Get the current size of the cache directory in MB."""
total_size = 0
for filename in os.listdir(self.cache_dir):
if filename.endswith(".npy"):
path_file = os.path.join(self.cache_dir, filename)
if os.path.isfile(path_file):
total_size += os.path.getsize(path_file)
return total_size / (1024.0 * 1024.0)
@staticmethod
def _get_filename(identifier: JSONSerializableObject) -> str:
"""Generate a filename based on the identifier."""
json_str = json.dumps(identifier, ensure_ascii=False)
return hashlib.sha256(json_str.encode("utf-8")).hexdigest() + ".npy"
async def _maintain_cache_dir(self) -> None:
"""Maintain the cache directory by removing old files if the number of
files exceeds the maximum limit or if the cache size exceeds the
maximum size."""
files = [
(_.name, _.stat().st_mtime)
for _ in os.scandir(self.cache_dir)
if _.is_file() and _.name.endswith(".npy")
]
files.sort(key=lambda x: x[1])
if self.max_file_number and len(files) > self.max_file_number:
for file_name, _ in files[: 0 - self.max_file_number]:
os.remove(os.path.join(self.cache_dir, file_name))
logger.info(
"Remove cached embedding file %s for limited number "
"of files (%d).",
file_name,
self.max_file_number,
)
files = files[0 - self.max_file_number :]
if (
self.max_cache_size is not None
and self._get_cache_size() > self.max_cache_size
):
removed_files = []
for filename, _ in files:
os.remove(os.path.join(self.cache_dir, filename))
removed_files.append(filename)
if self._get_cache_size() <= self.max_cache_size:
break
if removed_files:
logger.info(
"Remove %d cached embedding file(s) for limited "
"cache size (%d MB).",
len(removed_files),
self.max_cache_size,
)

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""The gemini text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class GeminiTextEmbedding(EmbeddingModelBase):
"""The Gemini text embedding model."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 3072,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Gemini text embedding model class.
Args:
api_key (`str`):
The Gemini API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 3072):
The dimension of the embedding vector, refer to the
`official documentation
<https://ai.google.dev/gemini-api/docs/embeddings?hl=zh-cn#control-embedding-size>`_
for more details.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
from google import genai
super().__init__(model_name, dimensions)
self.client = genai.Client(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""The Gemini embedding API call.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
# TODO: handle the batch size limit
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"model": self.model_name,
"contents": gather_text,
"config": kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = self.client.models.embed_content(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.values for _ in response.embeddings],
)
return EmbeddingResponse(
embeddings=[_.values for _ in response.embeddings],
usage=EmbeddingUsage(
time=time,
),
)

View File

@@ -0,0 +1,106 @@
# -*- coding: utf-8 -*-
"""The ollama text embedding model class."""
from datetime import datetime
from typing import List, Any
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ..embedding import EmbeddingModelBase
from ..message import TextBlock
class OllamaTextEmbedding(EmbeddingModelBase):
"""The Ollama embedding model."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
model_name: str,
dimensions: int,
host: str | None = None,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the Ollama text embedding model class.
Args:
model_name (`str`):
The name of the embedding model.
dimensions (`int`):
The dimension of the embedding vector, the parameter should be
provided according to the model used.
host (`str | None`, defaults to `None`):
The host URL for the Ollama API.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
"""
import ollama
super().__init__(model_name, dimensions)
self.client = ollama.AsyncClient(host=host, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the Ollama embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embed(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=response.embeddings,
)
return EmbeddingResponse(
embeddings=response.embeddings,
usage=EmbeddingUsage(
time=time,
),
)

View File

@@ -0,0 +1,109 @@
# -*- coding: utf-8 -*-
"""The OpenAI text embedding model class."""
from datetime import datetime
from typing import Any, List
from ._embedding_response import EmbeddingResponse
from ._embedding_usage import EmbeddingUsage
from ._cache_base import EmbeddingCacheBase
from ._embedding_base import EmbeddingModelBase
from ..message import TextBlock
class OpenAITextEmbedding(EmbeddingModelBase):
"""OpenAI text embedding model class."""
supported_modalities: list[str] = ["text"]
"""This class only supports text input."""
def __init__(
self,
api_key: str,
model_name: str,
dimensions: int = 1024,
embedding_cache: EmbeddingCacheBase | None = None,
**kwargs: Any,
) -> None:
"""Initialize the OpenAI text embedding model class.
Args:
api_key (`str`):
The OpenAI API key.
model_name (`str`):
The name of the embedding model.
dimensions (`int`, defaults to 1024):
The dimension of the embedding vector.
embedding_cache (`EmbeddingCacheBase | None`, defaults to `None`):
The embedding cache class instance, used to cache the
embedding results to avoid repeated API calls.
# TODO: handle batch size limit and token limit
"""
import openai
super().__init__(model_name, dimensions)
self.client = openai.AsyncClient(api_key=api_key, **kwargs)
self.embedding_cache = embedding_cache
async def __call__(
self,
text: List[str | TextBlock],
**kwargs: Any,
) -> EmbeddingResponse:
"""Call the OpenAI embedding API.
Args:
text (`List[str | TextBlock]`):
The input text to be embedded. It can be a list of strings.
"""
gather_text = []
for _ in text:
if isinstance(_, dict) and "text" in _:
gather_text.append(_["text"])
elif isinstance(_, str):
gather_text.append(_)
else:
raise ValueError(
"Input text must be a list of strings or TextBlock dicts.",
)
kwargs = {
"input": gather_text,
"model": self.model_name,
"dimensions": self.dimensions,
"encoding_format": "float",
**kwargs,
}
if self.embedding_cache:
cached_embeddings = await self.embedding_cache.retrieve(
identifier=kwargs,
)
if cached_embeddings:
return EmbeddingResponse(
embeddings=cached_embeddings,
usage=EmbeddingUsage(
tokens=0,
time=0,
),
source="cache",
)
start_time = datetime.now()
response = await self.client.embeddings.create(**kwargs)
time = (datetime.now() - start_time).total_seconds()
if self.embedding_cache:
await self.embedding_cache.store(
identifier=kwargs,
embeddings=[_.embedding for _ in response.data],
)
return EmbeddingResponse(
embeddings=[_.embedding for _ in response.data],
usage=EmbeddingUsage(
tokens=response.usage.total_tokens,
time=time,
),
)

View File

@@ -0,0 +1,44 @@
# -*- coding: utf-8 -*-
"""The evaluation module in AgentScope."""
from ._evaluator import (
EvaluatorBase,
RayEvaluator,
GeneralEvaluator,
)
from ._metric_base import (
MetricBase,
MetricResult,
MetricType,
)
from ._task import Task
from ._solution import SolutionOutput
from ._benchmark_base import BenchmarkBase
from ._evaluator_storage import (
EvaluatorStorageBase,
FileEvaluatorStorage,
)
from ._ace_benchmark import (
ACEBenchmark,
ACEAccuracy,
ACEProcessAccuracy,
ACEPhone,
)
__all__ = [
"BenchmarkBase",
"EvaluatorBase",
"RayEvaluator",
"GeneralEvaluator",
"MetricBase",
"MetricResult",
"MetricType",
"EvaluatorStorageBase",
"FileEvaluatorStorage",
"Task",
"SolutionOutput",
"ACEBenchmark",
"ACEAccuracy",
"ACEProcessAccuracy",
"ACEPhone",
]

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The ACE benchmark related implementations in AgentScope."""
from ._ace_benchmark import ACEBenchmark
from ._ace_metric import (
ACEAccuracy,
ACEProcessAccuracy,
)
from ._ace_tools_zh import ACEPhone
__all__ = [
"ACEBenchmark",
"ACEPhone",
"ACEAccuracy",
"ACEProcessAccuracy",
]

View File

@@ -0,0 +1,240 @@
# -*- coding: utf-8 -*-
"""The ACE benchmark class in agentscope. The code is implemented with
reference to the `ACEBench <https://github.com/ACEBench/ACEBench>`_
under the MIT license."""
import json
import os
from typing import Generator
import json5
import requests
from tqdm import tqdm
from ._ace_metric import ACEAccuracy, ACEProcessAccuracy
from ._ace_tools_zh import ACEPhone
from .._benchmark_base import BenchmarkBase
from .._task import Task
class ACEBenchmark(BenchmarkBase):
"""The ACE benchmark for evaluating AI agents."""
data_dir_url: str = (
"https://raw.githubusercontent.com/ACEBench/ACEBench/main/data_all"
)
"""The URL to the data dir"""
data_subdir: list[str] = [
# "data_en", # TODO: enable English version
"data_zh",
]
ground_truth_dir: str = "possible_answer"
data_files: list[str] = [
"data_agent_multi_step.json",
"data_agent_multi_turn.json",
# "data_normal_atom_bool.json",
# "data_normal_atom_enum.json",
# "data_normal_atom_list.json",
# "data_normal_atom_number.json",
# "data_normal_atom_object_deep.json",
# "data_normal_atom_object_short.json",
#
# "data_normal_multi_turn_user_adjust.json",
# "data_normal_multi_turn_user_switch.json",
#
# "data_normal_preference.json",
# "data_normal_similar_api.json",
# "data_normal_single_turn_parallel_function.json",
# "data_normal_single_turn_single_function.json",
#
# "data_special_error_param.json",
# "data_special_incomplete.json",
# "data_special_irrelevant.json",
]
"""The data filenames"""
def __init__(
self,
data_dir: str,
) -> None:
"""Initialize the ACEBenchmark
Args:
data_dir (`str`):
The directory where the dataset is downloaded and saved.
"""
super().__init__(
name="ACEBench",
description="The ACE benchmark for evaluating AI agents.",
)
self.data_dir = os.path.abspath(data_dir)
if os.path.exists(data_dir) and not os.path.isdir(data_dir):
raise RuntimeError(
f"The data_dir `{data_dir}` is not a valid directory path.",
)
os.makedirs(data_dir, exist_ok=True)
if not self._verify_data():
self._download_data()
self.dataset = self._load_data()
def _load_data(self) -> list[dict]:
"""Load the dataset from the data directory."""
dataset = []
for subdir in self.data_subdir:
for filename in self.data_files:
file_path = os.path.join(self.data_dir, subdir, filename)
gt_path = os.path.join(
self.data_dir,
subdir,
self.ground_truth_dir,
filename,
)
gt_dataset = {}
with open(gt_path, "r", encoding="utf-8") as gt_file:
for line in gt_file:
gt_data = json5.loads(line)
gt_dataset[gt_data["id"]] = gt_data
with open(file_path, "r", encoding="utf-8") as f:
for line in f:
data = json5.loads(line)
gt = gt_dataset[data["id"]]
gt.pop("id", None)
data["ground_truth"] = gt["ground_truth"]
data["mile_stone"] = gt["mile_stone"]
data["language"] = subdir.rsplit(
"_",
maxsplit=1,
)[-1]
data["tags"] = {
"language": data["language"],
"category": filename.split(
".",
maxsplit=1,
)[0].removeprefix(
"data_",
),
}
dataset.append(data)
return dataset
def _verify_data(self) -> bool:
"""Verify the data completeness and integrity."""
for subdir in self.data_subdir:
for filename in self.data_files:
file_path = os.path.join(self.data_dir, subdir, filename)
if not os.path.exists(file_path):
return False
gt_path = os.path.join(
self.data_dir,
subdir,
self.ground_truth_dir,
filename,
)
if not os.path.exists(gt_path):
return False
return True
def _download_data(self) -> None:
"""Download the data from the URL"""
for subdir in self.data_subdir:
subdir_path = os.path.join(self.data_dir, subdir)
subdir_gt_path = os.path.join(subdir_path, self.ground_truth_dir)
os.makedirs(subdir_path, exist_ok=True)
os.makedirs(subdir_gt_path, exist_ok=True)
for filename in tqdm(
self.data_files,
desc=f"Downloading {subdir}",
):
response = requests.get(
f"{self.data_dir_url}/{subdir}/{filename}",
)
response.raise_for_status()
with open(os.path.join(subdir_path, filename), "wb") as f:
f.write(response.content)
gt_response = requests.get(
f"{self.data_dir_url}/{subdir}/"
f"{self.ground_truth_dir}/{filename}",
)
gt_response.raise_for_status()
with open(os.path.join(subdir_gt_path, filename), "wb") as f:
f.write(gt_response.content)
@staticmethod
def _data_to_task(item: dict) -> Task:
"""Convert a dataset item to a Task object."""
# Start the simulated phone and load initial configuration
ace_phone = ACEPhone()
ace_phone.load_initial_config(item["initial_config"])
# Obtain tool functions
tools: list[tuple] = []
for function_schema in item["function"]:
name = function_schema["name"]
# Handle the schema differences
formatted_schema = json.loads(
json.dumps(
function_schema,
).replace(
'"type": "dict"',
'"type": "object"',
),
)
tool_function = ace_phone.get_tool_function(name)
tools.append(
(
tool_function,
{
"type": "function",
"function": formatted_schema,
},
),
)
return Task(
id=item["id"],
input=item["question"],
ground_truth={
"state": item["ground_truth"],
"mile_stone": item.get("mile_stone", []),
},
tags=item.get("tags", {}),
metrics=[
ACEAccuracy(item["ground_truth"]),
ACEProcessAccuracy(item["mile_stone"]),
],
metadata={
# The phone is used to extract the final state after finishing
# the task.
"phone": ace_phone,
# The provided tools for this task, used to equip the agent
"tools": tools,
},
)
def __iter__(self) -> Generator[Task, None, None]:
"""Iterate over the benchmark."""
for item in self.dataset:
yield self._data_to_task(item)
def __getitem__(self, index: int) -> Task:
"""Get a task by index."""
return self._data_to_task(self.dataset[index])
def __len__(self) -> int:
"""Get the length of the benchmark."""
return len(self.dataset)

View File

@@ -0,0 +1,131 @@
# -*- coding: utf-8 -*-
"""The ACE benchmark metric implementations in AgentScope."""
from .._solution import SolutionOutput
from .._metric_base import MetricBase, MetricResult, MetricType
class ACEProcessAccuracy(MetricBase):
"""The ace benchmark process accuracy metric."""
def __init__(
self,
mile_stone: list[str],
) -> None:
"""Initialize the AceBench process accuracy metric."""
super().__init__(
name="process_accuracy",
metric_type=MetricType.NUMERICAL,
description="The AceBench Agent eval process accuracy metric.",
)
self.mile_stone = mile_stone
async def __call__(
self,
solution: SolutionOutput,
) -> MetricResult:
"""Calculate the metric result."""
# Turn the tool use block sequence into ACEBench format
# e.g. func(arg1='dfd', arg2=44)
gathered_trajectory = []
for tool_call in solution.trajectory:
if tool_call.get("type") == "tool_use":
function_name = tool_call.get("name")
kwargs = tool_call.get("input")
gathered_kwargs = []
for key, value in kwargs.items():
if isinstance(value, str):
gathered_kwargs.append(
f"{key}='{value}'",
)
else:
gathered_kwargs.append(
f"{key}={value}",
)
kwargs_str = ", ".join(gathered_kwargs)
gathered_trajectory.append(
f"[{function_name}({kwargs_str})]",
)
for stone in self.mile_stone:
if stone not in gathered_trajectory:
return MetricResult(
name=self.name,
result=0,
message=f"Error: Missing milestone '{stone}' in "
"the given trajectory.",
)
return MetricResult(
name=self.name,
result=1,
message="Success",
)
class ACEAccuracy(MetricBase):
"""The ace benchmark metric"""
def __init__(
self,
state: list[dict],
) -> None:
"""Initialize the _metric object."""
super().__init__(
"accuracy",
MetricType.NUMERICAL,
"The AceBench Agent eval accuracy metric.",
)
self.state = state
async def __call__(
self,
solution: SolutionOutput,
) -> MetricResult:
"""Calculate the metric result."""
# Check if the solution matches the ground truth
if not isinstance(solution.output, list):
raise ValueError("Ground truth state must be a list.")
# Handle the typos in ACEBench dataset
gathered_state = {}
for item in self.state:
for key, value in item.items():
if key.endswith("API"):
key = key.replace("API", "Api")
elif key.endswith("rpi"):
key = key.replace("pi", "Api")
gathered_state[key] = value
gathered_output = {}
for item in solution.output:
for key, value in item.items():
gathered_output[key] = value
if not set(gathered_state.keys()).issubset(gathered_output.keys()):
raise ValueError(
"Missing keys in solution output compared to state, "
f"ground truth keys: {gathered_state.keys()}, "
f"solution keys: {gathered_output.keys()}",
)
for key, value in gathered_state.items():
if value != gathered_output.get(key):
return MetricResult(
name=self.name,
result=0,
message=(
f"Error: Mismatch in key '{key}':"
f"\n{value}\n{gathered_output.get(key)}"
),
)
return MetricResult(
name=self.name,
result=1,
message="Success: All keys match",
)

View File

@@ -0,0 +1,14 @@
# -*- coding: utf-8 -*-
"""The ACEBench simulation tools in AgentScope."""
from ._message_api import MessageApi
from ._travel_api import TravelApi
from ._reminder_api import ReminderApi
from ._food_platform_api import FoodPlatformApi
__all__ = [
"MessageApi",
"TravelApi",
"ReminderApi",
"FoodPlatformApi",
]

View File

@@ -0,0 +1,302 @@
# -*- coding: utf-8 -*-
"""The food platform API in the ACEBench evaluation."""
from ._shared_state import SharedState
class FoodPlatformApi(SharedState):
"""The food platform Api in the ACEBench evaluation."""
tool_functions: list[str] = [
"login_food_platform",
"view_logged_in_users",
"check_balance",
"add_food_delivery_order",
"get_products",
"view_orders",
"search_orders",
]
def __init__(self, shared_state: dict) -> None:
super().__init__(shared_state)
# 设置用户和初始金额
self.users: dict = {
"Eve": {
"user_id": "U100",
"password": "password123",
"balance": 500.0,
},
"Frank": {
"user_id": "U101",
"password": "password456",
"balance": 300.0,
},
"Grace": {
"user_id": "U102",
"password": "password789",
"balance": 150.0,
},
"Helen": {
"user_id": "U103",
"password": "password321",
"balance": 800.0,
},
"Isaac": {
"user_id": "U104",
"password": "password654",
"balance": 400.0,
},
"Jack": {
"user_id": "U105",
"password": "password654",
"balance": 120.0,
},
}
# 设置六个商家及其菜单
self.merchant_list: dict[str, dict] = {
"达美乐": {
"merchant_id": "M100",
"service_type": "Pizza",
"menu": [
{"product": "玛格丽特披萨", "price": 68.0},
{"product": "超级至尊披萨", "price": 88.0},
],
},
"米村拌饭": {
"merchant_id": "M101",
"service_type": "Bibimbap",
"menu": [
{"product": "石锅拌饭", "price": 35.0},
{"product": "韩式牛肉拌饭", "price": 45.0},
],
},
"海底捞": {
"merchant_id": "M102",
"service_type": "Hotpot",
"menu": [
{"product": "牛肉卷", "price": 68.0},
{"product": "海鲜拼盘", "price": 88.0},
],
},
"喜茶": {
"merchant_id": "M103",
"service_type": "Milk Tea",
"menu": [
{"product": "芝士奶茶", "price": 25.0},
{"product": "四季春奶茶", "price": 22.0},
],
},
"盒马生鲜": {
"merchant_id": "M104",
"service_type": "Fresh Grocery",
"menu": [
{"product": "有机蔬菜包", "price": 15.0},
{"product": "生鲜大礼包", "price": 99.0},
],
},
"九田家烤肉": {
"merchant_id": "M105",
"service_type": "BBQ",
"menu": [
{"product": "韩式烤牛肉", "price": 128.0},
{"product": "烤五花肉", "price": 78.0},
],
},
}
# 设置已登录用户列表
self.logged_in_users: list[str] = []
# 订单列表
self.orders: list = []
def get_state_dict(self) -> dict:
"""Get the current state dict of the FoodPlatformApi."""
return {
"FoodPlatform": {
"logged_in_users": self.logged_in_users,
"orders": self.orders,
"users": self.users,
},
}
def login_food_platform(
self,
username: str,
password: str,
) -> dict[str, bool | str]:
"""使用用户名和密码登录外卖平台。
Args:
username (`str`):
用户的用户名。
password (`str`):
用户的密码。
"""
if not self.wifi:
return {"status": False, "message": "wifi未打开无法登录"}
if username not in self.users:
return {"status": False, "message": "用户不存在"}
if self.users[username]["password"] != password:
return {"status": False, "message": "密码错误"}
# 检查是否已经有用户登录
if username in self.logged_in_users:
return {"status": False, "message": f"{username} 已经登录"}
# 记录已登录用户
self.logged_in_users.append(username)
return {"status": True, "message": f"用户{username}登陆成功!"}
def view_logged_in_users(self) -> dict:
"""查看当前所有登录的用户。"""
if not self.logged_in_users:
return {
"status": False,
"message": "当前没有登录food platform",
}
return {"status": True, "logged_in_users": self.logged_in_users}
def check_balance(self, user_name: str) -> float:
"""查询指定用户的余额。
Args:
user_name (`str`):
用户的用户名。
"""
if user_name in self.users:
return self.users[user_name]["balance"]
else:
return 0.0
def add_food_delivery_order(
self,
username: str,
merchant_name: str,
items: list[dict[str, str | int]],
) -> dict[str, bool | str]:
"""订外卖
Args:
username (`str`):
下订单的用户姓名。
merchant_name (`str`):
下订单的商家名称。
items (`list[dict[str, str | int]]`):
订单中商品的列表,每个商品包含名称和数量。
"""
if username not in self.logged_in_users:
return {
"status": False,
"message": f"用户 {username} 未登录food platform",
}
if merchant_name not in self.merchant_list:
return {"status": False, "message": "商家不存在"}
total_price = 0.0
order_items = []
for item in items:
product_name = item.get("product")
quantity = item.get("quantity", 1)
if not isinstance(quantity, int) or quantity <= 0:
return {
"status": False,
"message": f"无效的数量 {quantity} 对于商品 {product_name}",
}
# 查找商品价格
product_found = False
for product in self.merchant_list[merchant_name]["menu"]:
if product["product"] == product_name:
total_price += product["price"] * quantity
order_items.append(
{
"product": product_name,
"quantity": quantity,
"price_per_unit": product["price"],
},
)
product_found = True
break
if not product_found:
return {
"status": False,
"message": f"商品 {product_name} 不存在于 "
f"{merchant_name} 的菜单中",
}
# 检查余额是否足够
if total_price >= self.users[username]["balance"]:
return {"status": False, "message": "余额不足,无法下单"}
# 扣除余额并创建订单
self.users[username]["balance"] -= total_price
order = {
"user_name": username,
"merchant_name": merchant_name,
"items": order_items,
"total_price": total_price,
}
self.orders.append(order)
return {
"status": True,
"message": f"外卖订单成功下单给 {merchant_name}" f"总金额为 {total_price}",
}
def get_products(
self,
merchant_name: str,
) -> list[dict[str, str | float]] | dict[str, bool | str]:
"""获取特定商家的商品列表。
Args:
merchant_name (`str`):
要获取商品的商家名称。
"""
merchant = self.merchant_list.get(merchant_name)
if merchant:
return merchant["menu"]
else:
return {
"status": False,
"message": f"商家 '{merchant_name}' 不存在",
}
def view_orders(
self,
user_name: str,
) -> dict[str, bool | str | list[dict[str, str | int | float]]]:
"""查看用户的所有订单"""
user_orders = [
order for order in self.orders if order["user_name"] == user_name
]
if not user_orders:
return {"status": False, "message": "用户没有订单记录"}
return {"status": True, "orders": user_orders}
def search_orders(
self,
keyword: str,
) -> dict[str, bool | str | list[dict[str, str | float]]]:
"""根据关键字搜索订单。"""
matched_orders = [
order
for order in self.orders
if keyword.lower() in order["merchant_name"].lower()
or any(
keyword.lower() in item.lower()
for item in order.get("items", [])
)
]
if not matched_orders:
return {"status": False, "message": "没有找到匹配的订单"}
return {"status": True, "orders": matched_orders}

View File

@@ -0,0 +1,340 @@
# -*- coding: utf-8 -*-
"""The Message API in the ACEBench evaluation."""
from datetime import datetime
from ._shared_state import SharedState
class MessageApi(SharedState):
"""The message Api in the ACEBench evaluation."""
tool_functions: list[str] = [
"send_message",
"delete_message",
"view_messages_between_users",
"search_messages",
"get_all_message_times_with_ids",
"get_latest_message_id",
"get_earliest_message_id",
]
def __init__(self, share_state: dict) -> None:
"""Initialize the MessageApi with shared state."""
super().__init__(share_state)
# 设置六个用户
self.max_capacity = 6
self.user_list: dict[str, dict[str, str | int]] = {
"Eve": {
"user_id": "USR100",
"phone_number": "123-456-7890",
"occupation": "Software Engineer",
},
"Frank": {
"user_id": "USR101",
"phone_number": "234-567-8901",
"occupation": "Data Scientist",
},
"Grace": {
"user_id": "USR102",
"phone_number": "345-678-9012",
"occupation": "Product Manager",
},
"Helen": {
"user_id": "USR103",
"phone_number": "456-789-0123",
"occupation": "UX Designer",
},
"Isaac": {
"user_id": "USR104",
"phone_number": "567-890-1234",
"occupation": "DevOps Engineer",
},
"Jack": {
"user_id": "USR105",
"phone_number": "678-901-2345",
"occupation": "Marketing Specialist",
},
}
# 设置六个用户之间的短信记录
# 信息1和reminder配合 信息2和food配合
self.inbox: dict[int, dict[str, str | int]] = {
1: {
"sender_id": "USR100",
"receiver_id": "USR101",
"message": "Hey Frank, don't forget about our meeting on "
"2024-06-11 at 4 PM in Conference Room 1.",
"time": "2024-06-09",
},
2: {
"sender_id": "USR101",
"receiver_id": "USR102",
"message": """你能帮我点一个\"玛格丽特披萨\"的外卖吗,商家是达美乐。""",
"time": "2024-03-09",
},
3: {
"sender_id": "USR102",
"receiver_id": "USR103",
"message": "帮我查一些喜茶有哪些奶茶外卖,买一杯便宜些的奶茶。"
"买完以后记得回复我,回复的内容是(已经买好了)",
"time": "2023-12-05",
},
4: {
"sender_id": "USR103",
"receiver_id": "USR102",
"message": "No problem Helen, I can assist you.",
"time": "2024-09-09",
},
5: {
"sender_id": "USR104",
"receiver_id": "USR105",
"message": "Isaac, are you available for a call?",
"time": "2024-06-06",
},
6: {
"sender_id": "USR105",
"receiver_id": "USR104",
"message": "Yes Jack, let's do it in 30 minutes.",
"time": "2024-01-15",
},
}
self.message_id_counter: int = 6
def get_state_dict(self) -> dict:
"""Get the current state dict of the MessageApi."""
# To avoid the error in ACEBench dataset
inbox_state = {}
for key, value in self.inbox.items():
inbox_state[str(key)] = value
return {
"MessageApi": {
"inbox": inbox_state,
},
}
def send_message(
self,
sender_name: str,
receiver_name: str,
message: str,
) -> dict[str, bool | str]:
"""将一条消息从一个用户发送给另一个用户。
Args:
sender_name (`str`):
发送消息的用户姓名。
receiver_name (`str`):
接收消息的用户姓名。
message (`str`):
要发送的消息内容。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法发送短信"}
if not self.wifi:
return {"status": False, "message": "wifi关闭此时不能发送信息"}
if len(self.inbox) >= self.max_capacity:
return {
"status": False,
"message": "内存容量不够了你需要询问user删除哪一条短信。",
}
# 验证发送者和接收者是否存在
if (
sender_name not in self.user_list
or receiver_name not in self.user_list
):
return {"status": False, "message": "发送者或接收者不存在"}
sender_id = self.user_list[sender_name]["user_id"]
receiver_id = self.user_list[receiver_name]["user_id"]
# 将短信添加到inbox
self.message_id_counter += 1
self.inbox[self.message_id_counter] = {
"sender_id": sender_id,
"receiver_id": receiver_id,
"message": message,
}
return {"status": True, "message": f"短信成功发送给{receiver_name}"}
def delete_message(self, message_id: int) -> dict[str, bool | str]:
"""根据消息 ID 删除一条消息。
Args:
message_id (`int`):
要删除的消息的 ID。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法删除短信"}
if message_id not in self.inbox:
return {"status": False, "message": "短信ID不存在"}
del self.inbox[message_id]
return {"status": True, "message": f"短信ID {message_id} 已成功删除。"}
def view_messages_between_users(
self,
sender_name: str,
receiver_name: str,
) -> dict:
"""获取特定用户发送给另一个用户的所有消息。
Args:
sender_name (`str`):
发送消息的用户姓名。
receiver_name (`str`):
接收消息的用户姓名。
"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法查看短信信息",
}
if sender_name not in self.user_list:
return {"status": False, "message": "发送者不存在"}
if receiver_name not in self.user_list:
return {"status": False, "message": "接收者不存在"}
sender_id = self.user_list[sender_name]["user_id"]
receiver_id = self.user_list[receiver_name]["user_id"]
messages_between_users = []
# 遍历 inbox找出 sender_id 发送给 receiver_id 的短信
for msg_id, msg_data in self.inbox.items():
if (
msg_data["sender_id"] == sender_id
and msg_data["receiver_id"] == receiver_id
):
messages_between_users.append(
{
"id": msg_id,
"sender": sender_name,
"receiver": receiver_name,
"message": msg_data["message"],
},
)
if not messages_between_users:
return {"status": False, "message": "没有找到相关的短信记录"}
return {"status": True, "messages": messages_between_users}
def search_messages(
self,
user_name: str,
keyword: str,
) -> dict:
"""搜索特定用户消息中包含特定关键字的消息。
Args:
user_name (`str`):
要搜索消息的用户姓名。
keyword (`str`):
要在消息中搜索的关键字。
"""
if user_name not in self.user_list:
return {"status": False, "message": "用户不存在"}
user_id = self.user_list[user_name]["user_id"]
matched_messages = []
# 遍历 inbox找到发送或接收中包含关键词的消息
for msg_id, msg_data in self.inbox.items():
if (
user_id in (msg_data["sender_id"], msg_data["receiver_id"])
and keyword.lower() in msg_data["message"].lower()
):
matched_messages.append(
{
"id": msg_id,
"sender_id": msg_data["sender_id"],
"receiver_id": msg_data["receiver_id"],
"message": msg_data["message"],
},
)
if not matched_messages:
return {"status": False, "message": "没有找到包含关键词的短信"}
return {"status": True, "messages": matched_messages}
def get_all_message_times_with_ids(
self,
) -> dict:
"""获取所有短信的时间以及对应的短信编号。"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录获取所有短信的时间以及对应的短信编号。",
}
message_times_with_ids = {
msg_id: msg_data["time"] for msg_id, msg_data in self.inbox.items()
}
return message_times_with_ids
def get_latest_message_id(self) -> dict:
"""获取最近发送的消息的 ID。"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法获取最新发送的短信ID。",
}
if not self.inbox:
return {"status": False, "message": "短信记录为空"}
# 遍历所有短信,找出时间最新的短信
latest_message_id = None
latest_time = None
for message_id, message_data in self.inbox.items():
message_time = datetime.strptime(
str(message_data["time"]),
"%Y-%m-%d",
)
if latest_time is None or message_time > latest_time:
latest_time = message_time
latest_message_id = message_id
return {
"status": True,
"message": f"最新的短信ID是 {latest_message_id}",
"message_id": latest_message_id,
}
def get_earliest_message_id(self) -> dict:
"""获取最早发送的消息的 ID。"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法获取最早发送的短信ID",
}
if not self.inbox:
return {"status": False, "message": "短信记录为空"}
# 遍历所有短信,找出时间最早的短信
earliest_message_id = None
earliest_time = None
for message_id, message_data in self.inbox.items():
message_time = datetime.strptime(
str(message_data["time"]),
"%Y-%m-%d",
)
if earliest_time is None or message_time < earliest_time:
earliest_time = message_time
earliest_message_id = message_id
return {
"status": True,
"message": f"最早的短信ID是 {earliest_message_id}",
"message_id": earliest_message_id,
}

View File

@@ -0,0 +1,214 @@
# -*- coding: utf-8 -*-
"""The reminder API in ACEBench simulation tools."""
from datetime import datetime
from ._shared_state import SharedState
class ReminderApi(SharedState):
"""The reminder Api in the ACEBench evaluation."""
tool_functions: list[str] = [
"view_reminder_by_title",
"add_reminder",
"delete_reminder",
"view_all_reminders",
"mark_as_notified",
"search_reminders",
]
def __init__(self, share_state: dict) -> None:
"""Initialize the Reminder Api in the ACEBench evaluation."""
super().__init__(share_state)
self.max_capacity = 6
self.reminder_list: dict[
int,
dict,
] = {
1: {
"reminder_id": 1001,
"title": "Doctor's Appointment",
"description": "Visit Dr. Smith for a checkup.",
"time": "2024-07-15 09:30",
"notified": False,
},
2: {
"reminder_id": 1002,
"title": "Team Meeting",
"description": "Monthly project review with the team.",
"time": "2024-07-17 11:00",
"notified": False,
},
3: {
"reminder_id": 1003,
"title": "To-do list",
"description": '首先帮Frank在"盒马生鲜"点外卖,'
'需要定两个"生鲜大礼包"再发短信告诉Frank'
'"购买商品的价格是()元"。要把括号换成实际金额,'
"保留一位小数。",
"time": "2024-07-16 11:00",
"notified": False,
},
}
self.reminder_id_counter: int = 3
def get_state_dict(self) -> dict:
"""Get the current state dict of the ReminderApi."""
return {
"ReminderApi": {
"reminder_list": self.reminder_list,
},
}
def _check_capacity(self) -> bool:
"""检查备忘录容量是否已满。"""
return len(self.reminder_list) >= self.max_capacity
def view_reminder_by_title(
self,
title: str,
) -> dict[str, str | bool | dict[str, str | bool | datetime]]:
"""根据提醒的标题查看特定的提醒。
Args:
title (str): 提醒的标题。
Returns:
dict[str, str | bool | dict[str, str | bool | datetime]]:
包含查找状态和提醒详情的字典。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法查看提醒"}
for reminder in self.reminder_list.values():
if reminder["title"] == title:
return {"status": True, "reminder": reminder}
return {"status": False, "message": f"没有找到标题为 '{title}' 的提醒"}
def add_reminder(
self,
title: str,
description: str,
time: datetime,
) -> dict[str, bool | str]:
"""添加一个新的提醒。
Args:
title (str): 提醒标题。
description (str): 提醒描述。
time (datetime): 提醒时间, 一定遵循格式"YYYY-MM-DD HH:MM"
Returns:
dict[str, bool | str]: 包含添加状态和结果的字典。
"""
if not self.logged_in:
return {
"status": False,
"message": "device未登录无法添加一个新的提醒",
}
if self._check_capacity():
return {"status": False, "message": "提醒容量已满,无法添加新的提醒"}
self.reminder_id_counter += 1
reminder_id = self.reminder_id_counter
self.reminder_list[reminder_id] = {
"reminder_id": reminder_id,
"title": title,
"description": description,
"time": time,
"notified": False,
}
return {"status": True, "message": f"提醒 '{title}' 已成功添加"}
def delete_reminder(self, reminder_id: int) -> dict[str, bool | str]:
"""删除指定的提醒。
Args:
reminder_id (int): 要删除的提醒ID。
Returns:
dict[str, bool | str]: 包含删除状态和结果的字典。
"""
if not self.logged_in:
return {"status": False, "message": "device未登录无法删除指定的提醒"}
if reminder_id not in self.reminder_list:
return {"status": False, "message": "提醒ID不存在"}
del self.reminder_list[reminder_id]
return {"status": True, "message": f"提醒ID {reminder_id} 已成功删除"}
def view_all_reminders(
self,
) -> dict:
"""查看所有的提醒。
Returns:
dict:
包含所有提醒的字典列表。
"""
if not self.reminder_list:
return {"status": False, "message": "没有任何提醒"}
reminders = []
for reminder in self.reminder_list.values():
reminders.append(
{
"title": reminder["title"],
"description": reminder["description"],
"time": reminder["time"],
"notified": reminder["notified"],
},
)
return {"status": True, "reminders": reminders}
def mark_as_notified(
self,
reminder_id: int,
) -> dict[str, bool | str]:
"""标记提醒为已通知。
Args:
reminder_id (int): 要标记为已通知的提醒ID。
Returns:
dict[str, bool | str]:: 包含操作结果的字典。
"""
if reminder_id not in self.reminder_list:
return {"status": False, "message": "提醒ID不存在"}
self.reminder_list[reminder_id]["notified"] = True
return {"status": True, "message": f"提醒ID {reminder_id} 已标记为已通知"}
def search_reminders(
self,
keyword: str,
) -> dict:
"""根据关键词搜索提醒。
Args:
keyword (str): 搜索关键词。
Returns:
`dict`:
包含匹配提醒的字典列表。
"""
matched_reminders = []
for reminder in self.reminder_list.values():
if (
keyword.lower() in reminder["title"].lower()
or keyword.lower() in reminder["description"].lower()
):
matched_reminders.append(
{
"title": reminder["title"],
"description": reminder["description"],
"time": reminder["time"].strftime("%Y-%m-%d %H:%M"),
},
)
if not matched_reminders:
return {"status": False, "message": "没有找到包含该关键词的提醒"}
return {"status": True, "reminders": matched_reminders}

View File

@@ -0,0 +1,20 @@
# -*- coding: utf-8 -*-
"""The shared state class for ACEBench simulation tools."""
class SharedState:
"""The sharing state class for ACEBench simulation tools."""
def __init__(self, shared_state: dict) -> None:
"""Initialize the shared state"""
self._shared_state = shared_state
@property
def wifi(self) -> bool:
"""The WI-FI state"""
return self._shared_state["wifi"]
@property
def logged_in(self) -> bool:
"""The logged in state"""
return self._shared_state["logged_in"]

View File

@@ -0,0 +1,834 @@
# -*- coding: utf-8 -*-
# type: ignore
# pylint: disable=too-many-lines
# pylint: disable=too-many-statements
# pylint: disable=too-many-branches
# pylint: disable=too-many-statements
# pylint: disable=too-many-return-statements
"""The travel API for the ACEBench simulation tools in AgentScope."""
from datetime import datetime, timedelta
class TravelApi:
"""旅行预订系统类。
提供航班查询、用户认证、预订管理等功能的旅行系统。
支持直飞和中转航班查询、航班预订、预订修改和取消等功能。
"""
tool_functions: list[str] = [
"get_user_details",
"get_flight_details",
"get_reservation_details",
"reserve_flight",
"cancel_reservation",
"modify_flight",
]
def __init__(self) -> None:
"""初始化旅行系统。
设置用户档案和航班信息,包含用户信息、航班数据和预订记录。
"""
# 初始化用户信息
self.users = {
"user1": {
"user_name": "Eve",
"password": "password123",
"cash_balance": 2000.0,
"bank_balance": 50000.0,
"membership_level": "regular",
},
"user2": {
"user_name": "Frank",
"password": "password456",
"cash_balance": 8000.0,
"bank_balance": 8000.0,
"membership_level": "silver",
},
"user3": {
"user_name": "Grace",
"password": "password789",
"cash_balance": 1000.0,
"bank_balance": 5000.0,
"membership_level": "gold",
},
}
# 初始化航班信息
self.flights = [
{
"flight_no": "CA1234",
"origin": "北京",
"destination": "上海",
"depart_time": "2024-07-15 08:00:00",
"arrival_time": "2024-07-15 10:30:00",
"status": "available",
"seats_available": 5,
"economy_price": 1200,
"business_price": 3000,
},
{
"flight_no": "MU5678",
"origin": "上海",
"destination": "北京",
"depart_time": "2024-07-16 09:00:00",
"arrival_time": "2024-07-16 11:30:00",
"status": "available",
"seats_available": 3,
"economy_price": 1900,
"business_price": 3000,
},
{
"flight_no": "CZ4321",
"origin": "上海",
"destination": "北京",
"depart_time": "2024-07-16 20:00:00",
"arrival_time": "2024-07-16 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 2500,
"business_price": 4000,
},
{
"flight_no": "CZ4352",
"origin": "上海",
"destination": "北京",
"depart_time": "2024-07-17 20:00:00",
"arrival_time": "2024-07-17 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1600,
"business_price": 2500,
},
{
"flight_no": "MU3561",
"origin": "北京",
"destination": "南京",
"depart_time": "2024-07-18 08:00:00",
"arrival_time": "2024-07-18 10:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 4000,
},
{
"flight_no": "MU1566",
"origin": "北京",
"destination": "南京",
"depart_time": "2024-07-18 20:00:00",
"arrival_time": "2024-07-18 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 4000,
},
{
"flight_no": "CZ1765",
"origin": "南京",
"destination": "深圳",
"depart_time": "2024-07-17 20:30:00",
"arrival_time": "2024-07-17 22:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "CZ1765",
"origin": "南京",
"destination": "深圳",
"depart_time": "2024-07-18 12:30:00",
"arrival_time": "2024-07-18 15:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "MH1765",
"origin": "厦门",
"destination": "成都",
"depart_time": "2024-07-17 12:30:00",
"arrival_time": "2024-07-17 15:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "MH2616",
"origin": "成都",
"destination": "厦门",
"depart_time": "2024-07-18 18:30:00",
"arrival_time": "2024-07-18 21:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
{
"flight_no": "MH2616",
"origin": "成都",
"destination": "福州",
"depart_time": "2024-07-16 18:30:00",
"arrival_time": "2024-07-16 21:00:00",
"status": "available",
"seats_available": 8,
"economy_price": 1500,
"business_price": 2500,
},
]
# 初始化预订列表
self.reservations = [
{
"reservation_id": "res_1",
"user_id": "user1",
"flight_no": "CA1234",
"payment_method": "bank",
"cabin": "经济舱",
"baggage": 1,
"origin": "北京",
"destination": "上海",
},
{
"reservation_id": "res_2",
"user_id": "user1",
"flight_no": "MU5678",
"payment_method": "bank",
"cabin": "商务舱",
"baggage": 1,
"origin": "上海",
"destination": "北京",
},
{
"reservation_id": "res_3",
"user_id": "user2",
"flight_no": "MH1765",
"payment_method": "bank",
"cabin": "商务舱",
"baggage": 1,
"origin": "厦门",
"destination": "成都",
},
{
"reservation_id": "res_4",
"user_id": "user2",
"flight_no": "MU2616",
"payment_method": "bank",
"cabin": "商务舱",
"baggage": 1,
"origin": "成都",
"destination": "厦门",
},
]
def get_state_dict(self) -> dict:
"""Get the current state dict of the TravelApi."""
return {
"Travel": {
"users": self.users,
"reservations": self.reservations,
},
}
# 根据出发地和到达地查询航班
def get_flight_details(
self,
origin: str = None,
destination: str = None,
) -> list[dict] | str:
"""根据出发地和到达地查询航班的基本信息。
Args:
origin (str, optional): 出发地城市名称。默认为None。
destination (str, optional): 目的地城市名称。默认为None。
Returns:
list[dict] | str: 符合条件的航班列表或无航班的提示信息。
"""
flights = self.flights
# 过滤出发地
if origin:
flights = [
flight for flight in flights if flight["origin"] == origin
]
# 过滤到达地
if destination:
flights = [
flight
for flight in flights
if flight["destination"] == destination
]
if len(flights) == 0:
return "没有符合条件的直达航班"
# 返回查询结果
return [
{
"flight_no": flight["flight_no"],
"origin": flight["origin"],
"destination": flight["destination"],
"depart_time": flight["depart_time"],
"arrival_time": flight["arrival_time"],
"status": flight["status"],
"seats_available": flight["seats_available"],
"economy_price": flight["economy_price"],
"business_price": flight["business_price"],
}
for flight in flights
]
def get_user_details(self, user_id: str, password: str) -> dict:
"""根据用户名和密码查询用户信息。
Args:
user_id (str): 用户ID。
password (str): 用户密码。
Returns:
dict: 用户信息字典(不包含密码)或错误信息。
"""
user = self.users.get(user_id)
if user and user["password"] == password:
return {
key: value for key, value in user.items() if key != "password"
}
return {"status": "error", "message": "用户名或密码不正确"}
def get_reservation_details(
self,
reservation_id: str = None,
user_id: str = None,
) -> list[dict] | dict:
"""根据预订ID或用户ID查询预订信息包括对应航班的基本信息。
Args:
reservation_id (str, optional): 预订ID。默认为None。
user_id (str, optional): 用户ID。默认为None。
Returns:
`list[dict] | dict`:
详细预订信息列表或错误信息字典。
"""
# 根据预订ID或用户ID筛选预订信息
if reservation_id:
reservations = [
reservation
for reservation in self.reservations
if reservation["reservation_id"] == reservation_id
]
elif user_id:
reservations = [
reservation
for reservation in self.reservations
if reservation["user_id"] == user_id
]
else:
return {"status": "error", "message": "请提供有效的预订ID或用户ID"}
# 对每个预订,附加航班信息
detailed_reservations = []
for reservation in reservations:
flight_info = next(
(
flight
for flight in self.flights
if flight["flight_no"] == reservation["flight_no"]
),
None,
)
detailed_reservation = {**reservation, "flight_info": flight_info}
detailed_reservations.append(detailed_reservation)
return detailed_reservations
def authenticate_user(self, user_id: str, password: str) -> dict:
"""验证用户身份。
Args:
user_id (str): 用户ID。
password (str): 用户密码。
Returns:
`dict`:
用户信息字典或错误信息字典。
"""
user = self.users.get(user_id)
if user and user["password"] == password:
return user
return {"status": "error", "message": "用户名或密码不正确"}
def get_baggage_allowance(
self,
membership_level: str,
cabin_class: str,
) -> int:
"""获取用户基于会员等级和舱位的免费托运行李限额。
Args:
membership_level (str): 会员等级 ("regular", "silver", "gold")。
cabin_class (str): 舱位 ("基础经济舱", "经济舱", "商务舱")。
Returns:
int: 免费托运行李数量。
"""
allowance = {
"regular": {"经济舱": 1, "商务舱": 2},
"silver": {"经济舱": 2, "商务舱": 3},
"gold": {"经济舱": 3, "商务舱": 3},
}
return allowance.get(membership_level, {}).get(cabin_class, 0)
def find_transfer_flights(
self,
origin_city: str,
transfer_city: str,
destination_city: str,
) -> list[dict] | str:
"""查找从出发城市到目的地城市的中转航班。
确保第一班航班降落时间早于第二班航班起飞时间。
Args:
origin_city (str): 出发城市。
transfer_city (str): 中转城市。
destination_city (str): 到达城市。
Returns:
list[dict] | str:
满足条件的中转航班列表,每个航班包含两段航程的信息,或无航班提示。
"""
# 获取从出发城市到中转城市的航班
first_leg_flights: list[dict] = [
flight
for flight in self.flights
if flight["origin"] == origin_city
and flight["destination"] == transfer_city
and flight["status"] == "available"
]
# 获取从中转城市到目的地城市的航班
second_leg_flights = [
flight
for flight in self.flights
if flight["origin"] == transfer_city
and flight["destination"] == destination_city
and flight["status"] == "available"
]
# 存储符合条件的中转航班
transfer_flights = []
# 遍历第一段航班和第二段航班,查找符合时间条件的组合
for first_flight in first_leg_flights:
first_arrival = datetime.strptime(
first_flight["arrival_time"],
"%Y-%m-%d %H:%M:%S",
)
for second_flight in second_leg_flights:
second_departure = datetime.strptime(
str(second_flight["depart_time"]),
"%Y-%m-%d %H:%M:%S",
)
# 检查第一班航班降落时间早于第二班航班起飞时间
if first_arrival < second_departure:
transfer_flights.append(
{
"first_leg": first_flight,
"second_leg": second_flight,
},
)
# 返回符合条件的中转航班列表
if transfer_flights:
return transfer_flights
else:
return "未找到符合条件的中转航班。"
def calculate_baggage_fee(
self,
membership_level: str,
cabin_class: str,
baggage_count: int,
) -> float:
"""计算行李费用。
Args:
membership_level (str): 会员等级。
cabin_class (str): 舱位等级。
baggage_count (int): 行李数量。
Returns:
float: 额外行李费用。
"""
free_baggage = {
"regular": {"经济舱": 1, "商务舱": 2},
"silver": {"经济舱": 2, "商务舱": 3},
"gold": {"经济舱": 3, "商务舱": 3},
}
free_limit = free_baggage[membership_level][cabin_class]
additional_baggage = max(baggage_count - free_limit, 0)
return additional_baggage * 50
def update_balance(
self,
user: dict,
payment_method: str,
amount: float,
) -> bool:
"""更新用户的余额。
Args:
user (dict): 用户信息字典。
payment_method (str): 支付方式("cash""bank")。
amount (float): 更新金额(正数表示增加,负数表示减少)。
Returns:
bool: 如果余额充足且更新成功,返回 True否则返回 False。
"""
if payment_method == "cash":
if user["cash_balance"] + amount < 0:
return False # 余额不足
user["cash_balance"] += amount
elif payment_method == "bank":
if user["bank_balance"] + amount < 0:
return False # 余额不足
user["bank_balance"] += amount
return True
def reserve_flight(
self,
user_id: str,
password: str,
flight_no: str,
cabin: str,
payment_method: str,
baggage_count: int,
) -> str:
"""预订航班。
Args:
user_id (str): 用户ID。
password (str): 用户密码。
flight_no (str): 航班号。
cabin (str): 舱位等级。
payment_method (str): 支付方式。
baggage_count (int): 行李数量。
Returns:
str: 预订结果信息。
"""
user = self.authenticate_user(user_id, password)
if not user:
return "认证失败请检查用户ID和密码。"
# 检查航班和座位
flight = next(
(
f
for f in self.flights
if f["flight_no"] == flight_no and f["status"] == "available"
),
None,
)
# 计算航班价格
price: int = (
flight["economy_price"]
if cabin == "经济舱"
else flight["business_price"]
)
total_cost = price
# 计算行李费用
baggage_fee = self.calculate_baggage_fee(
user["membership_level"],
cabin,
baggage_count,
)
total_cost += baggage_fee
# 检查支付方式
if payment_method not in ["cash", "bank"]:
return "支付方式无效"
# 更新预定后的余额
if payment_method == "cash":
if total_cost > self.users.get(user_id)["cash_balance"]:
return "cash余额不足请考虑换一种支付方式"
self.users.get(user_id)["cash_balance"] -= total_cost
else:
if total_cost > self.users.get(user_id)["bank_balance"]:
return "bank余额不足请考虑换一种支付方式"
self.users.get(user_id)["bank_balance"] -= total_cost
# 更新航班信息并生成预订
flight["seats_available"] -= 1
reservation_id = f"res_{len(self.reservations) + 1}"
reservation = {
"reservation_id": reservation_id,
"user_id": user_id,
"flight_no": flight_no,
"payment_method": payment_method,
"cabin": cabin,
"baggage": baggage_count,
}
self.reservations.append(reservation)
return f"预订成功,预订号:{reservation_id}" f"总费用:{total_cost}元(包含行李费用)。"
def modify_flight(
self,
user_id: str,
reservation_id: str,
new_flight_no: str = None,
new_cabin: str = None,
add_baggage: int = 0,
new_payment_method: str = None,
) -> str:
"""修改航班预订,包括更改航班、舱位和行李。
Args:
user_id (str): 用户ID。
reservation_id (str): 预订ID。
new_flight_no (str, optional): 新的航班号。默认为None。
new_cabin (str, optional): 新的舱位。默认为None。
add_baggage (int, optional): 新增托运行李的数量。默认为0。
new_payment_method (str, optional): 新的付款方式。默认为None。
Returns:
str: 修改结果信息。
"""
# 获取对应的预订
reservation = next(
(
r
for r in self.reservations
if r["reservation_id"] == reservation_id
and r["user_id"] == user_id
),
None,
)
if not reservation:
return "预订未找到或用户ID不匹配。"
# 检查当前预订的航班信息
current_flight = next(
(
f
for f in self.flights
if f["flight_no"] == reservation["flight_no"]
),
None,
)
if not current_flight:
return "航班信息未找到。"
# 获取原始支付方式或新提供的支付方式
payment_method = (
new_payment_method
if new_payment_method
else reservation["payment_method"]
)
user = self.users[user_id]
if not user:
return "用户信息未找到。"
# 存储处理结果
result_messages = []
if new_flight_no and new_flight_no != reservation["flight_no"]:
# 更新航班号(若提供)但必须匹配出发地和目的地
new_flight = next(
(f for f in self.flights if f["flight_no"] == new_flight_no),
None,
)
if (
new_flight
and new_flight["origin"] == current_flight["origin"]
and new_flight["destination"] == current_flight["destination"]
):
reservation["flight_no"] = new_flight_no
result_messages.append("航班号已更改。")
else:
return "航班更改失败:新的航班号无效或目的地不匹配。"
# 更新舱位(若提供)并计算价格差价
if new_cabin and new_cabin != reservation.get("cabin"):
price_difference = self.calculate_price_difference(
current_flight,
reservation["cabin"],
new_cabin,
)
reservation["cabin"] = new_cabin
if price_difference > 0:
# 扣除差价
if self.update_balance(
user,
payment_method,
-price_difference,
):
result_messages.append(
f"舱位更改成功。已支付差价: {price_difference}",
)
else:
result_messages.append("余额不足,无法支付舱位差价。")
elif price_difference < 0:
# 退款
self.update_balance(user, payment_method, -price_difference)
result_messages.append(f"舱位更改成功。已退款差价: {-price_difference}")
# 增加托运行李,检查免费限额和计算费用
if add_baggage > 0:
membership = user["membership_level"]
max_free_baggage = self.get_baggage_allowance(
membership,
reservation["cabin"],
)
current_baggage = reservation.get("baggage", 0)
total_baggage = current_baggage + add_baggage
extra_baggage = max(0, total_baggage - max_free_baggage)
baggage_cost = extra_baggage * 50
if baggage_cost > 0:
# 扣除行李费用
if self.update_balance(user, payment_method, -baggage_cost):
result_messages.append(
f"行李已增加。需支付额外费用: {baggage_cost}",
)
else:
result_messages.append("余额不足,无法支付额外行李费用。")
reservation["baggage"] = total_baggage
# 返回最终结果
if not result_messages:
result_messages.append("修改完成,无需额外费用。")
return " ".join(result_messages)
def cancel_reservation(
self,
user_id: str,
reservation_id: str,
reason: str,
) -> str:
"""取消预订。
Args:
user_id (str): 用户ID。
reservation_id (str): 预订ID。
reason (str): 取消原因。
Returns:
str: 取消结果信息。
"""
# 设置默认当前时间为 2024年7月14日早上6点
current_time = datetime(2024, 7, 14, 6, 0, 0)
# 验证用户和预订是否存在
user = self.users.get(user_id, None)
if not user:
return "用户ID无效。"
reservation = next(
(
r
for r in self.reservations
if r["reservation_id"] == reservation_id
and r["user_id"] == user_id
),
None,
)
if not reservation:
return "预订ID无效或与该用户无关。"
# 检查航班信息是否存在
flight = next(
(
f
for f in self.flights
if f["flight_no"] == reservation["flight_no"]
),
None,
)
if not flight:
return "航班信息无效。"
# 检查航班是否已起飞
depart_time = datetime.strptime(
flight["depart_time"],
"%Y-%m-%d %H:%M:%S",
)
if current_time > depart_time:
return "航段已使用,无法取消。"
# 计算距离出发时间
time_until_departure = depart_time - current_time
cancel_fee = 0
refund_amount = 0
# 获取航班价格
flight_price = (
flight["economy_price"]
if reservation["cabin"] == "经济舱"
else flight["business_price"]
)
# 取消政策及退款计算
if reason == "航空公司取消航班":
# 航空公司取消航班,全额退款
refund_amount = flight_price
self.process_refund(user, refund_amount)
return f"航班已取消,您的预订将被免费取消,已退款{refund_amount}元。"
elif time_until_departure > timedelta(days=1):
# 离出发时间超过24小时免费取消
refund_amount = flight_price
self.process_refund(user, refund_amount)
return f"距离出发时间超过24小时免费取消成功已退款{refund_amount}元。"
else:
# 若不符合免费取消条件,可根据需求设置取消费
cancel_fee = flight_price * 0.1 # 假设取消费为票价的10%
refund_amount = flight_price - cancel_fee
self.process_refund(user, refund_amount)
return f"距离出发时间不足24小时已扣除取消费{cancel_fee}元,退款{refund_amount}元。"
def process_refund(self, user: dict, amount: float) -> str:
"""将退款金额添加到用户的现金余额中。
Args:
user (dict): 用户信息字典。
amount (float): 退款金额。
"""
user["cash_balance"] += amount
return f"已成功处理退款,{user['user_name']}的现金余额增加了{amount}元。"
def calculate_price_difference(
self,
flight: dict,
old_cabin: str,
new_cabin: str,
) -> float:
"""计算舱位价格差异。
Args:
flight (dict): 航班信息字典。
old_cabin (str): 原舱位等级。
new_cabin (str): 新舱位等级。
Returns:
float: 价格差异(正数表示需支付差价,负数表示退款)。
"""
cabin_prices = {
"经济舱": flight["economy_price"],
"商务舱": flight["business_price"],
}
old_price = cabin_prices.get(old_cabin, 0)
new_price = cabin_prices.get(new_cabin, 0)
return new_price - old_price

View File

@@ -0,0 +1,122 @@
# -*- coding: utf-8 -*-
"""The Chinese tools for ACEBench evaluation."""
from functools import wraps
from typing import Callable, Any
from ._ace_tools_api import (
ReminderApi,
FoodPlatformApi,
TravelApi,
MessageApi,
)
from ...message import TextBlock
from ...tool import ToolResponse
def _tool_function_wrapper(get_tool_function: Callable) -> Callable:
"""Wrap the tool function result to be ToolResponse."""
@wraps(get_tool_function)
def wrapper(self: "ACEPhone", name: str) -> Callable:
"""Wrap the tool function to return ToolResponse."""
tool_function = get_tool_function(self, name)
@wraps(tool_function)
def wrapper_tool_function(*args: Any, **kwargs: Any) -> ToolResponse:
"""The wrapped tool function"""
res = tool_function(*args, **kwargs)
return ToolResponse(
content=[
TextBlock(
type="text",
text=str(res),
),
],
)
return wrapper_tool_function
return wrapper
class ACEPhone:
"""Simulate a user phone with various apps and functionalities in
ACEBench. The code is implemented with reference to the
`ACEBench <https://github.com/ACEBench/ACEBench>`_.
"""
def __init__(self) -> None:
"""Initialize the shared state and apps for the ACEPhone."""
self._state = {
"wifi": False,
"logged_in": False,
}
self._message_app = MessageApi(self._state)
self._reminder_app = ReminderApi(self._state)
self._food_platform_app = FoodPlatformApi(self._state)
self._travel = TravelApi()
def turn_on_wifi(self) -> dict[str, bool | str]:
"""开启WiFi连接。"""
self._state["wifi"] = True
return {"status": True, "message": "wifi已经打开"}
def login_device(self) -> dict[str, bool | str]:
"""登录设备。"""
self._state["logged_in"] = True
return {"status": True, "message": "设备已经登录"}
def load_initial_config(self, initial_config: dict) -> None:
"""Load the initial config from the application configuration."""
# Empty initial config
if len(initial_config) == 0:
return
# Fix the typo in ACEBench by renaming "Baspi" to "BaseApi"
if "Baspi" in initial_config:
initial_config["BaseApi"] = initial_config.pop("Baspi")
# Verify state
assert (
"BaseApi" in initial_config
and "wifi" in initial_config["BaseApi"]
and "logged_in" in initial_config["BaseApi"]
), f"Invalid initial config: {initial_config}"
self._state["wifi"] = initial_config["BaseApi"]["wifi"]
self._state["logged_in"] = initial_config["BaseApi"]["logged_in"]
def get_current_state(self) -> list[dict]:
"""Follow ACEBench to get the current state of the ACEPhone."""
return [
{"BaseApi": self._state},
self._message_app.get_state_dict(),
self._reminder_app.get_state_dict(),
self._food_platform_app.get_state_dict(),
self._travel.get_state_dict(),
]
@_tool_function_wrapper
def get_tool_function(self, name: str) -> Callable:
"""Get a tool function by name."""
if name in [
"turn_on_wifi",
"login_device",
]:
return getattr(self, name)
if name in self._message_app.tool_functions:
return getattr(self._message_app, name)
if name in self._food_platform_app.tool_functions:
return getattr(self._food_platform_app, name)
if name in self._reminder_app.tool_functions:
return getattr(self._reminder_app, name)
if name in self._travel.tool_functions:
return getattr(self._travel, name)
raise ValueError(
f"Tool function '{name}' not found in ACEPhone.",
)

View File

@@ -0,0 +1,43 @@
# -*- coding: utf-8 -*-
"""The base class for benchmark evaluation."""
from abc import ABC, abstractmethod
from typing import Generator
from ._task import Task
class BenchmarkBase(ABC):
"""The base class for benchmark evaluation."""
name: str
"""The name of the benchmark."""
description: str
"""The description of the benchmark."""
def __init__(self, name: str, description: str) -> None:
"""Initialize the benchmark.
Args:
name (`str`):
The name of the benchmark.
description (`str`):
A brief description of the benchmark.
"""
self.name = name
self.description = description
@abstractmethod
def __iter__(self) -> Generator[Task, None, None]:
"""Iterate over the benchmark."""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
def __len__(self) -> int:
"""Get the length of the benchmark."""
raise NotImplementedError("Subclasses must implement this method.")
@abstractmethod
def __getitem__(self, index: int) -> Task:
"""Get the task at the given index."""
raise NotImplementedError("Subclasses must implement this method.")

View File

@@ -0,0 +1,12 @@
# -*- coding: utf-8 -*-
"""The evaluator module in AgentScope."""
from ._evaluator_base import EvaluatorBase
from ._ray_evaluator import RayEvaluator
from ._general_evaluator import GeneralEvaluator
__all__ = [
"EvaluatorBase",
"RayEvaluator",
"GeneralEvaluator",
]

View File

@@ -0,0 +1,192 @@
# -*- coding: utf-8 -*-
"""The base class for evaluator in evaluation."""
import collections
import json
from abc import abstractmethod
from typing import Callable, Coroutine, Any
from .._solution import SolutionOutput
from .._task import Task
from .._benchmark_base import BenchmarkBase
from .._evaluator_storage import EvaluatorStorageBase
from .._metric_base import MetricType
from ..._utils._common import _get_timestamp
class EvaluatorBase:
"""The class that runs the evaluation process."""
def __init__(
self,
name: str,
benchmark: BenchmarkBase,
n_repeat: int,
storage: EvaluatorStorageBase,
) -> None:
"""Initialize the evaluator.
Args:
name (`str`):
The name of this evaluator.
benchmark: (`BenchmarkBase`):
A benchmark instance inheriting from `BenchmarkBase` that
defines the evaluation dataset.
n_repeat (`int`):
How many times to repeat the evaluation for each task.
storage (`EvaluatorStorageBase`):
A instance inheriting from the child class of
`EvaluatorStorageBase` that supports storing and loading
solution output and evaluation results.
"""
self.name = name
self.benchmark = benchmark
self.n_repeat = n_repeat
self.storage = storage
@abstractmethod
async def run(
self,
solution: Callable[
[Task, Callable],
Coroutine[Any, Any, SolutionOutput],
],
) -> None:
"""Run the evaluation and return the results.
Args:
solution (`Callable[[Task, Callable], Coroutine[Any, Any, \
SolutionOutput]]`):
A async function that takes a `Task` instance and a pre-hook
as input and returns a `SolutionOutput` instance.
"""
async def _save_evaluation_meta(self) -> None:
"""Save the evaluation meta information."""
self.storage.save_evaluation_meta(
{
"evaluation_name": self.name,
"created_at": _get_timestamp(),
"total_repeats": self.n_repeat,
"benchmark": {
"name": self.benchmark.name,
"description": self.benchmark.description,
"total_tasks": len(self.benchmark),
},
"schema_version": 1,
},
)
async def aggregate(self) -> None: # pylint: disable=too-many-branches
"""Aggregate the evaluation results and save an overall result."""
meta_info: dict = {
"total_tasks": len(self.benchmark),
"total_repeats": self.n_repeat,
"repeats": {},
"schema_version": 1,
}
for repeat_index in range(self.n_repeat):
repeat_id = str(repeat_index)
current_repeat: dict = {
"completed_tasks": 0,
"incomplete_tasks": 0,
"metrics": {},
"completed_ids": [],
"incomplete_ids": [],
}
for task in self.benchmark:
for metric in task.metrics:
# Create a new dict in aggregated_result
if metric.name not in current_repeat["metrics"]:
current_repeat["metrics"][metric.name] = {
"type": metric.metric_type,
"involved_tasks": 0,
"completed_tasks": 0,
"incomplete_tasks": 0,
"aggregation": {},
"distribution": collections.defaultdict(list),
}
# Record the submitted task
current_repeat["metrics"][metric.name][
"involved_tasks"
] += 1
# Not finished
if not self.storage.evaluation_result_exists(
task.id,
repeat_id,
metric.name,
):
if task.id not in current_repeat["incomplete_ids"]:
current_repeat["incomplete_tasks"] += 1
current_repeat["incomplete_ids"].append(task.id)
current_repeat["metrics"][metric.name][
"incomplete_tasks"
] += 1
continue
if task.id not in current_repeat["completed_ids"]:
current_repeat["completed_tasks"] += 1
current_repeat["completed_ids"].append(task.id)
current_repeat["metrics"][metric.name][
"completed_tasks"
] += 1
# Get the evaluation result
eval_result = self.storage.get_evaluation_result(
task.id,
repeat_id,
metric.name,
)
# Record the metric result
if metric.metric_type == MetricType.CATEGORY:
current_repeat["metrics"][metric.name]["distribution"][
eval_result.result
].append(
task.id,
)
elif metric.metric_type == MetricType.NUMERICAL:
current_repeat["metrics"][metric.name]["distribution"][
task.id
] = eval_result.result
print("Repeat ID:", repeat_id)
for metric, value in current_repeat["metrics"].items():
print("\tMetric:", metric)
print("\t\tType:", value["type"])
print("\t\tInvolved tasks:", value["involved_tasks"])
print("\t\tCompleted tasks:", value["completed_tasks"])
print("\t\tIncomplete tasks:", value["incomplete_tasks"])
if value["type"] == MetricType.CATEGORY:
# Count the distribution
for category, task_ids in value["distribution"].items():
value["aggregation"][category] = (
len(task_ids) * 1.0 / value["involved_tasks"]
)
elif value["type"] == MetricType.NUMERICAL:
scores = list(value["distribution"].values())
value["aggregation"] = {
"mean": sum(scores) / value["involved_tasks"],
"max": max(scores),
"min": min(scores),
}
print(
"\t\tAggregation:",
json.dumps(
value["aggregation"],
indent=4,
ensure_ascii=False,
).replace("\n", "\n\t\t"),
)
meta_info["repeats"][repeat_id] = current_repeat
# save
self.storage.save_aggregation_result(meta_info)

View File

@@ -0,0 +1,127 @@
# -*- coding: utf-8 -*-
"""General evaluator implementation in AgentScope, which is easy to debug
compared to the RayEvaluator."""
from typing import Callable, Awaitable, Coroutine, Any
from ._evaluator_base import EvaluatorBase
from .._evaluator_storage import EvaluatorStorageBase
from .._task import Task
from .._solution import SolutionOutput
from .._benchmark_base import BenchmarkBase
class GeneralEvaluator(EvaluatorBase):
"""The general evaluator that support users to debug their evaluation"""
def __init__(
self,
name: str,
benchmark: BenchmarkBase,
n_repeat: int,
storage: EvaluatorStorageBase,
n_workers: int,
) -> None:
"""Initialize the evaluator."""
super().__init__(
name=name,
benchmark=benchmark,
n_repeat=n_repeat,
storage=storage,
)
assert isinstance(benchmark, BenchmarkBase)
assert n_repeat >= 1, "n_repeat must be at least 1"
assert n_workers >= 1, "n_workers must be at least 1"
self.benchmark = benchmark
self.n_repeat = n_repeat
self.n_workers = n_workers
async def run_evaluation(
self,
task: Task,
repeat_id: str,
solution_output: SolutionOutput,
) -> None:
"""Run the evaluation for a task and solution result."""
evaluation_results = await task.evaluate(solution_output)
# store the evaluation result
for result in evaluation_results:
self.storage.save_evaluation_result(
task_id=task.id,
repeat_id=repeat_id,
evaluation=result,
)
async def run_solution(
self,
repeat_id: str,
task: Task,
solution: Callable[[Task, Callable], Awaitable[SolutionOutput]],
) -> None:
"""Generate a solution to a task and evaluate."""
if self.storage.solution_result_exists(task.id, repeat_id):
# Obtain from storage
solution_result = self.storage.get_solution_result(
task.id,
repeat_id,
)
else:
# Run the solution
solution_result = await solution(
task,
self.storage.get_agent_pre_print_hook(
task.id,
repeat_id,
),
)
self.storage.save_solution_result(
task.id,
repeat_id,
solution_result,
)
# Evaluate the solution with the
for metric in task.metrics:
if not self.storage.evaluation_result_exists(
task.id,
repeat_id,
metric.name,
):
await self.run_evaluation(
task,
repeat_id,
solution_result,
)
async def run(
self,
solution: Callable[
[Task, Callable],
Coroutine[Any, Any, SolutionOutput],
],
) -> None:
"""Run the ray-based distributed and parallel evaluation, and get the
results.
Args:
solution (`Callable[[Task, Callable], Coroutine[Any, Any, \
SolutionOutput]]`):
A async function that takes a `Task` instance and a pre-print
hook function as input, returns a `SolutionOutput` instance.
"""
await self._save_evaluation_meta()
for repeat_id in range(self.n_repeat):
for task in self.benchmark:
await self.run_solution(
str(repeat_id),
task,
solution,
)
await self.aggregate()

View File

@@ -0,0 +1,211 @@
# -*- coding: utf-8 -*-
"""The evaluator base class in agentscope."""
import asyncio
from typing import Callable, Awaitable, Coroutine, Any
from .._benchmark_base import BenchmarkBase
from .._evaluator._evaluator_base import EvaluatorBase
from .._solution import SolutionOutput
from .._task import Task
from .._evaluator_storage import EvaluatorStorageBase
def _check_ray_available() -> None:
"""Check if ray is available and raise ImportError if not."""
try:
import ray # noqa # pylint: disable=unused-import
except ImportError as e:
raise ImportError(
"Ray is not installed. Please install it with `pip install ray` "
"to use the RayEvaluator.",
) from e
# Create a conditional decorator for ray.remote
def _ray_remote_decorator(cls: Any) -> Any:
"""
Conditional ray.remote decorator that only applies when ray is available.
"""
try:
import ray
return ray.remote(cls)
except ImportError:
return cls
@_ray_remote_decorator
class RayEvaluationActor:
"""
Actor class for running evaluation with ray remote.
"""
@staticmethod
async def run(
storage: EvaluatorStorageBase,
task: Task,
repeat_id: str,
solution_output: SolutionOutput,
) -> None:
"""
Run the evaluation for a task and solution result.
Args:
storage (EvaluatorStorageBase): Evaluator storage.
task (Task): Task to be evaluated.
repeat_id (str): Repeat ID
solution_output (SolutionOutput): output data after execute agents.
"""
evaluation_results = await task.evaluate(solution_output)
# store the evaluation result
for result in evaluation_results:
storage.save_evaluation_result(
task_id=task.id,
repeat_id=repeat_id,
evaluation=result,
)
@_ray_remote_decorator
class RaySolutionActor:
"""
Actor class for running agent solutions with ray remote.
"""
def __init__(self, n_workers: int = 1):
self.eval_actor = RayEvaluationActor.options(
max_concurrency=n_workers,
).remote()
async def run(
self,
storage: EvaluatorStorageBase,
repeat_id: str,
task: Task,
solution: Callable[
[Task, Callable],
Coroutine[Any, Any, SolutionOutput],
],
) -> None:
"""Generate a solution to a task and evaluate.
Args:
storage (EvaluatorStorageBase): Evaluator storage.
repeat_id (str): Repeat ID.
task (Task): Task to be evaluated.
solution
(Callable[[Task, Callable], Awaitable[SolutionOutput, Any]]):
callable function to execute agents and generate results.
"""
if storage.solution_result_exists(task.id, repeat_id):
# Obtain from storage
solution_result = storage.get_solution_result(
task.id,
repeat_id,
)
else:
# Run the solution
solution_result = await solution(
task,
storage.get_agent_pre_print_hook(
task.id,
repeat_id,
),
)
storage.save_solution_result(
task.id,
repeat_id,
solution_result,
)
# Evaluate the solution with the
futures = []
for metric in task.metrics:
if not storage.evaluation_result_exists(
task.id,
repeat_id,
metric.name,
):
futures.append(
self.eval_actor.run.remote(
storage,
task,
repeat_id,
solution_result,
),
)
if futures:
await asyncio.gather(*futures)
class RayEvaluator(EvaluatorBase):
"""The ray-based evaluator that supports distributed and parallel
evaluation."""
def __init__(
self,
name: str,
benchmark: BenchmarkBase,
n_repeat: int,
storage: EvaluatorStorageBase,
n_workers: int,
) -> None:
"""Initialize the evaluator."""
super().__init__(
name=name,
benchmark=benchmark,
n_repeat=n_repeat,
storage=storage,
)
# Check ray availability early
_check_ray_available()
assert isinstance(benchmark, BenchmarkBase)
assert n_repeat >= 1, "n_repeat must be at least 1"
assert n_workers >= 1, "n_workers must be at least 1"
self.benchmark = benchmark
self.n_repeat = n_repeat
self.n_workers = n_workers
async def run(
self,
solution: Callable[
[Task, Callable],
Awaitable[SolutionOutput] | SolutionOutput,
],
) -> None:
"""Run the ray-based distributed and parallel evaluation, and get the
results.
Args:
solution (`Callable[[Task], SolutionOutput]`):
A sync or async function that takes a `Task` instance as input
and returns a `SolutionOutput` instance.
"""
await self._save_evaluation_meta()
futures = []
solution_actor = RaySolutionActor.options(
max_concurrency=self.n_workers,
).remote(n_workers=self.n_workers)
for repeat_id in range(self.n_repeat):
for task in self.benchmark:
futures.append(
solution_actor.run.remote(
self.storage,
str(repeat_id),
task,
solution,
),
)
if futures:
await asyncio.gather(*futures)
await self.aggregate()

View File

@@ -0,0 +1,10 @@
# -*- coding: utf-8 -*-
"""The evaluator storage module in AgentScope."""
from ._evaluator_storage_base import EvaluatorStorageBase
from ._file_evaluator_storage import FileEvaluatorStorage
__all__ = [
"EvaluatorStorageBase",
"FileEvaluatorStorage",
]

View File

@@ -0,0 +1,195 @@
# -*- coding: utf-8 -*-
"""The evaluator storage base class for storing solution and evaluation
results."""
from abc import abstractmethod
from typing import Any, Callable
from .._metric_base import MetricResult
from .._solution import SolutionOutput
from ...agent import AgentBase
class EvaluatorStorageBase:
"""Used to store the solution results and evaluation results to support
resuming the evaluation process"""
@abstractmethod
def save_solution_result(
self,
task_id: str,
repeat_id: str,
output: SolutionOutput,
**kwargs: Any,
) -> None:
"""Save the solution result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
output (`SolutionOutput`):
The solution output to be saved.
"""
@abstractmethod
def get_evaluation_result(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> MetricResult:
"""Get the evaluation result by the given task id and repeat id
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The metric name.
Returns:
`MetricResult`:
The evaluation result for the given task and repeat ID.
"""
@abstractmethod
def save_evaluation_result(
self,
task_id: str,
repeat_id: str,
evaluation: MetricResult,
**kwargs: Any,
) -> None:
"""Save the evaluation result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
evaluation (`MetricResult`):
The evaluation result to be saved.
"""
@abstractmethod
def get_solution_result(
self,
task_id: str,
repeat_id: str,
**kwargs: Any,
) -> SolutionOutput:
"""Get the solution result for the given task and repeat id.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`SolutionOutput`:
The solution output for the given task and repeat ID.
"""
@abstractmethod
def solution_result_exists(self, task_id: str, repeat_id: str) -> bool:
"""Check if the solution for the given task and repeat is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`bool`:
True if the solution result file exists, False otherwise.
"""
@abstractmethod
def evaluation_result_exists(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> bool:
"""Check if the evaluation result for the given solution and metric
is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The name of the metric.
Returns:
`bool`:
True if the evaluation result file exists, False otherwise.
"""
@abstractmethod
def save_aggregation_result(
self,
aggregation_result: dict,
**kwargs: Any,
) -> None:
"""Save the aggregation result.
Args:
aggregation_result (`dict`):
A dictionary containing the aggregation result.
"""
@abstractmethod
def aggregation_result_exists(
self,
**kwargs: Any,
) -> bool:
"""Check if the aggregation result exists
Returns:
`bool`:
`True` if the aggregation result file exists.
"""
@abstractmethod
def save_evaluation_meta(self, meta_info: dict) -> None:
"""Save the evaluation meta information.
Args:
meta_info (`dict`):
A dictionary containing the meta information.
"""
@abstractmethod
def get_agent_pre_print_hook(
self,
task_id: str,
repeat_id: str,
) -> Callable[[AgentBase, dict], None]:
"""Get a pre-print hook function for the agent to save the agent
printing in the evaluation storage.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`Callable[[AgentBase, dict], None]`:
A hook function that takes an `AgentBase` instance and a
keyword arguments dictionary as input, saving the agent's
printing Msg into the evaluation storage.
"""

View File

@@ -0,0 +1,345 @@
# -*- coding: utf-8 -*-
"""A file system based evaluator storage."""
import json
import os
from json import JSONDecodeError
from typing import Any, Callable
from ._evaluator_storage_base import EvaluatorStorageBase
from .._solution import SolutionOutput
from .._metric_base import MetricResult
from ...agent import AgentBase
from ...message import Msg
class FileEvaluatorStorage(EvaluatorStorageBase):
"""File system based evaluator storage, providing methods to save and
retrieve evaluation results. So that the evaluation process can be resumed
from the last saved state.
The files are organized in a directory structure:
- save_dir/
- evaluation_result.json
- evaluation_meta.json
- {task_id}/
- {repeat_id}/
- solution.json
- evaluation/
- {metric_name}.json
"""
SOLUTION_FILE_NAME = "solution.json"
EVALUATION_DIR_NAME = "evaluation"
EVALUATION_RESULT_FILE = "evaluation_result.json"
EVALUATION_META_FILE = "evaluation_meta.json"
AGENT_PRINTING_LOG = "logging.txt"
def __init__(self, save_dir: str) -> None:
"""Initialize the file evaluator storage."""
self.save_dir = save_dir
def _get_save_path(self, task_id: str, repeat_id: str, *args: str) -> str:
"""Get the save path for a given task and repeat ID."""
return os.path.join(self.save_dir, repeat_id, task_id, *args)
def save_solution_result(
self,
task_id: str,
repeat_id: str,
output: SolutionOutput,
**kwargs: Any,
) -> None:
"""Save the solution result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
output (`SolutionOutput`):
The solution output to be saved.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_FILE_NAME,
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(output, f, ensure_ascii=False, indent=4)
def save_evaluation_result(
self,
task_id: str,
repeat_id: str,
evaluation: MetricResult,
**kwargs: Any,
) -> None:
"""Save the evaluation result.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
evaluation (`MetricResult`):
The evaluation result to be saved.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.EVALUATION_DIR_NAME,
f"{evaluation.name}.json",
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(evaluation, f, ensure_ascii=False, indent=4)
def get_evaluation_result(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> MetricResult:
"""Get the evaluation result by the given task id and repeat id
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The metric name.
Returns:
`MetricResult`:
The evaluation result for the given task and repeat ID.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.EVALUATION_DIR_NAME,
f"{metric_name}.json",
)
if not os.path.exists(path_file):
raise FileNotFoundError(path_file)
with open(path_file, "r", encoding="utf-8") as f:
evaluation = json.load(f)
return MetricResult(**evaluation)
def get_solution_result(
self,
task_id: str,
repeat_id: str,
**kwargs: Any,
) -> SolutionOutput:
"""Get the solution result for the given task and repeat id from the
file system.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Raises:
`FileNotFoundError`:
If the solution result file does not exist for the given task
and repeat ID.
Returns:
`SolutionOutput`:
The solution output for the given task and repeat ID.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_FILE_NAME,
)
if not os.path.exists(path_file):
raise FileNotFoundError(
f"Solution result for task {task_id} and repeat {repeat_id} "
"not found.",
)
try:
with open(path_file, "r", encoding="utf-8") as f:
solution_data = json.load(f)
except JSONDecodeError as e:
raise JSONDecodeError(
f"Failed to load JSON from {path_file}: {e.msg}",
e.doc,
e.pos,
) from e
return SolutionOutput(**solution_data)
def solution_result_exists(self, task_id: str, repeat_id: str) -> bool:
"""Check if the solution for the given task and repeat is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`bool`:
True if the solution result file exists, False otherwise.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.SOLUTION_FILE_NAME,
)
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
def evaluation_result_exists(
self,
task_id: str,
repeat_id: str,
metric_name: str,
) -> bool:
"""Check if the evaluation result for the given solution and metric
is finished.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
metric_name (`str`):
The name of the metric.
Returns:
`bool`:
True if the evaluation result file exists, False otherwise.
"""
path_file = self._get_save_path(
task_id,
repeat_id,
self.EVALUATION_DIR_NAME,
f"{metric_name}.json",
)
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
def save_aggregation_result(
self,
aggregation_result: dict,
**kwargs: Any,
) -> None:
"""Save the aggregation result.
Args:
aggregation_result (`dict`):
A dictionary containing the aggregation result.
"""
path_file = os.path.join(
self.save_dir,
self.EVALUATION_RESULT_FILE,
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(aggregation_result, f, ensure_ascii=False, indent=4)
def aggregation_result_exists(
self,
**kwargs: Any,
) -> bool:
"""Check if the aggregation result exists
Returns:
`bool`:
`True` if the aggregation result file exists.
"""
path_file = os.path.join(
self.save_dir,
self.EVALUATION_RESULT_FILE,
)
return os.path.exists(path_file) and os.path.getsize(path_file) > 0
def save_evaluation_meta(self, meta_info: dict) -> None:
"""Save the evaluation meta information.
Args:
meta_info (`dict`):
A dictionary containing the meta information.
"""
path_file = os.path.join(
self.save_dir,
self.EVALUATION_META_FILE,
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "w", encoding="utf-8") as f:
json.dump(meta_info, f, ensure_ascii=False, indent=4)
def get_agent_pre_print_hook(
self,
task_id: str,
repeat_id: str,
) -> Callable[[AgentBase, dict], None]:
"""Get a pre-print hook function for the agent to save the agent
printing in the evaluation storage.
Args:
task_id (`str`):
The task ID.
repeat_id (`str`):
The repeat ID for the task, usually the index of the repeat
evaluation.
Returns:
`Callable[[AgentBase, dict], None]`:
A hook function that takes an `AgentBase` instance and a
keyword arguments dictionary as input, saving the agent's
printing Msg into the evaluation storage.
"""
def pre_print_hook(_agent: AgentBase, kwargs: dict) -> None:
"""Hook function to save agent's printing."""
msg: Msg | None = kwargs.get("msg", None)
last: bool = kwargs.get("last", False)
if msg is None or not last:
return
# Only save the last message
printing_str = []
for block in msg.get_content_blocks():
match block["type"]:
case "text":
printing_str.append(
f"{msg.name}: {block['text']}",
)
case "thinking":
printing_str.append(
f"{msg.name} (thinking): {block['text']}",
)
case _:
block_str = json.dumps(
block,
ensure_ascii=False,
indent=4,
)
if printing_str:
printing_str.append(block_str)
else:
printing_str.append(f"{msg.name}: {block_str}")
path_file = self._get_save_path(
task_id,
repeat_id,
self.AGENT_PRINTING_LOG,
)
os.makedirs(os.path.dirname(path_file), exist_ok=True)
with open(path_file, "a", encoding="utf-8") as f:
f.write("\n".join(printing_str) + "\n")
return pre_print_hook

View File

@@ -0,0 +1,87 @@
# -*- coding: utf-8 -*-
"""The base class for _metric in evaluation."""
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from enum import Enum
from typing import Any
from .._utils._common import _get_timestamp
from .._utils._mixin import DictMixin
from ..types import JSONSerializableObject
@dataclass
class MetricResult(DictMixin):
"""The result of a _metric."""
name: str
"""The metric name."""
result: str | float | int
"""The metric result."""
created_at: str = field(default_factory=_get_timestamp)
"""The timestamp when the metric result was created."""
message: str | None = field(default_factory=lambda: None)
"""An optional message for the metric result, can be used to provide
additional information or context about the result."""
metadata: dict[str, JSONSerializableObject] | None = field(default=None)
"""Optional metadata for the metric result, can be used to store
additional information related to the metric result."""
class MetricType(str, Enum):
"""The metric type enum."""
CATEGORY = "category"
"""The metric result is a category, e.g. "pass" or "fail"."""
NUMERICAL = "numerical"
"""The metric result is a numerical value, e.g. 0.95 or 100."""
class MetricBase(ABC):
"""The base class for _metric in evaluation."""
def __init__(
self,
name: str,
metric_type: MetricType,
description: str | None = None,
categories: list[str] | None = None,
) -> None:
"""Initialize the _metric object.
Args:
name (`str`):
The name of the metric.
metric_type (`MetricType`):
The type of the metric, can be either "category" or
"numerical", which will determine how to display the result.
description (`str`):
The description of the metric.
categories (`list[str] | None`, optional):
The candidate categories. If `metric_type` is "category", the
categories must be provided, otherwise it should be `None`.
"""
self.name = name
self.metric_type = metric_type
self.description = description
if metric_type == MetricType.CATEGORY and categories is None:
raise ValueError(
"Categories must be provided for category metrics.",
)
self.categories = categories
@abstractmethod
async def __call__(
self,
*args: Any,
**kwargs: Any,
) -> MetricResult:
"""The call function to calculate the _metric result"""

View File

@@ -0,0 +1,36 @@
# -*- coding: utf-8 -*-
"""Solution class for evaluation tasks."""
from dataclasses import dataclass, field
from typing import Any
from ..message import (
ToolResultBlock,
ToolUseBlock,
TextBlock,
)
from ..types._json import JSONSerializableObject
from .._utils._mixin import DictMixin
@dataclass
class SolutionOutput(DictMixin):
"""The output of a solution in evaluation task"""
success: bool
"""Indicates whether the solution is executed successfully. When the
solution raise exception, this should be set to False."""
output: JSONSerializableObject
"""The final output of the solution."""
trajectory: list[ToolUseBlock | ToolResultBlock | TextBlock]
"""The tool calls and results trajectory"""
meta: dict[str, Any] | None = field(default_factory=lambda: None)
"""Additional metadata for the solution"""
def __getstate__(self) -> dict[str, Any]:
"""Custom pickling to handle dataclass + DictMixin inheritance."""
return self.__dict__.copy()
def __setstate__(self, state: dict[str, Any]) -> None:
"""Custom unpickling to handle dataclass + DictMixin inheritance."""
self.__dict__.update(state)

View File

@@ -0,0 +1,53 @@
# -*- coding: utf-8 -*-
"""The base class for task in evaluation."""
from dataclasses import dataclass, field
from typing import Any
from ._solution import SolutionOutput
from ._metric_base import MetricBase, MetricResult
from ..types._json import JSONSerializableObject
@dataclass
class Task:
"""The base class for task in evaluation."""
id: str
"""The unique identifier for the task."""
input: JSONSerializableObject
"""The task input, which should be a JSON serializable object."""
ground_truth: JSONSerializableObject
"""The task ground truth if exists, which should be a JSON serializable
object."""
metrics: list[MetricBase]
"""The metrics to evaluate the task, which should be a list of
`MetricBase` objects."""
tags: dict[str, str] | None = field(default_factory=lambda: None)
"""Tags to categorize the task, e.g. `{"difficulty": "easy",
"cate": "math"}`."""
metadata: dict[str, Any] | None = field(
default_factory=lambda: None,
)
"""Additional metadata for the task."""
async def evaluate(self, solution: SolutionOutput) -> list[MetricResult]:
"""Evaluate the task with the given solution.
Args:
solution (`SolutionOutput`):
The solution to evaluate the task with.
Returns:
`MetricResult`:
The result of the evaluation.
"""
evaluations = []
for metric in self.metrics:
result = await metric(solution)
evaluations.append(result)
return evaluations

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The exception module in agentscope."""
from ._exception_base import AgentOrientedExceptionBase
from ._tool import (
ToolInterruptedError,
ToolNotFoundError,
ToolInvalidArgumentsError,
)
__all__ = [
"AgentOrientedExceptionBase",
"ToolInterruptedError",
"ToolNotFoundError",
"ToolInvalidArgumentsError",
]

View File

@@ -0,0 +1,18 @@
# -*- coding: utf-8 -*-
"""The base exception class in agentscope."""
class AgentOrientedExceptionBase(Exception):
"""The base class for all agent-oriented exceptions. These exceptions are
expect to the captured and exposed to the agent during runtime, so that
agents can handle the error appropriately during the runtime.
"""
def __init__(self, message: str):
"""Initialize the exception with a message."""
super().__init__(message)
self.message = message
def __str__(self) -> str:
"""Return the string representation of the exception."""
return f"{self.__class__.__name__}: {self.message}"

View File

@@ -0,0 +1,16 @@
# -*- coding: utf-8 -*-
"""The tool-related exceptions in agentscope."""
from ._exception_base import AgentOrientedExceptionBase
class ToolNotFoundError(AgentOrientedExceptionBase):
"""Exception raised when a tool was not found."""
class ToolInterruptedError(AgentOrientedExceptionBase):
"""Exception raised when a tool calling was interrupted by the user."""
class ToolInvalidArgumentsError(AgentOrientedExceptionBase):
"""Exception raised when the arguments passed to a tool are invalid."""

Some files were not shown because too many files have changed in this diff Show More