175 lines
4.9 KiB
Python
175 lines
4.9 KiB
Python
#!/usr/bin/env python3
|
||
# -*- coding: utf-8 -*-
|
||
|
||
import os
|
||
import sys
|
||
import json
|
||
import argparse
|
||
from typing import Any, Dict
|
||
|
||
import requests
|
||
|
||
|
||
def build_args() -> argparse.Namespace:
|
||
parser = argparse.ArgumentParser(
|
||
description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词"
|
||
)
|
||
parser.add_argument(
|
||
"--base-url",
|
||
default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"),
|
||
help="llama-server 的基础URL(默认: http://127.0.0.1:8081/v1,或环境变量 SIMPLE_BASE_URL)",
|
||
)
|
||
parser.add_argument(
|
||
"--model",
|
||
default=os.getenv("SIMPLE_MODEL", "local-model"),
|
||
help="模型名称(默认: local-model,或环境变量 SIMPLE_MODEL)",
|
||
)
|
||
parser.add_argument(
|
||
"--system",
|
||
default="You are a helpful assistant.",
|
||
help="系统提示词(system role)",
|
||
)
|
||
parser.add_argument(
|
||
"--system-file",
|
||
default=None,
|
||
help="系统提示词文件路径(txt);若提供,则覆盖 --system 的字符串",
|
||
)
|
||
parser.add_argument(
|
||
"--user",
|
||
default=None,
|
||
help="用户提示词(user role);若不传则从交互式输入读取",
|
||
)
|
||
parser.add_argument(
|
||
"--temperature",
|
||
type=float,
|
||
default=0.2,
|
||
help="采样温度(默认: 0.2)",
|
||
)
|
||
parser.add_argument(
|
||
"--max-tokens",
|
||
type=int,
|
||
default=4096,
|
||
help="最大生成Token数(默认: 4096)",
|
||
)
|
||
parser.add_argument(
|
||
"--timeout",
|
||
type=float,
|
||
default=120.0,
|
||
help="HTTP超时时间秒(默认: 120)",
|
||
)
|
||
parser.add_argument(
|
||
"--verbose",
|
||
action="store_true",
|
||
help="打印完整返回JSON",
|
||
)
|
||
return parser.parse_args()
|
||
|
||
|
||
def call_llama_server(
|
||
base_url: str,
|
||
model: str,
|
||
system_prompt: str,
|
||
user_prompt: str,
|
||
temperature: float,
|
||
max_tokens: int,
|
||
timeout: float,
|
||
) -> Dict[str, Any]:
|
||
endpoint = base_url.rstrip("/") + "/chat/completions"
|
||
headers: Dict[str, str] = {"Content-Type": "application/json"}
|
||
|
||
# 兼容需要API Key的代理/服务(llama-server通常不强制)
|
||
api_key = os.getenv("OPENAI_API_KEY")
|
||
if api_key:
|
||
headers["Authorization"] = f"Bearer {api_key}"
|
||
|
||
payload: Dict[str, Any] = {
|
||
"model": model,
|
||
"messages": [
|
||
{"role": "system", "content": system_prompt},
|
||
{"role": "user", "content": user_prompt},
|
||
],
|
||
"temperature": temperature,
|
||
"max_tokens": max_tokens,
|
||
}
|
||
|
||
resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout)
|
||
resp.raise_for_status()
|
||
return resp.json()
|
||
|
||
|
||
def main() -> None:
|
||
args = build_args()
|
||
|
||
user_prompt = args.user
|
||
if not user_prompt:
|
||
try:
|
||
user_prompt = input("请输入用户提示词: ")
|
||
except KeyboardInterrupt:
|
||
print("\n已取消。")
|
||
sys.exit(1)
|
||
|
||
# 解析系统提示词:优先使用 --system-file
|
||
system_prompt = args.system
|
||
if args.system_file:
|
||
try:
|
||
with open(args.system_file, "r", encoding="utf-8") as f:
|
||
system_prompt = f.read()
|
||
except Exception as e:
|
||
print("\n❌ 读取系统提示词文件失败:")
|
||
print(str(e))
|
||
sys.exit(1)
|
||
|
||
try:
|
||
print("--- llama-server 推理 ---")
|
||
print(f"Base URL: {args.base_url}")
|
||
print(f"Model: {args.model}")
|
||
if args.system_file:
|
||
print(f"System(from file): {args.system_file}")
|
||
else:
|
||
print(f"System: {system_prompt}")
|
||
print(f"User: {user_prompt}")
|
||
|
||
data = call_llama_server(
|
||
base_url=args.base_url,
|
||
model=args.model,
|
||
system_prompt=system_prompt,
|
||
user_prompt=user_prompt,
|
||
temperature=args.temperature,
|
||
max_tokens=args.max_tokens,
|
||
timeout=args.timeout,
|
||
)
|
||
|
||
if args.verbose:
|
||
print("\n完整返回JSON:")
|
||
print(json.dumps(data, ensure_ascii=False, indent=2))
|
||
|
||
# 尝试按OpenAI兼容格式提取assistant内容
|
||
content = None
|
||
try:
|
||
content = data["choices"][0]["message"]["content"]
|
||
except Exception:
|
||
pass
|
||
|
||
if content is not None:
|
||
print("\n模型输出:")
|
||
print(content)
|
||
else:
|
||
# 兜底打印
|
||
print("\n无法按OpenAI兼容字段解析内容,原始返回如下:")
|
||
print(json.dumps(data, ensure_ascii=False))
|
||
|
||
except requests.exceptions.RequestException as e:
|
||
print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。")
|
||
print(f"详情: {e}")
|
||
sys.exit(2)
|
||
except Exception as e:
|
||
print("\n❌ 发生未预期的错误:")
|
||
print(str(e))
|
||
sys.exit(3)
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main()
|
||
|
||
|