#!/usr/bin/env python3 # -*- coding: utf-8 -*- import os import sys import json import argparse from typing import Any, Dict import requests def build_args() -> argparse.Namespace: parser = argparse.ArgumentParser( description="调用本地 llama-server (OpenAI兼容) 进行推理,支持自定义系统/用户提示词" ) parser.add_argument( "--base-url", default=os.getenv("SIMPLE_BASE_URL", "http://127.0.0.1:8081/v1"), help="llama-server 的基础URL(默认: http://127.0.0.1:8081/v1,或环境变量 SIMPLE_BASE_URL)", ) parser.add_argument( "--model", default=os.getenv("SIMPLE_MODEL", "local-model"), help="模型名称(默认: local-model,或环境变量 SIMPLE_MODEL)", ) parser.add_argument( "--system", default="You are a helpful assistant.", help="系统提示词(system role)", ) parser.add_argument( "--system-file", default=None, help="系统提示词文件路径(txt);若提供,则覆盖 --system 的字符串", ) parser.add_argument( "--user", default=None, help="用户提示词(user role);若不传则从交互式输入读取", ) parser.add_argument( "--temperature", type=float, default=0.2, help="采样温度(默认: 0.2)", ) parser.add_argument( "--max-tokens", type=int, default=4096, help="最大生成Token数(默认: 4096)", ) parser.add_argument( "--timeout", type=float, default=120.0, help="HTTP超时时间秒(默认: 120)", ) parser.add_argument( "--verbose", action="store_true", help="打印完整返回JSON", ) return parser.parse_args() def call_llama_server( base_url: str, model: str, system_prompt: str, user_prompt: str, temperature: float, max_tokens: int, timeout: float, ) -> Dict[str, Any]: endpoint = base_url.rstrip("/") + "/chat/completions" headers: Dict[str, str] = {"Content-Type": "application/json"} # 兼容需要API Key的代理/服务(llama-server通常不强制) api_key = os.getenv("OPENAI_API_KEY") if api_key: headers["Authorization"] = f"Bearer {api_key}" payload: Dict[str, Any] = { "model": model, "messages": [ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}, ], "temperature": temperature, "max_tokens": max_tokens, } resp = requests.post(endpoint, headers=headers, data=json.dumps(payload), timeout=timeout) resp.raise_for_status() return resp.json() def main() -> None: args = build_args() user_prompt = args.user if not user_prompt: try: user_prompt = input("请输入用户提示词: ") except KeyboardInterrupt: print("\n已取消。") sys.exit(1) # 解析系统提示词:优先使用 --system-file system_prompt = args.system if args.system_file: try: with open(args.system_file, "r", encoding="utf-8") as f: system_prompt = f.read() except Exception as e: print("\n❌ 读取系统提示词文件失败:") print(str(e)) sys.exit(1) try: print("--- llama-server 推理 ---") print(f"Base URL: {args.base_url}") print(f"Model: {args.model}") if args.system_file: print(f"System(from file): {args.system_file}") else: print(f"System: {system_prompt}") print(f"User: {user_prompt}") data = call_llama_server( base_url=args.base_url, model=args.model, system_prompt=system_prompt, user_prompt=user_prompt, temperature=args.temperature, max_tokens=args.max_tokens, timeout=args.timeout, ) if args.verbose: print("\n完整返回JSON:") print(json.dumps(data, ensure_ascii=False, indent=2)) # 尝试按OpenAI兼容格式提取assistant内容 content = None try: content = data["choices"][0]["message"]["content"] except Exception: pass if content is not None: print("\n模型输出:") print(content) else: # 兜底打印 print("\n无法按OpenAI兼容字段解析内容,原始返回如下:") print(json.dumps(data, ensure_ascii=False)) except requests.exceptions.RequestException as e: print("\n❌ 请求失败:请确认 llama-server 已在 8081 端口启动并可访问。") print(f"详情: {e}") sys.exit(2) except Exception as e: print("\n❌ 发生未预期的错误:") print(str(e)) sys.exit(3) if __name__ == "__main__": main()