chore: 添加虚拟环境到仓库
- 添加 backend_service/venv 虚拟环境 - 包含所有Python依赖包 - 注意:虚拟环境约393MB,包含12655个文件
This commit is contained in:
@@ -0,0 +1,10 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The learning module of AgentScope, including RL and SFT."""
|
||||
|
||||
from ._tune import tune
|
||||
from ._workflow import WorkflowType
|
||||
|
||||
__all__ = [
|
||||
"tune",
|
||||
"WorkflowType",
|
||||
]
|
||||
@@ -0,0 +1,72 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""The main entry point for agent learning."""
|
||||
from dataclasses import dataclass
|
||||
from ._workflow import (
|
||||
WorkflowType,
|
||||
_validate_function_signature,
|
||||
)
|
||||
|
||||
|
||||
def tune(workflow_func: WorkflowType, config_path: str) -> None:
|
||||
"""Train the agent workflow with the specific configuration.
|
||||
|
||||
Args:
|
||||
workflow_func (WorkflowType): The learning workflow function
|
||||
to execute.
|
||||
config_path (str): The configuration for the learning process.
|
||||
"""
|
||||
try:
|
||||
from trinity.cli.launcher import run_stage
|
||||
from trinity.common.config import Config
|
||||
from omegaconf import OmegaConf
|
||||
except ImportError as e:
|
||||
raise ImportError(
|
||||
"Trinity-RFT is not installed. Please install it with "
|
||||
"`pip install trinity-rft`.",
|
||||
) from e
|
||||
|
||||
if not _validate_function_signature(workflow_func):
|
||||
raise ValueError(
|
||||
"Invalid workflow function signature, please "
|
||||
"check the types of your workflow input/output.",
|
||||
)
|
||||
|
||||
@dataclass
|
||||
class TuneConfig(Config):
|
||||
"""Configuration for learning process."""
|
||||
|
||||
def to_trinity_config(self, workflow_func: WorkflowType) -> Config:
|
||||
"""Convert to Trinity-RFT compatible configuration."""
|
||||
workflow_name = "agentscope_workflow_adapter"
|
||||
self.buffer.explorer_input.taskset.default_workflow_type = (
|
||||
workflow_name
|
||||
)
|
||||
self.buffer.explorer_input.default_workflow_type = workflow_name
|
||||
self.buffer.explorer_input.taskset.workflow_args[
|
||||
"workflow_func"
|
||||
] = workflow_func
|
||||
return self.check_and_update()
|
||||
|
||||
@classmethod
|
||||
def load_config(cls, config_path: str) -> "TuneConfig":
|
||||
"""Load the learning configuration from a YAML file.
|
||||
|
||||
Args:
|
||||
config_path (str): The path to the configuration file.
|
||||
|
||||
Returns:
|
||||
TuneConfig: The loaded learning configuration.
|
||||
"""
|
||||
schema = OmegaConf.structured(cls)
|
||||
yaml_config = OmegaConf.load(config_path)
|
||||
try:
|
||||
config = OmegaConf.merge(schema, yaml_config)
|
||||
return OmegaConf.to_object(config)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Invalid configuration: {e}") from e
|
||||
|
||||
return run_stage(
|
||||
config=TuneConfig.load_config(config_path).to_trinity_config(
|
||||
workflow_func,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,77 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""Workflow for agent learning."""
|
||||
|
||||
from typing import (
|
||||
Dict,
|
||||
Callable,
|
||||
Awaitable,
|
||||
get_type_hints,
|
||||
)
|
||||
|
||||
import inspect
|
||||
|
||||
from .._logging import logger
|
||||
from ..model import TrinityChatModel
|
||||
|
||||
|
||||
WorkflowType = Callable[[Dict, TrinityChatModel], Awaitable[float]]
|
||||
|
||||
|
||||
def _validate_function_signature(func: Callable) -> bool:
|
||||
"""Validate if a function matches the workflow type signature.
|
||||
|
||||
Args:
|
||||
func (Callable): The function to validate.
|
||||
"""
|
||||
# check if the function is asynchronous
|
||||
if not inspect.iscoroutinefunction(func):
|
||||
logger.warning("The function is not asynchronous.")
|
||||
return False
|
||||
# Define expected parameter types and return type manually
|
||||
expected_params = [
|
||||
("task", Dict),
|
||||
("model", TrinityChatModel),
|
||||
]
|
||||
expected_return = float
|
||||
|
||||
func_signature = inspect.signature(func)
|
||||
func_hints = get_type_hints(func)
|
||||
|
||||
# Check if the number of parameters matches
|
||||
if len(func_signature.parameters) != len(expected_params):
|
||||
logger.warning(
|
||||
"Expected %d parameters, but got %d",
|
||||
len(expected_params),
|
||||
len(func_signature.parameters),
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate each parameter's name and type
|
||||
for (param_name, _), (expected_name, expected_type) in zip(
|
||||
func_signature.parameters.items(),
|
||||
expected_params,
|
||||
):
|
||||
if (
|
||||
param_name != expected_name
|
||||
or func_hints.get(param_name) != expected_type
|
||||
):
|
||||
logger.warning(
|
||||
"Expected parameter %s of type %s, but got %s of type %s",
|
||||
expected_name,
|
||||
expected_type,
|
||||
param_name,
|
||||
func_hints.get(param_name),
|
||||
)
|
||||
return False
|
||||
|
||||
# Validate the return type
|
||||
return_annotation = func_hints.get("return", None)
|
||||
if return_annotation != expected_return:
|
||||
logger.warning(
|
||||
"Expected return type %s, but got %s",
|
||||
expected_return,
|
||||
return_annotation,
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
Reference in New Issue
Block a user