first commit
This commit is contained in:
@@ -2,16 +2,16 @@
|
||||
"models_server":
|
||||
{
|
||||
"model_inference_framework_type":"llama_cpp",
|
||||
"multi_modal":1,
|
||||
"multi_modal":0,
|
||||
"chat_handler": "Qwen25VLChatHandler",
|
||||
"vlm_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen2.5-VL-3B-Instruct-GGUF/Qwen2.5-VL-3B-Instruct-Q8_0.gguf",
|
||||
"mmproj_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen2.5-VL-3B-Instruct-GGUF/mmproj-model-f16.gguf",
|
||||
"vlm_model_path":"/home/huangfukk/models/gguf/Qwen/Qwen3-4B/Qwen3-4B-Q5_K_M.gguf",
|
||||
"mmproj_model_path":"",
|
||||
"n_ctx":30720,
|
||||
"n_threads":4,
|
||||
"n_gpu_layers":40,
|
||||
"n_batch":24,
|
||||
"n_ubatch":24,
|
||||
"verbose":0,
|
||||
"verbose":1,
|
||||
"model_server_host":"localhost",
|
||||
"model_server_port":8000,
|
||||
"model_controller_workers":1
|
||||
|
||||
@@ -1,16 +1,16 @@
|
||||
{
|
||||
"rag_mag":{
|
||||
"embedding_model_path":"/home/ubuntu/Workspace/Projects/VLM_VLA/Models/Qwen/Qwen3-Embedding-0.6B-GGUF/Qwen3-Embedding-0.6B-f16.gguf",
|
||||
"vectorstore_persist_directory": "/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/memory/knowledge_base/map/vector_store/osm_map1",
|
||||
"embedding_model_path":"/home/huangfukk/models/gguf/Qwen/Qwen3-Embedding-4B/Qwen3-Embedding-4B-Q4_K_M.gguf",
|
||||
"vectorstore_persist_directory": "/home/huangfukk/AI_Agent/scripts/ai_agent/memory/knowledge_base/map/vector_store/osm_map1",
|
||||
"embedding_framework_type":"llamacpp_embedding",
|
||||
"collection_name":"osm_map_docs",
|
||||
"collection_name":"drone_docs",
|
||||
"model_config_llamacpp":
|
||||
{
|
||||
"n_ctx":512,
|
||||
"n_threads":4,
|
||||
"n_gpu_layers":36,
|
||||
"n_seq_max":256,
|
||||
|
||||
"n_threads_batch":4,
|
||||
"flash_attn":1,
|
||||
"verbose":0
|
||||
},
|
||||
|
||||
1
scripts/ai_agent/groundcontrol
Submodule
1
scripts/ai_agent/groundcontrol
Submodule
Submodule scripts/ai_agent/groundcontrol added at d026107bc2
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -14,6 +14,7 @@ class TextInferenceRequest(BaseModel):
|
||||
top_p: float = 0.95
|
||||
system_prompt: Optional[str] = None
|
||||
stop: Optional[List[str]] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None # JSON Schema 用于约束输出格式
|
||||
|
||||
class MultimodalInferenceRequest(BaseModel):
|
||||
user_prompt: str
|
||||
@@ -32,6 +33,7 @@ class InferenceRequest(BaseModel):
|
||||
top_p: float = 0.95
|
||||
system_prompt: Optional[str] = None
|
||||
stop: Optional[List[str]] = None
|
||||
json_schema: Optional[Dict[str, Any]] = None # JSON Schema 用于约束输出格式
|
||||
|
||||
|
||||
# 响应模型定义
|
||||
|
||||
@@ -135,20 +135,22 @@ class Models_Client:
|
||||
返回:
|
||||
包含推理结果的字典
|
||||
"""
|
||||
# payload = {
|
||||
# "user_prompt": prompt,
|
||||
# "max_tokens": max_tokens,
|
||||
# "temperature": temperature,
|
||||
# "top_p": top_p,
|
||||
# "system_prompt":system_prompt,
|
||||
# "stop": stop
|
||||
# }
|
||||
|
||||
try:
|
||||
# 将 Pydantic 模型转换为 JSON 字符串
|
||||
if hasattr(request, 'model_dump_json'):
|
||||
# Pydantic v2
|
||||
data = request.model_dump_json()
|
||||
elif hasattr(request, 'json'):
|
||||
# Pydantic v1
|
||||
data = request.json()
|
||||
else:
|
||||
# 降级方案:转换为字典再序列化
|
||||
data = json.dumps(request.dict() if hasattr(request, 'dict') else request.model_dump())
|
||||
|
||||
response = requests.post(
|
||||
self.text_endpoint,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(request)
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
@@ -211,44 +213,22 @@ class Models_Client:
|
||||
返回:
|
||||
包含推理结果的字典
|
||||
"""
|
||||
# # 处理图像
|
||||
# image_data = []
|
||||
|
||||
# if images:
|
||||
# # for img in images:
|
||||
# try:
|
||||
# # image_data.append(PIL_image_to_base64(image=img))
|
||||
# image_data= images
|
||||
# except Exception as e:
|
||||
# return {"error": f"处理图像base64对象失败: {str(e)}"}
|
||||
|
||||
# elif image_paths:
|
||||
# for path in image_paths:
|
||||
# try:
|
||||
# image_data.append(PIL_image_to_base64(image_path=path))
|
||||
# except Exception as e:
|
||||
# return {"error": f"处理图像 {path} 失败: {str(e)}"}
|
||||
|
||||
|
||||
|
||||
# if not image_data:
|
||||
# return {"error": "未提供有效的图像数据"}
|
||||
|
||||
# payload = {
|
||||
# "user_prompt": prompt,
|
||||
# "image_data": image_data,
|
||||
# "max_tokens": max_tokens,
|
||||
# "temperature": temperature,
|
||||
# "top_p": top_p,
|
||||
# "system_prompt": system_prompt,
|
||||
# "stop": stop
|
||||
# }
|
||||
|
||||
try:
|
||||
# 将 Pydantic 模型转换为 JSON 字符串
|
||||
if hasattr(request, 'model_dump_json'):
|
||||
# Pydantic v2
|
||||
data = request.model_dump_json()
|
||||
elif hasattr(request, 'json'):
|
||||
# Pydantic v1
|
||||
data = request.json()
|
||||
else:
|
||||
# 降级方案:转换为字典再序列化
|
||||
data = json.dumps(request.dict() if hasattr(request, 'dict') else request.model_dump())
|
||||
|
||||
response = requests.post(
|
||||
url=self.multimodal_endpoint,
|
||||
headers={"Content-Type": "application/json"},
|
||||
data=json.dumps(request)
|
||||
data=data
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@@ -102,24 +102,45 @@ def get_model_server() -> Models_Server:
|
||||
logger.info(f"=== get_model_server 被触发!Process ID: {os.getpid()} ===")
|
||||
|
||||
if _process_instance is None:
|
||||
config_json_file= f"{project_root}/models/model_config.json"
|
||||
print("config_json_file: ",config_json_file)
|
||||
# global model_server_config
|
||||
model_server_config = read_json_file(config_json_file)["models_server"]
|
||||
config_json_file = f"{project_root}/config/model_config.json" # ✅ 修正路径
|
||||
print("config_json_file: ", config_json_file)
|
||||
|
||||
# 读取配置文件
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None:
|
||||
logger.error(f"get_model_server: 无法读取配置文件 '{config_json_file}'")
|
||||
raise HTTPException(status_code=500, detail=f"无法读取配置文件: {config_json_file}")
|
||||
|
||||
# 检查 models_server 配置是否存在
|
||||
if "models_server" not in config_data:
|
||||
logger.error("get_model_server: 配置文件中缺少 'models_server' 字段")
|
||||
raise HTTPException(status_code=500, detail="配置文件中缺少 'models_server' 字段")
|
||||
|
||||
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
|
||||
if model_server_config is None:
|
||||
logger.error("get_model_server: 配置未初始化!")
|
||||
raise HTTPException(status_code=500, detail="服务配置未初始化")
|
||||
|
||||
# 获取单例实例
|
||||
model_server = Models_Server(model_server_config)
|
||||
if not model_server.model_loaded:
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
# _process_initialized = True
|
||||
else:
|
||||
model_server = _process_instance
|
||||
if not model_server.model_loaded:
|
||||
# 重新读取配置以获取 model_server_config
|
||||
config_json_file = f"{project_root}/config/model_config.json"
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None or "models_server" not in config_data:
|
||||
logger.error("get_model_server: 无法读取配置文件")
|
||||
raise HTTPException(status_code=500, detail="无法读取配置文件")
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
|
||||
@@ -250,68 +250,69 @@ class Models_Server:
|
||||
config = self.model_server_config
|
||||
try:
|
||||
messages = []
|
||||
has_images = request.image_data and len(request.image_data) > 0
|
||||
chat_handler = config.get("chat_handler", "")
|
||||
|
||||
# 构建 system message
|
||||
if request.system_prompt:
|
||||
system_message={
|
||||
if has_images and "Qwen25VLChatHandler" == chat_handler:
|
||||
# 多模态情况下,system message 的 content 也应该是列表
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": [request.system_prompt]
|
||||
}
|
||||
"content": [{"type": "text", "text": request.system_prompt}]
|
||||
}
|
||||
else:
|
||||
# 纯文本情况下,system message 的 content 应该是字符串
|
||||
system_message = {
|
||||
"role": "system",
|
||||
"content": request.system_prompt
|
||||
}
|
||||
messages.append(system_message)
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": []
|
||||
}
|
||||
#加入用户消息
|
||||
messages.append(user_message)
|
||||
|
||||
if "Qwen25VLChatHandler" == config["chat_handler"]:
|
||||
for msg in messages:
|
||||
if "user" == msg["role"]:
|
||||
# 添加图像到消息
|
||||
if request.image_data:
|
||||
len_images = len(request.image_data)
|
||||
|
||||
print("len_images: ",len_images)
|
||||
for i in range(0,len_images):
|
||||
logger.info(f"add image {i}")
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{request.image_data[i]}"}
|
||||
#f"data:image/jpeg;base64,{request.image_data[i]}"
|
||||
}
|
||||
)
|
||||
#加入user_prompt
|
||||
msg["content"].append(
|
||||
{"type": "text", "text": request.user_prompt}
|
||||
)
|
||||
|
||||
# 构建 user message
|
||||
if has_images:
|
||||
# 多模态消息:content 是列表
|
||||
user_content = []
|
||||
|
||||
if "Qwen25VLChatHandler" == chat_handler:
|
||||
# Qwen25VL 格式
|
||||
len_images = len(request.image_data)
|
||||
logger.info(f"添加 {len_images} 张图像到消息 (Qwen25VL格式)")
|
||||
for i in range(0, len_images):
|
||||
logger.info(f"add image {i}")
|
||||
user_content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {"url": f"data:image/jpeg;base64,{request.image_data[i]}"}
|
||||
})
|
||||
else:
|
||||
# 其他多模态格式
|
||||
len_images = len(request.image_data)
|
||||
logger.info(f"添加 {len_images} 张图像到消息")
|
||||
for i in range(0, len_images):
|
||||
logger.info(f"add image {i}")
|
||||
user_content.append({
|
||||
"type": "image",
|
||||
"image": f"data:image/jpeg;base64,{request.image_data[i]}"
|
||||
})
|
||||
|
||||
# 添加文本内容
|
||||
user_content.append({"type": "text", "text": request.user_prompt})
|
||||
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": user_content
|
||||
}
|
||||
else:
|
||||
for msg in messages:
|
||||
if "user" == msg["role"]:
|
||||
# 添加图像到消息
|
||||
if request.image_data:
|
||||
len_images = len(request.image_data)
|
||||
|
||||
print("len_images: ",len_images)
|
||||
for i in range(0,len_images):
|
||||
logger.info("add image ",i)
|
||||
msg["content"].append(
|
||||
{
|
||||
"type": "image",
|
||||
"image": f"data:image/jpeg;base64,{request.image_data[i]}" #图像数据形式
|
||||
#f"data:image/jpeg;base64,{request.image_data[i]}"
|
||||
}
|
||||
)
|
||||
#加入user_prompt
|
||||
msg["content"].append(
|
||||
{"type": "text", "text": request.user_prompt}
|
||||
)
|
||||
# 纯文本消息:content 是字符串
|
||||
user_message = {
|
||||
"role": "user",
|
||||
"content": request.user_prompt
|
||||
}
|
||||
|
||||
messages.append(user_message)
|
||||
return messages
|
||||
except Exception as e:
|
||||
print(f"进程 {self.process_id} 构建模型推理message: {str(e)}")
|
||||
logger.error(f"进程 {self.process_id} 构建模型推理message失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
def text_prompt_inference(self, request: TextInferenceRequest):
|
||||
@@ -375,13 +376,28 @@ class Models_Server:
|
||||
temperature = request.temperature,
|
||||
top_p = request.top_p,
|
||||
system_prompt = request.system_prompt,
|
||||
stop = request.stop
|
||||
stop = request.stop,
|
||||
image_data = None, # 纯文本推理,没有图像
|
||||
json_schema = request.json_schema # 传递 JSON Schema
|
||||
)
|
||||
messages = self._build_message(inference_request,self.model_server_config)
|
||||
messages = self._build_message(inference_request, self.model_server_config)
|
||||
|
||||
# 检查 messages 是否为 None
|
||||
if messages is None:
|
||||
logger.error("构建消息失败:_build_message 返回 None")
|
||||
raise HTTPException(status_code=500, detail="构建推理消息失败")
|
||||
|
||||
logger.info(f"构建的消息: {messages}")
|
||||
inference_start_time = time.time()
|
||||
|
||||
# 构建提示
|
||||
prompt = f"USER: {request.user_prompt}\nASSISTANT:"
|
||||
# 准备 response_format 参数(如果提供了 json_schema)
|
||||
response_format = None
|
||||
if request.json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": request.json_schema
|
||||
}
|
||||
logger.info(f"使用 JSON Schema 约束输出格式")
|
||||
|
||||
# 生成响应
|
||||
output = self.model.create_chat_completion(
|
||||
@@ -389,8 +405,9 @@ class Models_Server:
|
||||
max_tokens=request.max_tokens,
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop ,
|
||||
stream=False
|
||||
stop=request.stop,
|
||||
stream=False,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start_time
|
||||
@@ -406,7 +423,10 @@ class Models_Server:
|
||||
|
||||
}
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"推理失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"推理失败: {str(e)}")
|
||||
|
||||
def multimodal_inference(self,request: MultimodalInferenceRequest):
|
||||
@@ -445,6 +465,15 @@ class Models_Server:
|
||||
len_images = len(request.image_data)
|
||||
# print(f"full_prompt: {messages}")
|
||||
|
||||
# 准备 response_format 参数(如果提供了 json_schema)
|
||||
response_format = None
|
||||
if hasattr(request, 'json_schema') and request.json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": request.json_schema
|
||||
}
|
||||
logger.info(f"使用 JSON Schema 约束输出格式")
|
||||
|
||||
# 生成响应
|
||||
output = self.model.create_chat_completion(
|
||||
messages=messages,
|
||||
@@ -452,7 +481,8 @@ class Models_Server:
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stream=False
|
||||
stream=False,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start_time
|
||||
@@ -491,6 +521,15 @@ class Models_Server:
|
||||
len_images = len(request.image_data)
|
||||
# print(f"full_prompt: {messages}")
|
||||
|
||||
# 准备 response_format 参数(如果提供了 json_schema)
|
||||
response_format = None
|
||||
if hasattr(request, 'json_schema') and request.json_schema:
|
||||
response_format = {
|
||||
"type": "json_object",
|
||||
"schema": request.json_schema
|
||||
}
|
||||
logger.info(f"使用 JSON Schema 约束输出格式")
|
||||
|
||||
# 生成响应
|
||||
output = self.model.create_chat_completion(
|
||||
messages=messages,
|
||||
@@ -498,7 +537,8 @@ class Models_Server:
|
||||
temperature=request.temperature,
|
||||
top_p=request.top_p,
|
||||
stop=request.stop,
|
||||
stream=False
|
||||
stream=False,
|
||||
response_format=response_format
|
||||
)
|
||||
|
||||
inference_time = time.time() - inference_start_time
|
||||
@@ -527,27 +567,51 @@ model_router = InferringRouter() # 支持自动推断参数类型
|
||||
# # 依赖注入:获取单例模型服务
|
||||
def get_model_server() -> Models_Server:
|
||||
"""依赖注入:获取当前进程的单例服务实例"""
|
||||
global _process_instance # 必须在函数开始处声明
|
||||
logger.info(f"=== get_model_server 被触发!Process ID: {os.getpid()} ===")
|
||||
|
||||
if _process_instance is None:
|
||||
config_json_file= f"{project_root}/models/model_config.json"
|
||||
print("config_json_file: ",config_json_file)
|
||||
# global model_server_config
|
||||
model_server_config = read_json_file(config_json_file)["models_server"][0]
|
||||
config_json_file = f"{project_root}/config/model_config.json" # ✅ 修正路径
|
||||
print("config_json_file: ", config_json_file)
|
||||
|
||||
# 读取配置文件
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None:
|
||||
logger.error(f"get_model_server: 无法读取配置文件 '{config_json_file}'")
|
||||
raise HTTPException(status_code=500, detail=f"无法读取配置文件: {config_json_file}")
|
||||
|
||||
# 检查 models_server 配置是否存在
|
||||
if "models_server" not in config_data:
|
||||
logger.error("get_model_server: 配置文件中缺少 'models_server' 字段")
|
||||
raise HTTPException(status_code=500, detail="配置文件中缺少 'models_server' 字段")
|
||||
|
||||
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
|
||||
if model_server_config is None:
|
||||
logger.error("get_model_server: 配置未初始化!")
|
||||
raise HTTPException(status_code=500, detail="服务配置未初始化")
|
||||
|
||||
# 获取单例实例
|
||||
model_server = Models_Server(model_server_config)
|
||||
if not model_server.model_loaded:
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
# _process_initialized = True
|
||||
# 保存到全局变量
|
||||
_process_instance = model_server
|
||||
else:
|
||||
model_server = _process_instance
|
||||
if not model_server.model_loaded:
|
||||
# 重新读取配置以获取 model_server_config
|
||||
config_json_file = f"{project_root}/config/model_config.json"
|
||||
config_data = read_json_file(config_json_file)
|
||||
if config_data is None or "models_server" not in config_data:
|
||||
logger.error("get_model_server: 无法读取配置文件")
|
||||
raise HTTPException(status_code=500, detail="无法读取配置文件")
|
||||
models_server_data = config_data["models_server"]
|
||||
model_server_config = models_server_data[0] if isinstance(models_server_data, list) else models_server_data
|
||||
model_server._wait_for_model_load(model_server_config)
|
||||
model_server.init_model(model_server_config["model_inference_framework_type"],
|
||||
model_server_config)
|
||||
@@ -603,8 +667,13 @@ class Models_Controller:
|
||||
@model_router.post("/text/inference", response_model=Dict[str, Any])
|
||||
async def text_inference(self, request: TextInferenceRequest):
|
||||
"""纯文本推理端点"""
|
||||
|
||||
return self.model_server.text_inference(request)
|
||||
try:
|
||||
return self.model_server.text_inference(request)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"端点处理失败: {str(e)}", exc_info=True)
|
||||
raise HTTPException(status_code=500, detail=f"端点处理失败: {str(e)}")
|
||||
|
||||
@model_router.post("/multimodal/inference", response_model=Dict[str, Any])
|
||||
async def multimodal_inference(self,request: MultimodalInferenceRequest):
|
||||
@@ -679,13 +748,12 @@ def create_app() -> FastAPI:
|
||||
|
||||
|
||||
def main():
|
||||
#model server 和client初始化
|
||||
# "/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/scripts/ai_agent/models/model_config.json"
|
||||
print(f"project_root111:{project_root}")
|
||||
config_json_file= f"{project_root}/models/model_config.json"
|
||||
config_json_file= f"{project_root}/config/model_config.json" # ✅ 修正路径
|
||||
print("config_json_file: ",config_json_file)
|
||||
# global model_server_config
|
||||
model_server_config = read_json_file(config_json_file)["models_server"][0]
|
||||
config_data = read_json_file(config_json_file)["models_server"]
|
||||
# 如果 models_server 是列表,取第一个;如果是字典,直接使用
|
||||
model_server_config = config_data[0] if isinstance(config_data, list) else config_data
|
||||
|
||||
uvicorn.run(
|
||||
app="models_server:create_app",
|
||||
@@ -696,7 +764,6 @@ def main():
|
||||
log_level="info"
|
||||
)
|
||||
|
||||
|
||||
def test_inference(inference_type=0):
|
||||
from tools.core.common_functions import read_json_file
|
||||
config_json_file="/home/ubuntu/Workspace/Projects/VLM_VLA/UAV_AI/src/Model/AI_Agent/scripts/ai_agent/models/model_config.json"
|
||||
@@ -751,7 +818,7 @@ def test_inference(inference_type=0):
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# main()
|
||||
main()
|
||||
# test_inference(2)
|
||||
logger.info("work finish")
|
||||
# logger.info("work finish")
|
||||
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -13,10 +13,11 @@ from PIL import Image as PIL_Image
|
||||
import base64
|
||||
import numpy as np
|
||||
|
||||
import rospy # type: ignore
|
||||
from std_msgs.msg import String # type: ignore
|
||||
from sensor_msgs.msg import Image as Sensor_Image # type: ignore
|
||||
from cv_bridge import CvBridge, CvBridgeError # type: ignore
|
||||
# Removed ROS dependencies
|
||||
# import rospy # type: ignore
|
||||
# from std_msgs.msg import String # type: ignore
|
||||
# from sensor_msgs.msg import Image as Sensor_Image # type: ignore
|
||||
# from cv_bridge import CvBridge, CvBridgeError # type: ignore
|
||||
import cv2
|
||||
|
||||
|
||||
@@ -25,13 +26,13 @@ current_script = Path(__file__).resolve()
|
||||
# 向上两级找到项目根目录(根据实际结构调整层级)
|
||||
project_root = current_script.parents[2] # parents[0] 是当前目录,parents[1] 是父目录,以此类推
|
||||
|
||||
print("project_root: ",project_root)
|
||||
# print("project_root: ",project_root)
|
||||
|
||||
#添加到搜索路径
|
||||
sys.path.append(str(project_root))
|
||||
|
||||
#查看系统环境
|
||||
print("sys path:",sys.path)
|
||||
# print("sys path:",sys.path)
|
||||
|
||||
# 配置日志记录器
|
||||
logging.basicConfig(
|
||||
@@ -84,91 +85,7 @@ def PIL_image_to_base64( image_path: Optional[str] = None, image: Optional[PIL_I
|
||||
image.save(buffered, format="JPEG", quality=90)
|
||||
return base64.b64encode(buffered.getvalue()).decode('utf-8')
|
||||
|
||||
def ros_image2pil_image(self, ros_image_msg:Sensor_Image,supported_image_formats):
|
||||
"""回调函数:将ROS Image转为PIL Image并处理"""
|
||||
try:
|
||||
bridge_cv = CvBridge()
|
||||
# 1. 将ROS Image消息转为OpenCV格式(默认BGR8编码)
|
||||
# 若图像编码不同(如rgb8),需指定格式:bridge.imgmsg_to_cv2(ros_image_msg, "rgb8")
|
||||
# 尝试转换为OpenCV格式
|
||||
if ros_image_msg.encoding in supported_image_formats:
|
||||
cv_image = bridge_cv.imgmsg_to_cv2(ros_image_msg, desired_encoding=ros_image_msg.encoding)
|
||||
else:
|
||||
# 尝试默认转换
|
||||
cv_image = bridge_cv.imgmsg_to_cv2(ros_image_msg)
|
||||
logger.warning(f"转换不支持的图像编码: {ros_image_msg.encoding}")
|
||||
|
||||
|
||||
# 2. OpenCV默认是BGR格式,转为PIL需要的RGB格式
|
||||
rgb_image = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# 3. 转为PIL Image格式
|
||||
pil_image = PIL_Image.fromarray(rgb_image)
|
||||
|
||||
# 此处可添加对PIL Image的处理(如显示、保存等)
|
||||
rospy.loginfo(f"转换成功:PIL Image尺寸 {pil_image.size}")
|
||||
return pil_image
|
||||
# 示例:显示图像(需要PIL的显示支持)
|
||||
# pil_image.show()
|
||||
|
||||
except CvBridgeError as e:
|
||||
rospy.logerr(f"转换失败:{e}")
|
||||
except Exception as e:
|
||||
rospy.logerr(f"处理错误:{e}")
|
||||
|
||||
|
||||
def ros_image2dict(self, image_msg: Sensor_Image,
|
||||
supported_image_formats=['bgr8', 'rgb8', 'mono8'],
|
||||
max_dim:int =2000) :
|
||||
"""
|
||||
将ROS图像消息转换为字典格式,便于在上下文中存储
|
||||
|
||||
参数:
|
||||
image_msg: ROS的Image消息
|
||||
|
||||
返回:
|
||||
包含图像信息的字典,或None(转换失败时)
|
||||
"""
|
||||
try:
|
||||
# 尝试转换为OpenCV格式
|
||||
if image_msg.encoding in self.supported_image_formats:
|
||||
cv_image = self.bridge.imgmsg_to_cv2(image_msg, desired_encoding=image_msg.encoding)
|
||||
else:
|
||||
# 尝试默认转换
|
||||
cv_image = self.bridge.imgmsg_to_cv2(image_msg)
|
||||
logger.warning(f"转换不支持的图像编码: {image_msg.encoding}")
|
||||
|
||||
# 图像预处理:调整大小以减少数据量
|
||||
h, w = cv_image.shape[:2]
|
||||
if max(h, w) > max_dim:
|
||||
scale = max_dim / max(h, w)
|
||||
cv_image = cv2.resize(
|
||||
cv_image,
|
||||
(int(w * scale), int(h * scale)),
|
||||
interpolation=cv2.INTER_AREA
|
||||
)
|
||||
|
||||
# 转换为JPEG并编码为base64
|
||||
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 80]
|
||||
_, buffer = cv2.imencode('.jpg', cv_image, encode_param)
|
||||
img_base64 = base64.b64encode(buffer).decode('utf-8')
|
||||
|
||||
return {
|
||||
"type": "image",
|
||||
"format": "jpg",
|
||||
"data": img_base64,
|
||||
"width": cv_image.shape[1],
|
||||
"height": cv_image.shape[0],
|
||||
"original_encoding": image_msg.encoding,
|
||||
"timestamp": image_msg.header.stamp.to_sec()
|
||||
}
|
||||
|
||||
except CvBridgeError as e:
|
||||
logger.error(f"CV桥接错误: {str(e)}")
|
||||
except Exception as e:
|
||||
logger.error(f"图像处理错误: {str(e)}")
|
||||
return None
|
||||
|
||||
# Removed ROS specific functions (ros_image2pil_image, ros_image2dict) as they depend on rospy/cv_bridge
|
||||
|
||||
def PIL_image_to_base64_sizelmt(image_path:Optional[str]=None,image:Optional[PIL_Image.Image] = None,
|
||||
need_size_lmt:bool= False, max_size:tuple=(800, 800), quality:float=90):
|
||||
|
||||
Binary file not shown.
Binary file not shown.
Binary file not shown.
@@ -101,15 +101,16 @@ def set_embeddings(embedding_model_path,
|
||||
model_config:Optional[dict[str,Any]]):
|
||||
|
||||
if "llamacpp_embeddings" == embedding_type:
|
||||
# 使用 .get() 方法提供默认值,确保即使配置文件中缺少某些字段也能正常工作
|
||||
embeddings = set_embeddings_llamacpp(
|
||||
model_path=embedding_model_path,
|
||||
n_ctx=model_config["n_ctx"],
|
||||
n_threads=model_config["n_threads"],
|
||||
n_gpu_layers=model_config["n_gpu_layers"], # RTX 5090建议设20,充分利用GPU
|
||||
n_seq_max=model_config["n_seq_max"],
|
||||
n_threads_batch = model_config["n_threads_batch"],
|
||||
flash_attn = model_config["flash_attn"],
|
||||
verbose=model_config["verbose"]
|
||||
n_ctx=model_config.get("n_ctx", 512),
|
||||
n_threads=model_config.get("n_threads", 4),
|
||||
n_gpu_layers=model_config.get("n_gpu_layers", 0), # RTX 5090建议设20,充分利用GPU
|
||||
n_seq_max=model_config.get("n_seq_max", 128),
|
||||
n_threads_batch=model_config.get("n_threads_batch", 4),
|
||||
flash_attn=model_config.get("flash_attn", True),
|
||||
verbose=model_config.get("verbose", False)
|
||||
)
|
||||
elif "huggingFace_embeddings" == embedding_type:
|
||||
embeddings = set_embeddings_huggingFace(
|
||||
@@ -156,19 +157,49 @@ def load_vector_database(embeddings,
|
||||
collection_name:str):
|
||||
"""加载已存在的向量数据库"""
|
||||
try:
|
||||
if os.path.exists(path=persist_directory):
|
||||
vector_db = Chroma(
|
||||
persist_directory=persist_directory,
|
||||
embedding_function=embeddings,
|
||||
collection_name=collection_name
|
||||
)
|
||||
print(f"已加载向量数据库 from {persist_directory}")
|
||||
|
||||
return vector_db
|
||||
else:
|
||||
if not os.path.exists(path=persist_directory):
|
||||
logger.warning(f"向量数据库目录不存在: {persist_directory}")
|
||||
return None
|
||||
|
||||
# 先检查集合是否存在
|
||||
try:
|
||||
import chromadb
|
||||
client = chromadb.PersistentClient(path=persist_directory)
|
||||
collections = client.list_collections()
|
||||
collection_names = [col.name for col in collections]
|
||||
|
||||
if collection_name not in collection_names:
|
||||
logger.warning(f"集合 '{collection_name}' 不存在于数据库中。")
|
||||
logger.info(f"可用的集合: {collection_names}")
|
||||
# 尝试查找有数据的集合
|
||||
for col in collections:
|
||||
if col.count() > 0:
|
||||
logger.info(f"发现非空集合: {col.name} (count: {col.count()})")
|
||||
return None
|
||||
|
||||
# 检查集合是否有数据
|
||||
target_collection = client.get_collection(name=collection_name)
|
||||
doc_count = target_collection.count()
|
||||
if doc_count == 0:
|
||||
logger.warning(f"集合 '{collection_name}' 存在但为空 (count: 0)")
|
||||
logger.warning(f"⚠️ 该集合没有数据,RAG 检索将返回空结果")
|
||||
# 仍然返回数据库对象,允许后续操作(如添加数据)
|
||||
else:
|
||||
logger.info(f"✅ 集合 '{collection_name}' 包含 {doc_count} 条文档")
|
||||
except Exception as check_error:
|
||||
logger.warning(f"检查集合时出错: {check_error},继续尝试加载...")
|
||||
|
||||
# 加载向量数据库
|
||||
vector_db = Chroma(
|
||||
persist_directory=persist_directory,
|
||||
embedding_function=embeddings,
|
||||
collection_name=collection_name
|
||||
)
|
||||
logger.info(f"✅ 成功加载向量数据库 from {persist_directory}, 集合: {collection_name}")
|
||||
return vector_db
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"vector_database加载失败: {str(e)}")
|
||||
logger.error(f"vector_database加载失败: {str(e)}", exc_info=True)
|
||||
return None
|
||||
|
||||
#设置文本分割器
|
||||
@@ -206,7 +237,7 @@ def set_document_loaders(self):
|
||||
def retrieve_relevant_info(vectorstore,
|
||||
query: str,
|
||||
k: int = 3,
|
||||
score_threshold: float = 0.2
|
||||
score_threshold: float = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
检索与查询相关的信息
|
||||
@@ -214,7 +245,8 @@ def retrieve_relevant_info(vectorstore,
|
||||
参数:
|
||||
query: 查询文本
|
||||
k: 最多返回的结果数量
|
||||
score_threshold: 相关性分数阈值
|
||||
score_threshold: 相关性分数阈值。如果为 None,则使用自适应阈值(前 k 个结果中最高分数的 1.5 倍)
|
||||
对于 L2 距离,分数越小越相似,阈值应该是一个较大的值
|
||||
|
||||
返回:
|
||||
包含相关文档内容和分数的列表
|
||||
@@ -223,21 +255,41 @@ def retrieve_relevant_info(vectorstore,
|
||||
# 执行相似性搜索,返回带分数的结果
|
||||
docs_and_scores = vectorstore.similarity_search_with_score(query, k=k)
|
||||
|
||||
if not docs_and_scores:
|
||||
logger.info(f"检索到 0 条结果 (k={k})")
|
||||
return []
|
||||
|
||||
# 如果没有指定阈值,使用自适应阈值
|
||||
# 对于 L2 距离,取前 k 个结果中最高的分数,然后乘以一个系数作为阈值
|
||||
if score_threshold is None:
|
||||
max_score = max(score for _, score in docs_and_scores)
|
||||
# 使用最高分数的 1.5 倍作为阈值,确保包含所有前 k 个结果
|
||||
score_threshold = max_score * 1.5
|
||||
logger.debug(f"自适应阈值: {score_threshold:.2f} (基于最高分数 {max_score:.2f})")
|
||||
|
||||
# 过滤并格式化结果
|
||||
results = []
|
||||
for doc, score in docs_and_scores:
|
||||
if score < score_threshold: # 分数越低表示越相似
|
||||
# 对于 L2 距离,分数越小越相似
|
||||
# 如果阈值很大(> 1000),说明是 L2 距离,使用 < 比较
|
||||
# 如果阈值很小(< 1),说明可能是余弦距离,使用 < 比较
|
||||
if score < score_threshold:
|
||||
results.append({
|
||||
"content": doc.page_content,
|
||||
"metadata": doc.metadata,
|
||||
"similarity_score": float(score)
|
||||
})
|
||||
else:
|
||||
logger.debug(f"文档因分数 {score:.2f} >= 阈值 {score_threshold:.2f} 被过滤")
|
||||
|
||||
logger.info(f"检索到 {len(results)} 条相关信息 (k={k}, 阈值={score_threshold})")
|
||||
threshold_str = f"{score_threshold:.2f}" if score_threshold is not None else "自适应"
|
||||
logger.info(f"检索到 {len(results)} 条相关信息 (k={k}, 阈值={threshold_str})")
|
||||
if results:
|
||||
logger.debug(f"相似度分数范围: {min(r['similarity_score'] for r in results):.2f} - {max(r['similarity_score'] for r in results):.2f}")
|
||||
return results
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"检索相关信息失败: {str(e)}")
|
||||
logger.error(f"检索相关信息失败: {str(e)}", exc_info=True)
|
||||
return []
|
||||
|
||||
def get_retriever(vectorstore,search_kwargs: Dict[str, Any] = None) -> Any:
|
||||
|
||||
Reference in New Issue
Block a user