first commit
This commit is contained in:
381
CHANGELOG.md
Normal file
381
CHANGELOG.md
Normal file
@@ -0,0 +1,381 @@
|
||||
# 项目修改日志 (Changelog)
|
||||
|
||||
本文档记录了项目的主要功能修改和代码变更。
|
||||
|
||||
## 修改日期
|
||||
2025-11-24
|
||||
|
||||
---
|
||||
|
||||
## 一、模型服务器 (Models Server) 修改
|
||||
|
||||
### 1.1 文件:`scripts/ai_agent/models/models_server.py`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**1. 配置解析改进**
|
||||
- **位置**:`get_model_server()` 函数
|
||||
- **修改**:
|
||||
- 修正配置文件路径,使用 `project_root` 变量
|
||||
- 添加配置文件读取错误检查(`read_json_file()` 返回 `None` 的情况)
|
||||
- 支持 `models_server` 配置为列表或字典两种格式
|
||||
- 修复 `global _process_instance` 声明位置(必须在函数开始处)
|
||||
|
||||
**2. 消息构建逻辑修复**
|
||||
- **位置**:`_build_message()` 方法
|
||||
- **修改**:
|
||||
- 修复纯文本和多模态消息格式不一致的问题
|
||||
- 纯文本推理:`system_prompt` 和 `user_prompt` 的 `content` 字段为字符串
|
||||
- 多模态推理:`system_prompt` 和 `user_prompt` 的 `content` 字段为列表(包含文本和图像对象)
|
||||
- 解决了 `TypeError: can only concatenate str (not "list") to str` 错误
|
||||
|
||||
**3. JSON Schema 支持**
|
||||
- **位置**:`text_inference()`, `multimodal_inference()`, `model_inference()` 方法
|
||||
- **修改**:
|
||||
- 添加 `response_format` 参数支持,用于约束模型输出格式
|
||||
- 当 `request.json_schema` 存在时,构建 `response_format` 对象传递给 `create_chat_completion`
|
||||
- 格式:`{"type": "json_object", "schema": request.json_schema}`
|
||||
|
||||
**4. 错误处理增强**
|
||||
- **位置**:`text_inference()` 方法
|
||||
- **修改**:
|
||||
- 添加 `messages` 为 `None` 的检查
|
||||
- 显式设置 `image_data=None` 用于纯文本请求
|
||||
- 改进错误日志,使用 `exc_info=True` 记录完整堆栈信息
|
||||
|
||||
**5. GPU 日志增强**
|
||||
- **位置**:`init_vlm_llamacpp()` 方法
|
||||
- **修改**:
|
||||
- 当 `n_gpu_layers > 0` 时,设置 `verbose=True` 以输出 GPU 使用信息
|
||||
- 添加 GPU 层数日志输出
|
||||
|
||||
**6. 主函数启用**
|
||||
- **位置**:`main()` 函数
|
||||
- **修改**:
|
||||
- 取消注释 `uvicorn.run()` 调用,使模型服务器可以独立启动
|
||||
|
||||
---
|
||||
|
||||
### 1.2 文件:`scripts/ai_agent/models/model_definition.py`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**JSON Schema 字段添加**
|
||||
- **位置**:`TextInferenceRequest` 和 `InferenceRequest` 类
|
||||
- **修改**:
|
||||
- 添加 `json_schema: Optional[Dict[str, Any]] = None` 字段
|
||||
- 用于在推理请求中传递 JSON Schema,约束模型输出格式
|
||||
|
||||
---
|
||||
|
||||
### 1.3 文件:`scripts/ai_agent/models/models_client.py`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**Pydantic 兼容性改进**
|
||||
- **位置**:`text_inference()` 和 `multimodal_inference()` 方法
|
||||
- **修改**:
|
||||
- 添加 Pydantic v1/v2 兼容的 JSON 序列化逻辑
|
||||
- 优先使用 `model_dump_json()` (Pydantic v2)
|
||||
- 降级使用 `json()` (Pydantic v1)
|
||||
- 最后使用 `dict()`/`model_dump()` + `json.dumps()` 作为兜底方案
|
||||
- 解决了 `Object of type TextInferenceRequest is not JSON serializable` 错误
|
||||
|
||||
**服务器等待逻辑改进**
|
||||
- **位置**:`wait_for_server()` 方法
|
||||
- **修改**:
|
||||
- 接受 "loading" 状态作为有效状态
|
||||
- 增强连接失败时的错误消息
|
||||
|
||||
---
|
||||
|
||||
## 二、后端服务 (Backend Service) 修改
|
||||
|
||||
### 2.1 文件:`scripts/ai_agent/groundcontrol/backend_service/src/main.py`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**子进程输出处理改进**
|
||||
- **位置**:`start_model_server()` 函数
|
||||
- **修改**:
|
||||
- 使用线程实时输出模型服务器的 stdout 和 stderr
|
||||
- 移除 `bufsize=1` 参数(二进制模式下不支持行缓冲)
|
||||
- 解决了 `RuntimeWarning: line buffering (buffering=1) isn't supported in binary mode` 警告
|
||||
|
||||
---
|
||||
|
||||
### 2.2 文件:`scripts/ai_agent/groundcontrol/backend_service/src/py_tree_generator.py`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**1. RAG 初始化改进**
|
||||
- **位置**:`__init__()` 方法
|
||||
- **修改**:
|
||||
- 从 `rag_config.json` 读取 RAG 配置
|
||||
- 计算 `ai_agent_root` 路径(相对于 `base_dir` 向上三级)
|
||||
- 使用配置中的 `vectorstore_persist_directory` 和 `collection_name`
|
||||
- 添加向量数据库加载状态日志
|
||||
|
||||
**2. JSON Schema 支持**
|
||||
- **位置**:`generate()` 方法
|
||||
- **修改**:
|
||||
- 分类请求:传递 `classifier_schema` 约束输出格式
|
||||
- 生成请求:根据模式(simple/complex)传递对应的 schema(`self.schema` 或 `self.simple_schema`)
|
||||
- 确保模型输出符合预期的 JSON 结构
|
||||
|
||||
**3. RAG 检索调用更新**
|
||||
- **位置**:`_retrieve_context()` 方法
|
||||
- **修改**:
|
||||
- 将 `score_threshold=0.6` 改为 `score_threshold=None`
|
||||
- 使用自适应阈值,自动适应 L2 距离(ChromaDB 默认使用 L2 距离)
|
||||
|
||||
**4. 正则表达式修复**
|
||||
- **位置**:`_parse_allowed_nodes_from_prompt()` 函数
|
||||
- **修改**:
|
||||
- 修复正则表达式,正确匹配 "可用节点定义" 部分
|
||||
- 解决了 `ERROR - 在系统提示词中未找到'可用节点定义'部分的JSON代码块` 错误
|
||||
|
||||
---
|
||||
|
||||
## 三、RAG 系统 (Retrieval Augmented Generation) 修改
|
||||
|
||||
### 3.1 文件:`scripts/ai_agent/tools/memory_mag/rag/rag.py`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**1. 向量数据库加载改进**
|
||||
- **位置**:`load_vector_database()` 函数
|
||||
- **修改**:
|
||||
- 添加集合存在性检查(使用 ChromaDB 客户端)
|
||||
- 列出所有可用集合及其文档数量
|
||||
- 检查目标集合是否有数据
|
||||
- 自动检测并提示有数据的集合
|
||||
- 即使集合为空也返回数据库对象(允许后续添加数据)
|
||||
- 添加详细的日志输出,包括集合信息和文档数量
|
||||
|
||||
**2. 检索逻辑优化**
|
||||
- **位置**:`retrieve_relevant_info()` 函数
|
||||
- **修改**:
|
||||
- 将 `score_threshold` 默认值改为 `None`,支持自适应阈值
|
||||
- 自适应阈值逻辑:
|
||||
- 取前 k 个结果中的最高分数
|
||||
- 乘以 1.5 作为阈值(确保包含所有前 k 个结果)
|
||||
- 自动适应 L2 距离(分数范围通常在 10000-20000)
|
||||
- 修复日志格式化错误(`score_threshold` 可能为 `None`)
|
||||
- 添加相似度分数范围日志输出
|
||||
- 解决了阈值过小(0.6)导致所有结果被过滤的问题
|
||||
|
||||
**3. 嵌入模型配置容错**
|
||||
- **位置**:`set_embeddings()` 函数
|
||||
- **修改**:
|
||||
- 使用 `.get()` 方法访问 `model_config` 参数,提供默认值
|
||||
- 支持的参数:`n_ctx`, `n_threads`, `n_gpu_layers`, `n_seq_max`, `n_threads_batch`, `flash_attn`, `verbose`
|
||||
- 解决了 `KeyError: 'n_threads_batch'` 等配置缺失错误
|
||||
|
||||
---
|
||||
|
||||
### 3.2 文件:`scripts/ai_agent/config/rag_config.json`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**配置更新**
|
||||
- **修改项**:
|
||||
1. `vectorstore_persist_directory`:更新为 `/home/huangfukk/AI_Agent/scripts/ai_agent/memory/knowledge_base/map/vector_store/osm_map1`
|
||||
2. `embedding_model_path`:更新为 `/home/huangfukk/models/gguf/Qwen/Qwen3-Embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf`
|
||||
3. `collection_name`:从 `"osm_map_docs"` 改为 `"drone_docs"`(使用有数据的集合)
|
||||
4. `model_config_llamacpp`:添加 `n_threads_batch: 4` 字段
|
||||
|
||||
---
|
||||
|
||||
### 3.3 文件:`scripts/ai_agent/config/model_config.json`
|
||||
|
||||
#### 修改内容:
|
||||
|
||||
**模型配置更新**
|
||||
- **修改项**:
|
||||
1. `verbose`:从 `0` 改为 `1`(启用详细日志)
|
||||
2. `n_gpu_layers`:设置为 `40`(启用 GPU 加速)
|
||||
|
||||
---
|
||||
|
||||
## 四、主要问题修复总结
|
||||
|
||||
### 4.1 已修复的错误
|
||||
|
||||
1. **模型服务器启动超时**
|
||||
- **原因**:`main()` 函数被注释,服务器无法启动
|
||||
- **修复**:取消注释 `uvicorn.run()` 调用
|
||||
|
||||
2. **配置解析 KeyError**
|
||||
- **原因**:`models_server` 配置格式不一致(列表 vs 字典)
|
||||
- **修复**:添加格式判断,支持两种格式
|
||||
|
||||
3. **消息格式类型错误**
|
||||
- **原因**:纯文本和多模态消息格式混用
|
||||
- **修复**:区分纯文本(字符串)和多模态(列表)格式
|
||||
|
||||
4. **JSON 序列化错误**
|
||||
- **原因**:Pydantic 模型直接序列化失败
|
||||
- **修复**:添加 Pydantic v1/v2 兼容的序列化逻辑
|
||||
|
||||
5. **RAG 检索返回空结果**
|
||||
- **原因**:
|
||||
- 集合名称不匹配(`osm_map_docs` 为空,`drone_docs` 有数据)
|
||||
- 相似度阈值过小(0.6),而 L2 距离分数通常在 10000+ 范围
|
||||
- **修复**:
|
||||
- 更新集合名称为 `drone_docs`
|
||||
- 实现自适应阈值机制
|
||||
|
||||
6. **配置缺失 KeyError**
|
||||
- **原因**:`rag_config.json` 缺少 `n_threads_batch` 字段
|
||||
- **修复**:添加字段,并在代码中使用 `.get()` 提供默认值
|
||||
|
||||
---
|
||||
|
||||
## 五、新增功能
|
||||
|
||||
### 5.1 JSON Schema 约束输出
|
||||
|
||||
- **功能描述**:支持在推理请求中传递 JSON Schema,约束模型输出格式
|
||||
- **应用场景**:
|
||||
- 分类任务:确保输出符合 `{"mode": "simple"|"complex"}` 格式
|
||||
- 生成任务:确保输出符合 Pytree JSON 结构
|
||||
- **实现位置**:
|
||||
- `model_definition.py`:添加 `json_schema` 字段
|
||||
- `models_server.py`:构建 `response_format` 参数
|
||||
- `py_tree_generator.py`:传递动态生成的 schema
|
||||
|
||||
### 5.2 自适应相似度阈值
|
||||
|
||||
- **功能描述**:根据检索结果自动调整相似度阈值
|
||||
- **优势**:
|
||||
- 自动适应不同的距离度量(L2、余弦距离等)
|
||||
- 无需手动调整阈值参数
|
||||
- 确保返回前 k 个最相关的结果
|
||||
- **实现位置**:`rag.py` 的 `retrieve_relevant_info()` 函数
|
||||
|
||||
### 5.3 向量数据库诊断功能
|
||||
|
||||
- **功能描述**:加载向量数据库时自动检查集合状态
|
||||
- **功能**:
|
||||
- 列出所有可用集合
|
||||
- 显示每个集合的文档数量
|
||||
- 提示有数据的集合
|
||||
- 诊断集合不存在或为空的情况
|
||||
- **实现位置**:`rag.py` 的 `load_vector_database()` 函数
|
||||
|
||||
---
|
||||
|
||||
## 六、配置变更
|
||||
|
||||
### 6.1 必须更新的配置
|
||||
|
||||
1. **`rag_config.json`**
|
||||
- 更新 `vectorstore_persist_directory` 路径
|
||||
- 更新 `embedding_model_path` 路径
|
||||
- 更新 `collection_name` 为有数据的集合
|
||||
- 确保 `model_config_llamacpp` 包含所有必需字段
|
||||
|
||||
2. **`model_config.json`**
|
||||
- 根据需求调整 `n_gpu_layers`(GPU 加速层数)
|
||||
- 根据需求调整 `verbose`(日志详细程度)
|
||||
|
||||
---
|
||||
|
||||
## 七、测试建议
|
||||
|
||||
### 7.1 功能测试
|
||||
|
||||
1. **模型服务器启动测试**
|
||||
```bash
|
||||
cd scripts/ai_agent/models
|
||||
python models_server.py
|
||||
```
|
||||
|
||||
2. **RAG 检索测试**
|
||||
```python
|
||||
from tools.memory_mag.rag.rag import set_embeddings, load_vector_database, retrieve_relevant_info
|
||||
# 测试检索功能
|
||||
```
|
||||
|
||||
3. **JSON Schema 约束测试**
|
||||
- 发送分类请求,验证输出格式
|
||||
- 发送生成请求,验证 Pytree JSON 结构
|
||||
|
||||
### 7.2 集成测试
|
||||
|
||||
1. **端到端测试**
|
||||
- 启动后端服务
|
||||
- 发送 "起飞,飞到匡亚明学院" 请求
|
||||
- 验证 RAG 检索是否返回地点信息
|
||||
- 验证生成的 Pytree 是否包含正确的地点坐标
|
||||
|
||||
---
|
||||
|
||||
## 八、注意事项
|
||||
|
||||
1. **向量数据库集合**
|
||||
- 确保使用有数据的集合(当前为 `drone_docs`)
|
||||
- 如果更换集合,需要更新 `rag_config.json` 中的 `collection_name`
|
||||
|
||||
2. **嵌入模型路径**
|
||||
- 确保嵌入模型路径正确
|
||||
- 嵌入模型必须与向量数据库创建时使用的模型一致
|
||||
|
||||
3. **GPU 配置**
|
||||
- 根据硬件配置调整 `n_gpu_layers`
|
||||
- 如果使用 CPU,设置 `n_gpu_layers=0`
|
||||
|
||||
4. **日志级别**
|
||||
- 生产环境建议设置 `verbose=0`
|
||||
- 调试时可以使用 `verbose=1` 查看详细信息
|
||||
|
||||
---
|
||||
|
||||
## 九、后续优化建议
|
||||
|
||||
1. **性能优化**
|
||||
- 考虑缓存嵌入模型实例
|
||||
- 优化向量数据库查询性能
|
||||
|
||||
2. **错误处理**
|
||||
- 添加更细粒度的错误分类
|
||||
- 提供更友好的错误消息
|
||||
|
||||
3. **配置管理**
|
||||
- 考虑使用环境变量覆盖配置
|
||||
- 添加配置验证逻辑
|
||||
|
||||
4. **文档完善**
|
||||
- 添加 API 文档
|
||||
- 添加部署指南
|
||||
|
||||
---
|
||||
|
||||
## 十、启动脚本
|
||||
|
||||
### 10.1 后端服务启动脚本
|
||||
|
||||
**文件**:`scripts/ai_agent/groundcontrol/backend_service/start_backend.sh`
|
||||
|
||||
**功能**:
|
||||
- 提供便捷的后端服务启动方式
|
||||
- 自动设置工作目录和 Python 路径
|
||||
- 启动 FastAPI 后端服务,监听 `0.0.0.0:8001`
|
||||
|
||||
**使用方法**:
|
||||
```bash
|
||||
cd scripts/ai_agent/groundcontrol/backend_service
|
||||
./start_backend.sh
|
||||
```
|
||||
|
||||
**脚本内容**:
|
||||
- 自动切换到脚本所在目录
|
||||
- 设置 Python 路径
|
||||
- 使用 uvicorn 启动 FastAPI 应用:`python -m uvicorn src.main:app --host 0.0.0.0 --port 8001`
|
||||
|
||||
---
|
||||
|
||||
**文档版本**:1.1
|
||||
**最后更新**:2025-11-24
|
||||
|
||||
@@ -2,16 +2,16 @@
|
||||
"models_server":
|
||||
{
|
||||
"model_inference_framework_type":"llama_cpp",
|
||||
"multi_modal":1,
|
||||
"multi_modal":0,
|
||||
"chat_handler": "Qwen25VLChatHandler",
|
||||
"vlm_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen2.5-VL-3B-Instruct-GGUF/Qwen2.5-VL-3B-Instruct-Q8_0.gguf",
|
||||
"mmproj_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen2.5-VL-3B-Instruct-GGUF/mmproj-model-f16.gguf",
|
||||
"vlm_model_path":"/home/huangfukk/models/gguf/Qwen/Qwen3-4B/Qwen3-4B-Q5_K_M.gguf",
|
||||
"mmproj_model_path":"",
|
||||
"n_ctx":30720,
|
||||
"n_threads":4,
|
||||
"n_gpu_layers":40,
|
||||
"n_batch":24,
|
||||
"n_ubatch":24,
|
||||
"verbose":0,
|
||||
"verbose":1,
|
||||
"model_server_host":"localhost",
|
||||
"model_server_port":8000,
|
||||
"model_controller_workers":1
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
{
|
||||
"rag_mag":{
|
||||
"embedding_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-f16.gguf",
|
||||
"vectorstore_persist_directory": "/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/memory/knowledge_base/map/vector_store/osm_map1",
|
||||
"embedding_model_path":"/home/huangfukk/models/gguf/Qwen/Qwen3-Embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf",
|
||||
"vectorstore_persist_directory": "/home/huangfukk/AI_Agent/scripts/ai_agent/memory/knowledge_base/map/vector_store/osm_map1",
|
||||
"embedding_framework_type":"llamacpp_embedding",
|
||||
"collection_name":"osm_map_docs",
|
||||
"collection_name":"drone_docs",
|
||||
"model_config_llamacpp":
|
||||
{
|
||||
"n_ctx":512,
|
||||
"n_threads":4,
|
||||
"n_gpu_layers":36,
|
||||
"n_seq_max":256,
|
||||
|
||||
"n_threads_batch":4,
|
||||
"flash_attn":1,
|
||||
"verbose":0
|
||||
},
|
||||
|
||||
1
scripts/ai_agent/groundcontrol
Submodule
1
scripts/ai_agent/groundcontrol
Submodule
Submodule scripts/ai_agent/groundcontrol added at d026107bc2
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -14,6 +14,7 @@ class TextInferenceRequest(BaseModel):
|
||||
top_p: float = 0.95
|
||||
system_prompt: Optional[str] = None
|
||||
stop: Optional[List[str]] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None # JSON Schema 用于约束输出格式
|
||||
|
||||
class MultimodalInferenceRequest(BaseModel):
|
||||
user_prompt: str
|
||||
@@ -32,6 +33,7 @@ class InferenceRequest(BaseModel):
|
||||
top_p: float = 0.95
|
||||
system_prompt: Optional[str] = None
|
||||
stop: Optional[List[str]] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None # JSON Schema 用于约束输出格式
|
||||
|
||||
|
||||
# 响应模型定义
|
||||
|
||||
@@ -135,20 +135,22 @@ class Models_Client:
|
||||
返回:
|
||||
包含推理结果的字典
|
||||
"""
|
||||
# payload = {
|
||||
# "user_prompt": prompt,
|
||||
# "max_tokens": max_tokens,
|
||||
# "temperature": temperature,
|
||||
# "top_p": top_p,
|
||||
# "system_prompt":system_prompt,
|
||||
# "stop": stop
|
||||
# }
|
||||
|
||||
try:
|
||||
# 将 Pydantic 模型转换为 JSON 字符串
|
||||
if hasattr(request, 'model_dump_json'):
|
||||
# Pydantic v2
|
||||
data = request.model_dump_json()
|
||||
elif hasattr(request, 'json'):
|
||||
# Pydantic v1
|
||||
data = request.json()
|
||||
else:
|
||||
# 降级方案:转换为字典再序列化
|
||||
data = json.dumps(request.dict() if hasattr(request, 'dict') else request.model_dump())
|
||||
|
||||
response = requests.post(
|
||||
self.text_endpoint,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(request)
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -211,44 +213,22 @@ class Models_Client:
|
||||
返回:
|
||||
包含推理结果的字典
|
||||
"""
|
||||
# # 处理图像
|
||||
# image_data = []
|
||||
|
||||
# if images:
|
||||
# # for img in images:
|
||||
# try:
|
||||
# # image_data.append(PIL_image_to_base64(image=img))
|
||||
# image_data= images
|
||||
# except Exception as e:
|
||||
# return {"error": f"处理图像base64对象失败: {str(e)}"}
|
||||
|
||||
# elif image_paths:
|
||||
# for path in image_paths:
|
||||
# try:
|
||||
# image_data.append(PIL_image_to_base64(image_path=path))
|
||||
# except Exception as e:
|
||||
# return {"error": f"处理图像 {path} 失败: {str(e)}"}
|
||||
|
||||
|
||||
|
||||
# if not image_data:
|
||||
# return {"error": "未提供有效的图像数据"}
|
||||
|
||||
# payload = {
|
||||
# "user_prompt": prompt,
|
||||
# "image_data": image_data,
|
||||
# "max_tokens": max_tokens,
|
||||
# "temperature": temperature,
|
||||
# "top_p": top_p,
|
||||
# "system_prompt": system_prompt,
|
||||
# "stop": stop
|
||||
# }
|
||||
|
||||
try:
|
||||
# 将 Pydantic 模型转换为 JSON 字符串
|
||||
if hasattr(request, 'model_dump_json'):
|
||||
# Pydantic v2
|
||||
data = request.model_dump_json()
|
||||
elif hasattr(request, 'json'):
|
||||
# Pydantic v1
|
||||
data = request.json()
|
||||
else:
|
||||
# 降级方案:转换为字典再序列化
|
||||
data = json.dumps(request.dict() if hasattr(request, 'dict') else request.model_dump())
|
||||
|
||||
response = requests.post(
|
||||
url=self.multimodal_endpoint,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(request)
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -102,24 +102,45 @@ def get_model_server() -> Models_Server:
|
||||
logger.info(f"=== get_model_server 被触发!Process ID: {os.getpid()} ===")
|
||||
|
||||
if _process_instance is None:
|
||||
config_json_file= f"{project_root}/models/model_config.json"
|
||||
print("config_json_file: ",config_json_file)
|
||||
# global model_server_config
|
||||
model_server_config = read_json_file(config_json_file)["models_server"]
|
||||
config_json_file = f"{project_root}/config/model_config.json" # ✅ 修正路径
|
||||
print("config_json_file: ", config_json_file)
|
||||
|
||||
# 读取配置文件
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None:
|
||||
logger.error(f"get_model_server: 无法读取配置文件 '{config_json_file}'")
|
||||
raise HTTPException(status_code=500, detail=f"无法读取配置文件: {config_json_file}")
|
||||
|
||||
# 检查 models_server 配置是否存在
|
||||
if "models_server" not in config_data:
|
||||
logger.error("get_model_server: 配置文件中缺少 'models_server' 字段")
|
||||
raise HTTPException(status_code=500, detail="配置文件中缺少 'models_server' 字段")
|
||||
|
||||
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
|
||||
if model_server_config is None:
|
||||
logger.error("get_model_server: 配置未初始化!")
|
||||
raise HTTPException(status_code=500, detail="服务配置未初始化")
|
||||
|
||||
# 获取单例实例
|
||||
model_server = Models_Server(model_server_config)
|
||||
if not model_server.model_loaded:
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
# _process_initialized = True
|
||||
else:
|
||||
model_server = _process_instance
|
||||
if not model_server.model_loaded:
|
||||
# 重新读取配置以获取 model_server_config
|
||||
config_json_file = f"{project_root}/config/model_config.json"
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None or "models_server" not in config_data:
|
||||
logger.error("get_model_server: 无法读取配置文件")
|
||||
raise HTTPException(status_code=500, detail="无法读取配置文件")
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
|
||||
@@ -250,68 +250,69 @@ class Models_Server:
|
||||
config = self.model_server_config
|
||||
try:
|
||||
messages = []
|
||||
has_images = request.image_data and len(request.image_data) > 0
|
||||
chat_handler = config.get("chat_handler", "")
|
||||
|
||||
# 构建 system message
|
||||
if request.system_prompt:
|
||||
system_message={
|
||||
if has_images and "Qwen25VLChatHandler" == chat_handler:
|
||||
# 多模态情况下,system message 的 content 也应该是列表
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": [request.system_prompt]
|
||||
}
|
||||
"content": [{"type": "text", "text": request.system_prompt}]
|
||||
}
|
||||
else:
|
||||
# 纯文本情况下,system message 的 content 应该是字符串
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": request.system_prompt
|
||||
}
|
||||
messages.append(system_message)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": []
|
||||
}
|
||||
#加入用户消息
|
||||
messages.append(user_message)
|
||||
|
||||
if "Qwen25VLChatHandler" == config["chat_handler"]:
|
||||
for msg in messages:
|
||||
if "user" == msg["role"]:
|
||||
# 添加图像到消息
|
||||
if request.image_data:
|
||||
len_images = len(request.image_data)
|
||||
|
||||
print("len_images: ",len_images)
|
||||
for i in range(0,len_images):
|
||||
logger.info(f"add image {i}")
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{request.image_data[i]}"}
|
||||
#f"data:image/jpeg;base64,{request.image_data[i]}"
|
||||
}
|
||||
)
|
||||
#加入user_prompt
|
||||
msg["content"].append(
|
||||
{"type": "text", "text": request.user_prompt}
|
||||
)
|
||||
|
||||
# 构建 user message
|
||||
if has_images:
|
||||
# 多模态消息:content 是列表
|
||||
user_content = []
|
||||
|
||||
if "Qwen25VLChatHandler" == chat_handler:
|
||||
# Qwen25VL 格式
|
||||
len_images = len(request.image_data)
|
||||
logger.info(f"添加 {len_images} 张图像到消息 (Qwen25VL格式)")
|
||||
for i in range(0, len_images):
|
||||
logger.info(f"add image {i}")
|
||||
user_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{request.image_data[i]}"}
|
||||
})
|
||||
else:
|
||||
# 其他多模态格式
|
||||
len_images = len(request.image_data)
|
||||
logger.info(f"添加 {len_images} 张图像到消息")
|
||||
for i in range(0, len_images):
|
||||
logger.info(f"add image {i}")
|
||||
user_content.append({
|
||||
"type": "image",
|
||||
"image": f"data:image/jpeg;base64,{request.image_data[i]}"
|
||||
})
|
||||
|
||||
# 添加文本内容
|
||||
user_content.append({"type": "text", "text": request.user_prompt})
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": user_content
|
||||
}
|
||||
else:
|
||||
for msg in messages:
|
||||
if "user" == msg["role"]:
|
||||
# 添加图像到消息
|
||||
if request.image_data:
|
||||
len_images = len(request.image_data)
|
||||
|
||||
print("len_images: ",len_images)
|
||||
for i in range(0,len_images):
|
||||
logger.info("add image ",i)
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"data:image/jpeg;base64,{request.image_data[i]}" #图像数据形式
|
||||
#f"data:image/jpeg;base64,{request.image_data[i]}"
|
||||
}
|
||||
)
|
||||
#加入user_prompt
|
||||
msg["content"].append(
|
||||
{"type": "text", "text": request.user_prompt}
|
||||
)
|
||||
# 纯文本消息:content 是字符串
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": request.user_prompt
|
||||
}
|
||||
|
||||
messages.append(user_message)
|
||||
return messages
|
||||
except Exception as e:
|
||||
print(f"进程 {self.process_id} 构建模型推理message: {str(e)}")
|
||||
logger.error(f"进程 {self.process_id} 构建模型推理message失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def text_prompt_inference(self, request: TextInferenceRequest):
|
||||
@@ -375,13 +376,28 @@ class Models_Server:
|
||||
temperature = request.temperature,
|
||||
top_p = request.top_p,
|
||||
system_prompt = request.system_prompt,
|
||||
stop = request.stop
|
||||
stop = request.stop,
|
||||
image_data = None, # 纯文本推理,没有图像
|
||||
json_schema = request.json_schema # 传递 JSON Schema
|
||||
)
|
||||
messages = self._build_message(inference_request,self.model_server_config)
|
||||
messages = self._build_message(inference_request, self.model_server_config)
|
||||
|
||||
# 检查 messages 是否为 None
|
||||
if messages is None:
|
||||
logger.error("构建消息失败:_build_message 返回 None")
|
||||
raise HTTPException(status_code=500, detail="构建推理消息失败")
|
||||
|
||||
logger.info(f"构建的消息: {messages}")
|
||||
inference_start_time = time.time()
|
||||
|
||||
# 构建提示
|
||||
prompt = f"USER: {request.user_prompt}\nASSISTANT:"
|
||||
# 准备 response_format 参数(如果提供了 json_schema)
|
||||
response_format = None
|
||||
if request.json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": request.json_schema
|
||||
}
|
||||
logger.info(f"使用 JSON Schema 约束输出格式")
|
||||
|
||||
# 生成响应
|
||||
output = self.model.create_chat_completion(
|
||||
@@ -389,8 +405,9 @@ class Models_Server:
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop ,
|
||||
stream=False
|
||||
stop=request.stop,
|
||||
stream=False,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start_time
|
||||
@@ -406,7 +423,10 @@ class Models_Server:
|
||||
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"推理失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"推理失败: {str(e)}")
|
||||
|
||||
def multimodal_inference(self,request: MultimodalInferenceRequest):
|
||||
@@ -445,6 +465,15 @@ class Models_Server:
|
||||
len_images = len(request.image_data)
|
||||
# print(f"full_prompt: {messages}")
|
||||
|
||||
# 准备 response_format 参数(如果提供了 json_schema)
|
||||
response_format = None
|
||||
if hasattr(request, 'json_schema') and request.json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": request.json_schema
|
||||
}
|
||||
logger.info(f"使用 JSON Schema 约束输出格式")
|
||||
|
||||
# 生成响应
|
||||
output = self.model.create_chat_completion(
|
||||
messages=messages,
|
||||
@@ -452,7 +481,8 @@ class Models_Server:
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stream=False
|
||||
stream=False,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start_time
|
||||
@@ -491,6 +521,15 @@ class Models_Server:
|
||||
len_images = len(request.image_data)
|
||||
# print(f"full_prompt: {messages}")
|
||||
|
||||
# 准备 response_format 参数(如果提供了 json_schema)
|
||||
response_format = None
|
||||
if hasattr(request, 'json_schema') and request.json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": request.json_schema
|
||||
}
|
||||
logger.info(f"使用 JSON Schema 约束输出格式")
|
||||
|
||||
# 生成响应
|
||||
output = self.model.create_chat_completion(
|
||||
messages=messages,
|
||||
@@ -498,7 +537,8 @@ class Models_Server:
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stream=False
|
||||
stream=False,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start_time
|
||||
@@ -527,27 +567,51 @@ model_router = InferringRouter() # 支持自动推断参数类型
|
||||
# # 依赖注入:获取单例模型服务
|
||||
def get_model_server() -> Models_Server:
|
||||
"""依赖注入:获取当前进程的单例服务实例"""
|
||||
global _process_instance # 必须在函数开始处声明
|
||||
logger.info(f"=== get_model_server 被触发!Process ID: {os.getpid()} ===")
|
||||
|
||||
if _process_instance is None:
|
||||
config_json_file= f"{project_root}/models/model_config.json"
|
||||
print("config_json_file: ",config_json_file)
|
||||
# global model_server_config
|
||||
model_server_config = read_json_file(config_json_file)["models_server"][0]
|
||||
config_json_file = f"{project_root}/config/model_config.json" # ✅ 修正路径
|
||||
print("config_json_file: ", config_json_file)
|
||||
|
||||
# 读取配置文件
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None:
|
||||
logger.error(f"get_model_server: 无法读取配置文件 '{config_json_file}'")
|
||||
raise HTTPException(status_code=500, detail=f"无法读取配置文件: {config_json_file}")
|
||||
|
||||
# 检查 models_server 配置是否存在
|
||||
if "models_server" not in config_data:
|
||||
logger.error("get_model_server: 配置文件中缺少 'models_server' 字段")
|
||||
raise HTTPException(status_code=500, detail="配置文件中缺少 'models_server' 字段")
|
||||
|
||||
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
|
||||
if model_server_config is None:
|
||||
logger.error("get_model_server: 配置未初始化!")
|
||||
raise HTTPException(status_code=500, detail="服务配置未初始化")
|
||||
|
||||
# 获取单例实例
|
||||
model_server = Models_Server(model_server_config)
|
||||
if not model_server.model_loaded:
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
# _process_initialized = True
|
||||
# 保存到全局变量
|
||||
_process_instance = model_server
|
||||
else:
|
||||
model_server = _process_instance
|
||||
if not model_server.model_loaded:
|
||||
# 重新读取配置以获取 model_server_config
|
||||
config_json_file = f"{project_root}/config/model_config.json"
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None or "models_server" not in config_data:
|
||||
logger.error("get_model_server: 无法读取配置文件")
|
||||
raise HTTPException(status_code=500, detail="无法读取配置文件")
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
@@ -603,8 +667,13 @@ class Models_Controller:
|
||||
@model_router.post("/text/inference", response_model=Dict[str, Any])
|
||||
async def text_inference(self, request: TextInferenceRequest):
|
||||
"""纯文本推理端点"""
|
||||
|
||||
return self.model_server.text_inference(request)
|
||||
try:
|
||||
return self.model_server.text_inference(request)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"端点处理失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"端点处理失败: {str(e)}")
|
||||
|
||||
@model_router.post("/multimodal/inference", response_model=Dict[str, Any])
|
||||
async def multimodal_inference(self,request: MultimodalInferenceRequest):
|
||||
@@ -679,13 +748,12 @@ def create_app() -> FastAPI:
|
||||
|
||||
|
||||
def main():
|
||||
#model server 和client初始化
|
||||
# "/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/scripts/ai_agent/models/model_config.json"
|
||||
print(f"project_root111:{project_root}")
|
||||
config_json_file= f"{project_root}/models/model_config.json"
|
||||
config_json_file= f"{project_root}/config/model_config.json" # ✅ 修正路径
|
||||
print("config_json_file: ",config_json_file)
|
||||
# global model_server_config
|
||||
model_server_config = read_json_file(config_json_file)["models_server"][0]
|
||||
config_data = read_json_file(config_json_file)["models_server"]
|
||||
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
|
||||
model_server_config = config_data[0] if isinstance(config_data, list) else config_data
|
||||
|
||||
uvicorn.run(
|
||||
app="models_server:create_app",
|
||||
@@ -696,7 +764,6 @@ def main():
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
def test_inference(inference_type=0):
|
||||
from tools.core.common_functions import read_json_file
|
||||
config_json_file="/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/scripts/ai_agent/models/model_config.json"
|
||||
@@ -751,7 +818,7 @@ def test_inference(inference_type=0):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# main()
|
||||
main()
|
||||
# test_inference(2)
|
||||
logger.info("work finish")
|
||||
# logger.info("work finish")
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -13,10 +13,11 @@ from PIL import Image as PIL_Image
|
||||
import base64
|
||||
import numpy as np
|
||||
|
||||
import rospy # type: ignore
|
||||
from std_msgs.msg import String # type: ignore
|
||||
from sensor_msgs.msg import Image as Sensor_Image # type: ignore
|
||||
from cv_bridge import CvBridge, CvBridgeError # type: ignore
|
||||
# Removed ROS dependencies
|
||||
# import rospy # type: ignore
|
||||
# from std_msgs.msg import String # type: ignore
|
||||
# from sensor_msgs.msg import Image as Sensor_Image # type: ignore
|
||||
# from cv_bridge import CvBridge, CvBridgeError # type: ignore
|
||||
import cv2
|
||||
|
||||
|
||||
@@ -25,13 +26,13 @@ current_script = Path(__file__).resolve()
|
||||
# 向上两级找到项目根目录(根据实际结构调整层级)
|
||||
project_root = current_script.parents[2] # parents[0] 是当前目录,parents[1] 是父目录,以此类推
|
||||
|
||||
print("project_root: ",project_root)
|
||||
# print("project_root: ",project_root)
|
||||
|
||||
#添加到搜索路径
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
#查看系统环境
|
||||
print("sys path:",sys.path)
|
||||
# print("sys path:",sys.path)
|
||||
|
||||
# 配置日志记录器
|
||||
logging.basicConfig(
|
||||
@@ -84,91 +85,7 @@ def PIL_image_to_base64( image_path: Optional[str] = None, image: Optional[PIL_I
|
||||
image.save(buffered, format="JPEG", quality=90)
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
def ros_image2pil_image(self, ros_image_msg:Sensor_Image,supported_image_formats):
|
||||
"""回调函数:将ROS Image转为PIL Image并处理"""
|
||||
try:
|
||||
bridge_cv = CvBridge()
|
||||
# 1. 将ROS Image消息转为OpenCV格式(默认BGR8编码)
|
||||
# 若图像编码不同(如rgb8),需指定格式:bridge.imgmsg_to_cv2(ros_image_msg, "rgb8")
|
||||
# 尝试转换为OpenCV格式
|
||||
if ros_image_msg.encoding in supported_image_formats:
|
||||
cv_image = bridge_cv.imgmsg_to_cv2(ros_image_msg, desired_encoding=ros_image_msg.encoding)
|
||||
else:
|
||||
# 尝试默认转换
|
||||
cv_image = bridge_cv.imgmsg_to_cv2(ros_image_msg)
|
||||
logger.warning(f"转换不支持的图像编码: {ros_image_msg.encoding}")
|
||||
|
||||
|
||||
# 2. OpenCV默认是BGR格式,转为PIL需要的RGB格式
|
||||
rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 3. 转为PIL Image格式
|
||||
pil_image = PIL_Image.fromarray(rgb_image)
|
||||
|
||||
# 此处可添加对PIL Image的处理(如显示、保存等)
|
||||
rospy.loginfo(f"转换成功:PIL Image尺寸 {pil_image.size}")
|
||||
return pil_image
|
||||
# 示例:显示图像(需要PIL的显示支持)
|
||||
# pil_image.show()
|
||||
|
||||
except CvBridgeError as e:
|
||||
rospy.logerr(f"转换失败:{e}")
|
||||
except Exception as e:
|
||||
rospy.logerr(f"处理错误:{e}")
|
||||
|
||||
|
||||
def ros_image2dict(self, image_msg: Sensor_Image,
|
||||
supported_image_formats=['bgr8', 'rgb8', 'mono8'],
|
||||
max_dim:int =2000) :
|
||||
"""
|
||||
将ROS图像消息转换为字典格式,便于在上下文中存储
|
||||
|
||||
参数:
|
||||
image_msg: ROS的Image消息
|
||||
|
||||
返回:
|
||||
包含图像信息的字典,或None(转换失败时)
|
||||
"""
|
||||
try:
|
||||
# 尝试转换为OpenCV格式
|
||||
if image_msg.encoding in self.supported_image_formats:
|
||||
cv_image = self.bridge.imgmsg_to_cv2(image_msg, desired_encoding=image_msg.encoding)
|
||||
else:
|
||||
# 尝试默认转换
|
||||
cv_image = self.bridge.imgmsg_to_cv2(image_msg)
|
||||
logger.warning(f"转换不支持的图像编码: {image_msg.encoding}")
|
||||
|
||||
# 图像预处理:调整大小以减少数据量
|
||||
h, w = cv_image.shape[:2]
|
||||
if max(h, w) > max_dim:
|
||||
scale = max_dim / max(h, w)
|
||||
cv_image = cv2.resize(
|
||||
cv_image,
|
||||
(int(w * scale), int(h * scale)),
|
||||
interpolation=cv2.INTER_AREA
|
||||
)
|
||||
|
||||
# 转换为JPEG并编码为base64
|
||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
|
||||
_, buffer = cv2.imencode('.jpg', cv_image, encode_param)
|
||||
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"format": "jpg",
|
||||
"data": img_base64,
|
||||
"width": cv_image.shape[1],
|
||||
"height": cv_image.shape[0],
|
||||
"original_encoding": image_msg.encoding,
|
||||
"timestamp": image_msg.header.stamp.to_sec()
|
||||
}
|
||||
|
||||
except CvBridgeError as e:
|
||||
logger.error(f"CV桥接错误: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"图像处理错误: {str(e)}")
|
||||
return None
|
||||
|
||||
# Removed ROS specific functions (ros_image2pil_image, ros_image2dict) as they depend on rospy/cv_bridge
|
||||
|
||||
def PIL_image_to_base64_sizelmt(image_path:Optional[str]=None,image:Optional[PIL_Image.Image] = None,
|
||||
need_size_lmt:bool= False, max_size:tuple=(800, 800), quality:float=90):
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -101,15 +101,16 @@ def set_embeddings(embedding_model_path,
|
||||
model_config:Optional[dict[str,Any]]):
|
||||
|
||||
if "llamacpp_embeddings" == embedding_type:
|
||||
# 使用 .get() 方法提供默认值,确保即使配置文件中缺少某些字段也能正常工作
|
||||
embeddings = set_embeddings_llamacpp(
|
||||
model_path=embedding_model_path,
|
||||
n_ctx=model_config["n_ctx"],
|
||||
n_threads=model_config["n_threads"],
|
||||
n_gpu_layers=model_config["n_gpu_layers"], # RTX 5090建议设20,充分利用GPU
|
||||
n_seq_max=model_config["n_seq_max"],
|
||||
n_threads_batch = model_config["n_threads_batch"],
|
||||
flash_attn = model_config["flash_attn"],
|
||||
verbose=model_config["verbose"]
|
||||
n_ctx=model_config.get("n_ctx", 512),
|
||||
n_threads=model_config.get("n_threads", 4),
|
||||
n_gpu_layers=model_config.get("n_gpu_layers", 0), # RTX 5090建议设20,充分利用GPU
|
||||
n_seq_max=model_config.get("n_seq_max", 128),
|
||||
n_threads_batch=model_config.get("n_threads_batch", 4),
|
||||
flash_attn=model_config.get("flash_attn", True),
|
||||
verbose=model_config.get("verbose", False)
|
||||
)
|
||||
elif "huggingFace_embeddings" == embedding_type:
|
||||
embeddings = set_embeddings_huggingFace(
|
||||
@@ -156,19 +157,49 @@ def load_vector_database(embeddings,
|
||||
collection_name:str):
|
||||
"""加载已存在的向量数据库"""
|
||||
try:
|
||||
if os.path.exists(path=persist_directory):
|
||||
vector_db = Chroma(
|
||||
persist_directory=persist_directory,
|
||||
embedding_function=embeddings,
|
||||
collection_name=collection_name
|
||||
)
|
||||
print(f"已加载向量数据库 from {persist_directory}")
|
||||
|
||||
return vector_db
|
||||
else:
|
||||
if not os.path.exists(path=persist_directory):
|
||||
logger.warning(f"向量数据库目录不存在: {persist_directory}")
|
||||
return None
|
||||
|
||||
# 先检查集合是否存在
|
||||
try:
|
||||
import chromadb
|
||||
client = chromadb.PersistentClient(path=persist_directory)
|
||||
collections = client.list_collections()
|
||||
collection_names = [col.name for col in collections]
|
||||
|
||||
if collection_name not in collection_names:
|
||||
logger.warning(f"集合 '{collection_name}' 不存在于数据库中。")
|
||||
logger.info(f"可用的集合: {collection_names}")
|
||||
# 尝试查找有数据的集合
|
||||
for col in collections:
|
||||
if col.count() > 0:
|
||||
logger.info(f"发现非空集合: {col.name} (count: {col.count()})")
|
||||
return None
|
||||
|
||||
# 检查集合是否有数据
|
||||
target_collection = client.get_collection(name=collection_name)
|
||||
doc_count = target_collection.count()
|
||||
if doc_count == 0:
|
||||
logger.warning(f"集合 '{collection_name}' 存在但为空 (count: 0)")
|
||||
logger.warning(f"⚠️ 该集合没有数据,RAG 检索将返回空结果")
|
||||
# 仍然返回数据库对象,允许后续操作(如添加数据)
|
||||
else:
|
||||
logger.info(f"✅ 集合 '{collection_name}' 包含 {doc_count} 条文档")
|
||||
except Exception as check_error:
|
||||
logger.warning(f"检查集合时出错: {check_error},继续尝试加载...")
|
||||
|
||||
# 加载向量数据库
|
||||
vector_db = Chroma(
|
||||
persist_directory=persist_directory,
|
||||
embedding_function=embeddings,
|
||||
collection_name=collection_name
|
||||
)
|
||||
logger.info(f"✅ 成功加载向量数据库 from {persist_directory}, 集合: {collection_name}")
|
||||
return vector_db
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"vector_database加载失败: {str(e)}")
|
||||
logger.error(f"vector_database加载失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
#设置文本分割器
|
||||
@@ -206,7 +237,7 @@ def set_document_loaders(self):
|
||||
def retrieve_relevant_info(vectorstore,
|
||||
query: str,
|
||||
k: int = 3,
|
||||
score_threshold: float = 0.2
|
||||
score_threshold: float = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
检索与查询相关的信息
|
||||
@@ -214,7 +245,8 @@ def retrieve_relevant_info(vectorstore,
|
||||
参数:
|
||||
query: 查询文本
|
||||
k: 最多返回的结果数量
|
||||
score_threshold: 相关性分数阈值
|
||||
score_threshold: 相关性分数阈值。如果为 None,则使用自适应阈值(前 k 个结果中最高分数的 1.5 倍)
|
||||
对于 L2 距离,分数越小越相似,阈值应该是一个较大的值
|
||||
|
||||
返回:
|
||||
包含相关文档内容和分数的列表
|
||||
@@ -223,21 +255,41 @@ def retrieve_relevant_info(vectorstore,
|
||||
# 执行相似性搜索,返回带分数的结果
|
||||
docs_and_scores = vectorstore.similarity_search_with_score(query, k=k)
|
||||
|
||||
if not docs_and_scores:
|
||||
logger.info(f"检索到 0 条结果 (k={k})")
|
||||
return []
|
||||
|
||||
# 如果没有指定阈值,使用自适应阈值
|
||||
# 对于 L2 距离,取前 k 个结果中最高的分数,然后乘以一个系数作为阈值
|
||||
if score_threshold is None:
|
||||
max_score = max(score for _, score in docs_and_scores)
|
||||
# 使用最高分数的 1.5 倍作为阈值,确保包含所有前 k 个结果
|
||||
score_threshold = max_score * 1.5
|
||||
logger.debug(f"自适应阈值: {score_threshold:.2f} (基于最高分数 {max_score:.2f})")
|
||||
|
||||
# 过滤并格式化结果
|
||||
results = []
|
||||
for doc, score in docs_and_scores:
|
||||
if score < score_threshold: # 分数越低表示越相似
|
||||
# 对于 L2 距离,分数越小越相似
|
||||
# 如果阈值很大(> 1000),说明是 L2 距离,使用 < 比较
|
||||
# 如果阈值很小(< 1),说明可能是余弦距离,使用 < 比较
|
||||
if score < score_threshold:
|
||||
results.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"similarity_score": float(score)
|
||||
})
|
||||
else:
|
||||
logger.debug(f"文档因分数 {score:.2f} >= 阈值 {score_threshold:.2f} 被过滤")
|
||||
|
||||
logger.info(f"检索到 {len(results)} 条相关信息 (k={k}, 阈值={score_threshold})")
|
||||
threshold_str = f"{score_threshold:.2f}" if score_threshold is not None else "自适应"
|
||||
logger.info(f"检索到 {len(results)} 条相关信息 (k={k}, 阈值={threshold_str})")
|
||||
if results:
|
||||
logger.debug(f"相似度分数范围: {min(r['similarity_score'] for r in results):.2f} - {max(r['similarity_score'] for r in results):.2f}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关信息失败: {str(e)}")
|
||||
logger.error(f"检索相关信息失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_retriever(vectorstore,search_kwargs: Dict[str, Any] = None) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user