first commit

This commit is contained in:
2025-11-24 20:10:33 +08:00
parent 4a7cfb1cee
commit 12cfcc2681
20 changed files with 668 additions and 247 deletions

View File

@@ -2,16 +2,16 @@
"models_server":
{
"model_inference_framework_type":"llama_cpp",
"multi_modal":1,
"multi_modal":0,
"chat_handler": "Qwen25VLChatHandler",
"vlm_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen2.5-VL-3B-Instruct-GGUF/Qwen2.5-VL-3B-Instruct-Q8_0.gguf",
"mmproj_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen2.5-VL-3B-Instruct-GGUF/mmproj-model-f16.gguf",
"vlm_model_path":"/home/huangfukk/models/gguf/Qwen/Qwen3-4B/Qwen3-4B-Q5_K_M.gguf",
"mmproj_model_path":"",
"n_ctx":30720,
"n_threads":4,
"n_gpu_layers":40,
"n_batch":24,
"n_ubatch":24,
"verbose":0,
"verbose":1,
"model_server_host":"localhost",
"model_server_port":8000,
"model_controller_workers":1

View File

@@ -1,16 +1,16 @@
{
"rag_mag":{
"embedding_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-f16.gguf",
"vectorstore_persist_directory": "/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/memory/knowledge_base/map/vector_store/osm_map1",
"embedding_model_path":"/home/huangfukk/models/gguf/Qwen/Qwen3-Embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf",
"vectorstore_persist_directory": "/home/huangfukk/AI_Agent/scripts/ai_agent/memory/knowledge_base/map/vector_store/osm_map1",
"embedding_framework_type":"llamacpp_embedding",
"collection_name":"osm_map_docs",
"collection_name":"drone_docs",
"model_config_llamacpp":
{
"n_ctx":512,
"n_threads":4,
"n_gpu_layers":36,
"n_seq_max":256,
"n_threads_batch":4,
"flash_attn":1,
"verbose":0
},

Submodule scripts/ai_agent/groundcontrol added at d026107bc2

View File

@@ -14,6 +14,7 @@ class TextInferenceRequest(BaseModel):
top_p: float = 0.95
system_prompt: Optional[str] = None
stop: Optional[List[str]] = None
json_schema: Optional[Dict[str, Any]] = None # JSON Schema 用于约束输出格式
class MultimodalInferenceRequest(BaseModel):
user_prompt: str
@@ -32,6 +33,7 @@ class InferenceRequest(BaseModel):
top_p: float = 0.95
system_prompt: Optional[str] = None
stop: Optional[List[str]] = None
json_schema: Optional[Dict[str, Any]] = None # JSON Schema 用于约束输出格式
# 响应模型定义

View File

@@ -135,20 +135,22 @@ class Models_Client:
返回:
包含推理结果的字典
"""
# payload = {
# "user_prompt": prompt,
# "max_tokens": max_tokens,
# "temperature": temperature,
# "top_p": top_p,
# "system_prompt":system_prompt,
# "stop": stop
# }
try:
# 将 Pydantic 模型转换为 JSON 字符串
if hasattr(request, 'model_dump_json'):
# Pydantic v2
data = request.model_dump_json()
elif hasattr(request, 'json'):
# Pydantic v1
data = request.json()
else:
# 降级方案:转换为字典再序列化
data = json.dumps(request.dict() if hasattr(request, 'dict') else request.model_dump())
response = requests.post(
self.text_endpoint,
headers={"Content-Type": "application/json"},
data=json.dumps(request)
data=data
)
response.raise_for_status()
return response.json()
@@ -211,44 +213,22 @@ class Models_Client:
返回:
包含推理结果的字典
"""
# # 处理图像
# image_data = []
# if images:
# # for img in images:
# try:
# # image_data.append(PIL_image_to_base64(image=img))
# image_data= images
# except Exception as e:
# return {"error": f"处理图像base64对象失败: {str(e)}"}
# elif image_paths:
# for path in image_paths:
# try:
# image_data.append(PIL_image_to_base64(image_path=path))
# except Exception as e:
# return {"error": f"处理图像 {path} 失败: {str(e)}"}
# if not image_data:
# return {"error": "未提供有效的图像数据"}
# payload = {
# "user_prompt": prompt,
# "image_data": image_data,
# "max_tokens": max_tokens,
# "temperature": temperature,
# "top_p": top_p,
# "system_prompt": system_prompt,
# "stop": stop
# }
try:
# 将 Pydantic 模型转换为 JSON 字符串
if hasattr(request, 'model_dump_json'):
# Pydantic v2
data = request.model_dump_json()
elif hasattr(request, 'json'):
# Pydantic v1
data = request.json()
else:
# 降级方案:转换为字典再序列化
data = json.dumps(request.dict() if hasattr(request, 'dict') else request.model_dump())
response = requests.post(
url=self.multimodal_endpoint,
headers={"Content-Type": "application/json"},
data=json.dumps(request)
data=data
)
response.raise_for_status()
return response.json()

View File

@@ -102,24 +102,45 @@ def get_model_server() -> Models_Server:
logger.info(f"=== get_model_server 被触发Process ID: {os.getpid()} ===")
if _process_instance is None:
config_json_file= f"{project_root}/models/model_config.json"
print("config_json_file: ",config_json_file)
# global model_server_config
model_server_config = read_json_file(config_json_file)["models_server"]
config_json_file = f"{project_root}/config/model_config.json" # ✅ 修正路径
print("config_json_file: ", config_json_file)
# 读取配置文件
config_data = read_json_file(config_json_file)
if config_data is None:
logger.error(f"get_model_server: 无法读取配置文件 '{config_json_file}'")
raise HTTPException(status_code=500, detail=f"无法读取配置文件: {config_json_file}")
# 检查 models_server 配置是否存在
if "models_server" not in config_data:
logger.error("get_model_server: 配置文件中缺少 'models_server' 字段")
raise HTTPException(status_code=500, detail="配置文件中缺少 'models_server' 字段")
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
models_server_data = config_data["models_server"]
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
if model_server_config is None:
logger.error("get_model_server: 配置未初始化!")
raise HTTPException(status_code=500, detail="服务配置未初始化")
# 获取单例实例
model_server = Models_Server(model_server_config)
if not model_server.model_loaded:
model_server._wait_for_model_load(model_server_config)
model_server.init_model(model_server_config["model_inference_framework_type"],
model_server_config)
# _process_initialized = True
else:
model_server = _process_instance
if not model_server.model_loaded:
# 重新读取配置以获取 model_server_config
config_json_file = f"{project_root}/config/model_config.json"
config_data = read_json_file(config_json_file)
if config_data is None or "models_server" not in config_data:
logger.error("get_model_server: 无法读取配置文件")
raise HTTPException(status_code=500, detail="无法读取配置文件")
models_server_data = config_data["models_server"]
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
model_server._wait_for_model_load(model_server_config)
model_server.init_model(model_server_config["model_inference_framework_type"],
model_server_config)

View File

@@ -250,68 +250,69 @@ class Models_Server:
config = self.model_server_config
try:
messages = []
has_images = request.image_data and len(request.image_data) > 0
chat_handler = config.get("chat_handler", "")
# 构建 system message
if request.system_prompt:
system_message={
if has_images and "Qwen25VLChatHandler" == chat_handler:
# 多模态情况下system message 的 content 也应该是列表
system_message = {
"role": "system",
"content": [request.system_prompt]
}
"content": [{"type": "text", "text": request.system_prompt}]
}
else:
# 纯文本情况下system message 的 content 应该是字符串
system_message = {
"role": "system",
"content": request.system_prompt
}
messages.append(system_message)
user_message = {
"role": "user",
"content": []
}
#加入用户消息
messages.append(user_message)
if "Qwen25VLChatHandler" == config["chat_handler"]:
for msg in messages:
if "user" == msg["role"]:
# 添加图像到消息
if request.image_data:
len_images = len(request.image_data)
print("len_images: ",len_images)
for i in range(0,len_images):
logger.info(f"add image {i}")
msg["content"].append(
{
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{request.image_data[i]}"}
#f"data:image/jpeg;base64,{request.image_data[i]}"
}
)
#加入user_prompt
msg["content"].append(
{"type": "text", "text": request.user_prompt}
)
# 构建 user message
if has_images:
# 多模态消息content 是列表
user_content = []
if "Qwen25VLChatHandler" == chat_handler:
# Qwen25VL 格式
len_images = len(request.image_data)
logger.info(f"添加 {len_images} 张图像到消息 (Qwen25VL格式)")
for i in range(0, len_images):
logger.info(f"add image {i}")
user_content.append({
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{request.image_data[i]}"}
})
else:
# 其他多模态格式
len_images = len(request.image_data)
logger.info(f"添加 {len_images} 张图像到消息")
for i in range(0, len_images):
logger.info(f"add image {i}")
user_content.append({
"type": "image",
"image": f"data:image/jpeg;base64,{request.image_data[i]}"
})
# 添加文本内容
user_content.append({"type": "text", "text": request.user_prompt})
user_message = {
"role": "user",
"content": user_content
}
else:
for msg in messages:
if "user" == msg["role"]:
# 添加图像到消息
if request.image_data:
len_images = len(request.image_data)
print("len_images: ",len_images)
for i in range(0,len_images):
logger.info("add image ",i)
msg["content"].append(
{
"type": "image",
"image": f"data:image/jpeg;base64,{request.image_data[i]}" #图像数据形式
#f"data:image/jpeg;base64,{request.image_data[i]}"
}
)
#加入user_prompt
msg["content"].append(
{"type": "text", "text": request.user_prompt}
)
# 纯文本消息content 是字符串
user_message = {
"role": "user",
"content": request.user_prompt
}
messages.append(user_message)
return messages
except Exception as e:
print(f"进程 {self.process_id} 构建模型推理message: {str(e)}")
logger.error(f"进程 {self.process_id} 构建模型推理message失败: {str(e)}", exc_info=True)
return None
def text_prompt_inference(self, request: TextInferenceRequest):
@@ -375,13 +376,28 @@ class Models_Server:
temperature = request.temperature,
top_p = request.top_p,
system_prompt = request.system_prompt,
stop = request.stop
stop = request.stop,
image_data = None, # 纯文本推理,没有图像
json_schema = request.json_schema # 传递 JSON Schema
)
messages = self._build_message(inference_request,self.model_server_config)
messages = self._build_message(inference_request, self.model_server_config)
# 检查 messages 是否为 None
if messages is None:
logger.error("构建消息失败_build_message 返回 None")
raise HTTPException(status_code=500, detail="构建推理消息失败")
logger.info(f"构建的消息: {messages}")
inference_start_time = time.time()
# 构建提示
prompt = f"USER: {request.user_prompt}\nASSISTANT:"
# 准备 response_format 参数(如果提供了 json_schema
response_format = None
if request.json_schema:
response_format = {
"type": "json_object",
"schema": request.json_schema
}
logger.info(f"使用 JSON Schema 约束输出格式")
# 生成响应
output = self.model.create_chat_completion(
@@ -389,8 +405,9 @@ class Models_Server:
max_tokens=request.max_tokens,
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop ,
stream=False
stop=request.stop,
stream=False,
response_format=response_format
)
inference_time = time.time() - inference_start_time
@@ -406,7 +423,10 @@ class Models_Server:
}
}
except HTTPException:
raise
except Exception as e:
logger.error(f"推理失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"推理失败: {str(e)}")
def multimodal_inference(self,request: MultimodalInferenceRequest):
@@ -445,6 +465,15 @@ class Models_Server:
len_images = len(request.image_data)
# print(f"full_prompt: {messages}")
# 准备 response_format 参数(如果提供了 json_schema
response_format = None
if hasattr(request, 'json_schema') and request.json_schema:
response_format = {
"type": "json_object",
"schema": request.json_schema
}
logger.info(f"使用 JSON Schema 约束输出格式")
# 生成响应
output = self.model.create_chat_completion(
messages=messages,
@@ -452,7 +481,8 @@ class Models_Server:
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
stream=False
stream=False,
response_format=response_format
)
inference_time = time.time() - inference_start_time
@@ -491,6 +521,15 @@ class Models_Server:
len_images = len(request.image_data)
# print(f"full_prompt: {messages}")
# 准备 response_format 参数(如果提供了 json_schema
response_format = None
if hasattr(request, 'json_schema') and request.json_schema:
response_format = {
"type": "json_object",
"schema": request.json_schema
}
logger.info(f"使用 JSON Schema 约束输出格式")
# 生成响应
output = self.model.create_chat_completion(
messages=messages,
@@ -498,7 +537,8 @@ class Models_Server:
temperature=request.temperature,
top_p=request.top_p,
stop=request.stop,
stream=False
stream=False,
response_format=response_format
)
inference_time = time.time() - inference_start_time
@@ -527,27 +567,51 @@ model_router = InferringRouter() # 支持自动推断参数类型
# # 依赖注入:获取单例模型服务
def get_model_server() -> Models_Server:
"""依赖注入:获取当前进程的单例服务实例"""
global _process_instance # 必须在函数开始处声明
logger.info(f"=== get_model_server 被触发Process ID: {os.getpid()} ===")
if _process_instance is None:
config_json_file= f"{project_root}/models/model_config.json"
print("config_json_file: ",config_json_file)
# global model_server_config
model_server_config = read_json_file(config_json_file)["models_server"][0]
config_json_file = f"{project_root}/config/model_config.json" # ✅ 修正路径
print("config_json_file: ", config_json_file)
# 读取配置文件
config_data = read_json_file(config_json_file)
if config_data is None:
logger.error(f"get_model_server: 无法读取配置文件 '{config_json_file}'")
raise HTTPException(status_code=500, detail=f"无法读取配置文件: {config_json_file}")
# 检查 models_server 配置是否存在
if "models_server" not in config_data:
logger.error("get_model_server: 配置文件中缺少 'models_server' 字段")
raise HTTPException(status_code=500, detail="配置文件中缺少 'models_server' 字段")
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
models_server_data = config_data["models_server"]
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
if model_server_config is None:
logger.error("get_model_server: 配置未初始化!")
raise HTTPException(status_code=500, detail="服务配置未初始化")
# 获取单例实例
model_server = Models_Server(model_server_config)
if not model_server.model_loaded:
model_server._wait_for_model_load(model_server_config)
model_server.init_model(model_server_config["model_inference_framework_type"],
model_server_config)
# _process_initialized = True
# 保存到全局变量
_process_instance = model_server
else:
model_server = _process_instance
if not model_server.model_loaded:
# 重新读取配置以获取 model_server_config
config_json_file = f"{project_root}/config/model_config.json"
config_data = read_json_file(config_json_file)
if config_data is None or "models_server" not in config_data:
logger.error("get_model_server: 无法读取配置文件")
raise HTTPException(status_code=500, detail="无法读取配置文件")
models_server_data = config_data["models_server"]
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
model_server._wait_for_model_load(model_server_config)
model_server.init_model(model_server_config["model_inference_framework_type"],
model_server_config)
@@ -603,8 +667,13 @@ class Models_Controller:
@model_router.post("/text/inference", response_model=Dict[str, Any])
async def text_inference(self, request: TextInferenceRequest):
"""纯文本推理端点"""
return self.model_server.text_inference(request)
try:
return self.model_server.text_inference(request)
except HTTPException:
raise
except Exception as e:
logger.error(f"端点处理失败: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"端点处理失败: {str(e)}")
@model_router.post("/multimodal/inference", response_model=Dict[str, Any])
async def multimodal_inference(self,request: MultimodalInferenceRequest):
@@ -679,13 +748,12 @@ def create_app() -> FastAPI:
def main():
#model server 和client初始化
# "/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/scripts/ai_agent/models/model_config.json"
print(f"project_root111:{project_root}")
config_json_file= f"{project_root}/models/model_config.json"
config_json_file= f"{project_root}/config/model_config.json" # ✅ 修正路径
print("config_json_file: ",config_json_file)
# global model_server_config
model_server_config = read_json_file(config_json_file)["models_server"][0]
config_data = read_json_file(config_json_file)["models_server"]
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
model_server_config = config_data[0] if isinstance(config_data, list) else config_data
uvicorn.run(
app="models_server:create_app",
@@ -696,7 +764,6 @@ def main():
log_level="info"
)
def test_inference(inference_type=0):
from tools.core.common_functions import read_json_file
config_json_file="/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/scripts/ai_agent/models/model_config.json"
@@ -751,7 +818,7 @@ def test_inference(inference_type=0):
if __name__ == "__main__":
# main()
main()
# test_inference(2)
logger.info("work finish")
# logger.info("work finish")

View File

@@ -13,10 +13,11 @@ from PIL import Image as PIL_Image
import base64
import numpy as np
import rospy # type: ignore
from std_msgs.msg import String # type: ignore
from sensor_msgs.msg import Image as Sensor_Image # type: ignore
from cv_bridge import CvBridge, CvBridgeError # type: ignore
# Removed ROS dependencies
# import rospy # type: ignore
# from std_msgs.msg import String # type: ignore
# from sensor_msgs.msg import Image as Sensor_Image # type: ignore
# from cv_bridge import CvBridge, CvBridgeError # type: ignore
import cv2
@@ -25,13 +26,13 @@ current_script = Path(__file__).resolve()
# 向上两级找到项目根目录(根据实际结构调整层级)
project_root = current_script.parents[2] # parents[0] 是当前目录parents[1] 是父目录,以此类推
print("project_root: ",project_root)
# print("project_root: ",project_root)
#添加到搜索路径
sys.path.append(str(project_root))
#查看系统环境
print("sys path:",sys.path)
# print("sys path:",sys.path)
# 配置日志记录器
logging.basicConfig(
@@ -84,91 +85,7 @@ def PIL_image_to_base64( image_path: Optional[str] = None, image: Optional[PIL_I
image.save(buffered, format="JPEG", quality=90)
return base64.b64encode(buffered.getvalue()).decode('utf-8')
def ros_image2pil_image(self, ros_image_msg:Sensor_Image,supported_image_formats):
"""回调函数将ROS Image转为PIL Image并处理"""
try:
bridge_cv = CvBridge()
# 1. 将ROS Image消息转为OpenCV格式默认BGR8编码
# 若图像编码不同如rgb8需指定格式bridge.imgmsg_to_cv2(ros_image_msg, "rgb8")
# 尝试转换为OpenCV格式
if ros_image_msg.encoding in supported_image_formats:
cv_image = bridge_cv.imgmsg_to_cv2(ros_image_msg, desired_encoding=ros_image_msg.encoding)
else:
# 尝试默认转换
cv_image = bridge_cv.imgmsg_to_cv2(ros_image_msg)
logger.warning(f"转换不支持的图像编码: {ros_image_msg.encoding}")
# 2. OpenCV默认是BGR格式转为PIL需要的RGB格式
rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
# 3. 转为PIL Image格式
pil_image = PIL_Image.fromarray(rgb_image)
# 此处可添加对PIL Image的处理如显示、保存等
rospy.loginfo(f"转换成功PIL Image尺寸 {pil_image.size}")
return pil_image
# 示例显示图像需要PIL的显示支持
# pil_image.show()
except CvBridgeError as e:
rospy.logerr(f"转换失败:{e}")
except Exception as e:
rospy.logerr(f"处理错误:{e}")
def ros_image2dict(self, image_msg: Sensor_Image,
supported_image_formats=['bgr8', 'rgb8', 'mono8'],
max_dim:int =2000) :
"""
将ROS图像消息转换为字典格式便于在上下文中存储
参数:
image_msg: ROS的Image消息
返回:
包含图像信息的字典或None(转换失败时)
"""
try:
# 尝试转换为OpenCV格式
if image_msg.encoding in self.supported_image_formats:
cv_image = self.bridge.imgmsg_to_cv2(image_msg, desired_encoding=image_msg.encoding)
else:
# 尝试默认转换
cv_image = self.bridge.imgmsg_to_cv2(image_msg)
logger.warning(f"转换不支持的图像编码: {image_msg.encoding}")
# 图像预处理:调整大小以减少数据量
h, w = cv_image.shape[:2]
if max(h, w) > max_dim:
scale = max_dim / max(h, w)
cv_image = cv2.resize(
cv_image,
(int(w * scale), int(h * scale)),
interpolation=cv2.INTER_AREA
)
# 转换为JPEG并编码为base64
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
_, buffer = cv2.imencode('.jpg', cv_image, encode_param)
img_base64 = base64.b64encode(buffer).decode('utf-8')
return {
"type": "image",
"format": "jpg",
"data": img_base64,
"width": cv_image.shape[1],
"height": cv_image.shape[0],
"original_encoding": image_msg.encoding,
"timestamp": image_msg.header.stamp.to_sec()
}
except CvBridgeError as e:
logger.error(f"CV桥接错误: {str(e)}")
except Exception as e:
logger.error(f"图像处理错误: {str(e)}")
return None
# Removed ROS specific functions (ros_image2pil_image, ros_image2dict) as they depend on rospy/cv_bridge
def PIL_image_to_base64_sizelmt(image_path:Optional[str]=None,image:Optional[PIL_Image.Image] = None,
need_size_lmt:bool= False, max_size:tuple=(800, 800), quality:float=90):

View File

@@ -101,15 +101,16 @@ def set_embeddings(embedding_model_path,
model_config:Optional[dict[str,Any]]):
if "llamacpp_embeddings" == embedding_type:
# 使用 .get() 方法提供默认值,确保即使配置文件中缺少某些字段也能正常工作
embeddings = set_embeddings_llamacpp(
model_path=embedding_model_path,
n_ctx=model_config["n_ctx"],
n_threads=model_config["n_threads"],
n_gpu_layers=model_config["n_gpu_layers"], # RTX 5090建议设20充分利用GPU
n_seq_max=model_config["n_seq_max"],
n_threads_batch = model_config["n_threads_batch"],
flash_attn = model_config["flash_attn"],
verbose=model_config["verbose"]
n_ctx=model_config.get("n_ctx", 512),
n_threads=model_config.get("n_threads", 4),
n_gpu_layers=model_config.get("n_gpu_layers", 0), # RTX 5090建议设20充分利用GPU
n_seq_max=model_config.get("n_seq_max", 128),
n_threads_batch=model_config.get("n_threads_batch", 4),
flash_attn=model_config.get("flash_attn", True),
verbose=model_config.get("verbose", False)
)
elif "huggingFace_embeddings" == embedding_type:
embeddings = set_embeddings_huggingFace(
@@ -156,19 +157,49 @@ def load_vector_database(embeddings,
collection_name:str):
"""加载已存在的向量数据库"""
try:
if os.path.exists(path=persist_directory):
vector_db = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings,
collection_name=collection_name
)
print(f"已加载向量数据库 from {persist_directory}")
return vector_db
else:
if not os.path.exists(path=persist_directory):
logger.warning(f"向量数据库目录不存在: {persist_directory}")
return None
# 先检查集合是否存在
try:
import chromadb
client = chromadb.PersistentClient(path=persist_directory)
collections = client.list_collections()
collection_names = [col.name for col in collections]
if collection_name not in collection_names:
logger.warning(f"集合 '{collection_name}' 不存在于数据库中。")
logger.info(f"可用的集合: {collection_names}")
# 尝试查找有数据的集合
for col in collections:
if col.count() > 0:
logger.info(f"发现非空集合: {col.name} (count: {col.count()})")
return None
# 检查集合是否有数据
target_collection = client.get_collection(name=collection_name)
doc_count = target_collection.count()
if doc_count == 0:
logger.warning(f"集合 '{collection_name}' 存在但为空 (count: 0)")
logger.warning(f"⚠️ 该集合没有数据RAG 检索将返回空结果")
# 仍然返回数据库对象,允许后续操作(如添加数据)
else:
logger.info(f"✅ 集合 '{collection_name}' 包含 {doc_count} 条文档")
except Exception as check_error:
logger.warning(f"检查集合时出错: {check_error},继续尝试加载...")
# 加载向量数据库
vector_db = Chroma(
persist_directory=persist_directory,
embedding_function=embeddings,
collection_name=collection_name
)
logger.info(f"✅ 成功加载向量数据库 from {persist_directory}, 集合: {collection_name}")
return vector_db
except Exception as e:
logger.error(f"vector_database加载失败: {str(e)}")
logger.error(f"vector_database加载失败: {str(e)}", exc_info=True)
return None
#设置文本分割器
@@ -206,7 +237,7 @@ def set_document_loaders(self):
def retrieve_relevant_info(vectorstore,
query: str,
k: int = 3,
score_threshold: float = 0.2
score_threshold: float = None
) -> List[Dict[str, Any]]:
"""
检索与查询相关的信息
@@ -214,7 +245,8 @@ def retrieve_relevant_info(vectorstore,
参数:
query: 查询文本
k: 最多返回的结果数量
score_threshold: 相关性分数阈值
score_threshold: 相关性分数阈值。如果为 None则使用自适应阈值前 k 个结果中最高分数的 1.5 倍)
对于 L2 距离,分数越小越相似,阈值应该是一个较大的值
返回:
包含相关文档内容和分数的列表
@@ -223,21 +255,41 @@ def retrieve_relevant_info(vectorstore,
# 执行相似性搜索,返回带分数的结果
docs_and_scores = vectorstore.similarity_search_with_score(query, k=k)
if not docs_and_scores:
logger.info(f"检索到 0 条结果 (k={k})")
return []
# 如果没有指定阈值,使用自适应阈值
# 对于 L2 距离,取前 k 个结果中最高的分数,然后乘以一个系数作为阈值
if score_threshold is None:
max_score = max(score for _, score in docs_and_scores)
# 使用最高分数的 1.5 倍作为阈值,确保包含所有前 k 个结果
score_threshold = max_score * 1.5
logger.debug(f"自适应阈值: {score_threshold:.2f} (基于最高分数 {max_score:.2f})")
# 过滤并格式化结果
results = []
for doc, score in docs_and_scores:
if score < score_threshold: # 分数越低表示越相似
# 对于 L2 距离,分数越小越相似
# 如果阈值很大(> 1000说明是 L2 距离,使用 < 比较
# 如果阈值很小(< 1说明可能是余弦距离使用 < 比较
if score < score_threshold:
results.append({
"content": doc.page_content,
"metadata": doc.metadata,
"similarity_score": float(score)
})
else:
logger.debug(f"文档因分数 {score:.2f} >= 阈值 {score_threshold:.2f} 被过滤")
logger.info(f"检索到 {len(results)} 条相关信息 (k={k}, 阈值={score_threshold})")
threshold_str = f"{score_threshold:.2f}" if score_threshold is not None else "自适应"
logger.info(f"检索到 {len(results)} 条相关信息 (k={k}, 阈值={threshold_str})")
if results:
logger.debug(f"相似度分数范围: {min(r['similarity_score'] for r in results):.2f} - {max(r['similarity_score'] for r in results):.2f}")
return results
except Exception as e:
logger.error(f"检索相关信息失败: {str(e)}")
logger.error(f"检索相关信息失败: {str(e)}", exc_info=True)
return []
def get_retriever(vectorstore,search_kwargs: Dict[str, Any] = None) -> Any: