Files
DronePlanning/tools/test_llama_server.py
2025-09-21 01:16:33 +08:00

175 lines
4.9 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import os
import sys
import json
import argparse
from typing import Any, Dict
import requests
def build_args() -> argparse.Namespace:
parser = argparse.ArgumentParser(
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
)
parser.add_argument(
"--base-url",
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
help="llama-server 的基础URL默认: http://127.0.0.1:8081/v1或环境变量 SIMPLE_BASE_URL",
)
parser.add_argument(
"--model",
default=os.getenv("SIMPLE_MODEL", "local-model"),
help="模型名称(默认: local-model或环境变量 SIMPLE_MODEL",
)
parser.add_argument(
"--system",
default="You are a helpful assistant.",
help="系统提示词system role",
)
parser.add_argument(
"--system-file",
default=None,
help="系统提示词文件路径txt若提供则覆盖 --system 的字符串",
)
parser.add_argument(
"--user",
default=None,
help="用户提示词user role若不传则从交互式输入读取",
)
parser.add_argument(
"--temperature",
type=float,
default=0.2,
help="采样温度(默认: 0.2",
)
parser.add_argument(
"--max-tokens",
type=int,
default=4096,
help="最大生成Token数默认: 4096",
)
parser.add_argument(
"--timeout",
type=float,
default=120.0,
help="HTTP超时时间秒默认: 120",
)
parser.add_argument(
"--verbose",
action="store_true",
help="打印完整返回JSON",
)
return parser.parse_args()
def call_llama_server(
base_url: str,
model: str,
system_prompt: str,
user_prompt: str,
temperature: float,
max_tokens: int,
timeout: float,
) -> Dict[str, Any]:
endpoint = base_url.rstrip("/") + "/chat/completions"
headers: Dict[str, str] = {"Content-Type": "application/json"}
# 兼容需要API Key的代理/服务llama-server通常不强制
api_key = os.getenv("OPENAI_API_KEY")
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
payload: Dict[str, Any] = {
"model": model,
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"temperature": temperature,
"max_tokens": max_tokens,
}
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
resp.raise_for_status()
return resp.json()
def main() -> None:
args = build_args()
user_prompt = args.user
if not user_prompt:
try:
user_prompt = input("请输入用户提示词: ")
except KeyboardInterrupt:
print("\n已取消。")
sys.exit(1)
# 解析系统提示词:优先使用 --system-file
system_prompt = args.system
if args.system_file:
try:
with open(args.system_file, "r", encoding="utf-8") as f:
system_prompt = f.read()
except Exception as e:
print("\n❌ 读取系统提示词文件失败:")
print(str(e))
sys.exit(1)
try:
print("--- llama-server 推理 ---")
print(f"Base URL: {args.base_url}")
print(f"Model: {args.model}")
if args.system_file:
print(f"System(from file): {args.system_file}")
else:
print(f"System: {system_prompt}")
print(f"User: {user_prompt}")
data = call_llama_server(
base_url=args.base_url,
model=args.model,
system_prompt=system_prompt,
user_prompt=user_prompt,
temperature=args.temperature,
max_tokens=args.max_tokens,
timeout=args.timeout,
)
if args.verbose:
print("\n完整返回JSON:")
print(json.dumps(data, ensure_ascii=False, indent=2))
# 尝试按OpenAI兼容格式提取assistant内容
content = None
try:
content = data["choices"][0]["message"]["content"]
except Exception:
pass
if content is not None:
print("\n模型输出:")
print(content)
else:
# 兜底打印
print("\n无法按OpenAI兼容字段解析内容原始返回如下")
print(json.dumps(data, ensure_ascii=False))
except requests.exceptions.RequestException as e:
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
print(f"详情: {e}")
sys.exit(2)
except Exception as e:
print("\n❌ 发生未预期的错误:")
print(str(e))
sys.exit(3)
if __name__ == "__main__":
main()