新增说明

This commit is contained in:
2025-09-21 01:16:33 +08:00
parent 8e333ac03f
commit fd89745950
18 changed files with 525 additions and 19 deletions

174
tools/test_llama_server.py Normal file
View File

@@ -0,0 +1,174 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import json
import argparse
from typing import Any, Dict
import requests
def build_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
)
parser.add_argument(
"--base-url",
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
help="llama-server 的基础URL默认: http://127.0.0.1:8081/v1或环境变量 SIMPLE_BASE_URL",
)
parser.add_argument(
"--model",
default=os.getenv("SIMPLE_MODEL", "local-model"),
help="模型名称(默认: local-model或环境变量 SIMPLE_MODEL",
)
parser.add_argument(
"--system",
default="You are a helpful assistant.",
help="系统提示词system role",
)
parser.add_argument(
"--system-file",
default=None,
help="系统提示词文件路径txt若提供则覆盖 --system 的字符串",
)
parser.add_argument(
"--user",
default=None,
help="用户提示词user role若不传则从交互式输入读取",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="采样温度(默认: 0.2",
)
parser.add_argument(
"--max-tokens",
type=int,
default=4096,
help="最大生成Token数默认: 4096",
)
parser.add_argument(
"--timeout",
type=float,
default=120.0,
help="HTTP超时时间秒默认: 120",
)
parser.add_argument(
"--verbose",
action="store_true",
help="打印完整返回JSON",
)
return parser.parse_args()
def call_llama_server(
base_url: str,
model: str,
system_prompt: str,
user_prompt: str,
temperature: float,
max_tokens: int,
timeout: float,
) -> Dict[str, Any]:
endpoint = base_url.rstrip("/") + "/chat/completions"
headers: Dict[str, str] = {"Content-Type": "application/json"}
# 兼容需要API Key的代理/服务llama-server通常不强制
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload: Dict[str, Any] = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
resp.raise_for_status()
return resp.json()
def main() -> None:
args = build_args()
user_prompt = args.user
if not user_prompt:
try:
user_prompt = input("请输入用户提示词: ")
except KeyboardInterrupt:
print("\n已取消。")
sys.exit(1)
# 解析系统提示词:优先使用 --system-file
system_prompt = args.system
if args.system_file:
try:
with open(args.system_file, "r", encoding="utf-8") as f:
system_prompt = f.read()
except Exception as e:
print("\n❌ 读取系统提示词文件失败:")
print(str(e))
sys.exit(1)
try:
print("--- llama-server 推理 ---")
print(f"Base URL: {args.base_url}")
print(f"Model: {args.model}")
if args.system_file:
print(f"System(from file): {args.system_file}")
else:
print(f"System: {system_prompt}")
print(f"User: {user_prompt}")
data = call_llama_server(
base_url=args.base_url,
model=args.model,
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=args.temperature,
max_tokens=args.max_tokens,
timeout=args.timeout,
)
if args.verbose:
print("\n完整返回JSON:")
print(json.dumps(data, ensure_ascii=False, indent=2))
# 尝试按OpenAI兼容格式提取assistant内容
content = None
try:
content = data["choices"][0]["message"]["content"]
except Exception:
pass
if content is not None:
print("\n模型输出:")
print(content)
else:
# 兜底打印
print("\n无法按OpenAI兼容字段解析内容原始返回如下")
print(json.dumps(data, ensure_ascii=False))
except requests.exceptions.RequestException as e:
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
print(f"详情: {e}")
sys.exit(2)
except Exception as e:
print("\n❌ 发生未预期的错误:")
print(str(e))
sys.exit(3)
if __name__ == "__main__":
main()