优化交互式测试验证脚本,针对场景4修改提示词以及代码

This commit is contained in:
2026-01-02 16:28:58 +08:00
parent c08cdfb339
commit 6f990e645d
31 changed files with 71855 additions and 184 deletions

39
tools/rag/README.md Normal file
View File

@@ -0,0 +1,39 @@
# RAG & Map Tools
该目录包含了地图构建、知识库生成和向量数据库管理的相关工具。
## 目录结构
- **knowledge_base/**: 存放源文档数据。
- 支持格式: `.txt`, `.md`, `.pdf`
- 生成格式: `.json`, `.ndjson` (由 `build_knowledge_base.py` 生成)
- **map/**: 存放地图原始数据。
- `.osm` (OpenStreetMap 数据)
- `.world` (Gazebo 仿真环境数据)
- **vector_store/**: ChromaDB 向量数据库的持久化存储目录。
## 脚本说明
### 1. `build_knowledge_base.py`
**功能**: 处理 `map/` 目录下的地图文件,提取地理信息和语义描述,生成知识库文件到 `knowledge_base/` 目录。
**使用方法**:
```bash
python build_knowledge_base.py
```
### 2. `ingest.py`
**功能**: 读取 `knowledge_base/` 中的所有文档调用嵌入模型Embedding Model将其向量化并存入 `vector_store/` 中的 ChromaDB 数据库。
**使用方法**:
```bash
python ingest.py
```
**依赖**: 需要确保后端嵌入服务(如 `llama-server`)已启动,或者配置正确的 `ORIN_IP` 环境变量。
## 工作流
1. 将地图文件放入 `map/`
2. 运行 `build_knowledge_base.py` 生成文本描述。
3. 将其他补充文档放入 `knowledge_base/`
4. 运行 `ingest.py` 构建向量索引。

View File

@@ -0,0 +1,157 @@
import xml.etree.ElementTree as ET
import json
from pathlib import Path
import logging
import os
# --- 配置日志 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def process_osm_json(input_path: Path) -> list[str]:
"""
处理OpenStreetMap的JSON文件返回描述性句子列表。
"""
logging.info(f"正在以OSM JSON格式处理文件: {input_path.name}")
descriptions = []
try:
with open(input_path, 'r', encoding='utf-8') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logging.error(f"读取或解析 {input_path.name} 时出错: {e}")
return []
elements = data.get('elements', [])
if not elements:
return []
nodes_map = {node['id']: node for node in elements if node.get('type') == 'node'}
ways = [elem for elem in elements if elem.get('type') == 'way']
for way in ways:
tags = way.get('tags', {})
if 'name' not in tags:
continue
way_name = tags.get('name')
way_nodes_ids = way.get('nodes', [])
if not way_nodes_ids:
continue
total_lat, total_lon, node_count = 0, 0, 0
for node_id in way_nodes_ids:
node_info = nodes_map.get(node_id)
if node_info:
total_lat += node_info.get('lat', 0)
total_lon += node_info.get('lon', 0)
node_count += 1
if node_count == 0:
continue
center_lat = total_lat / node_count
center_lon = total_lon / node_count
sentence = f"在地图上有一个名为 '{way_name}' 的地点或区域"
other_tags = {k: v for k, v in tags.items() if k != 'name'}
if other_tags:
tag_descs = [f"{key}'{value}'" for key, value in other_tags.items()]
sentence += f",它的{ ''.join(tag_descs) }"
sentence += f",其中心位置坐标大约在 ({center_lat:.6f}, {center_lon:.6f})。"
descriptions.append(sentence)
logging.info(f"{input_path.name} 提取了 {len(descriptions)} 条位置描述。")
return descriptions
def process_gazebo_world(input_path: Path) -> list[str]:
"""
处理Gazebo的.world文件返回描述性句子列表。
"""
logging.info(f"正在以Gazebo World格式处理文件: {input_path.name}")
descriptions = []
try:
tree = ET.parse(input_path)
root = tree.getroot()
except ET.ParseError as e:
logging.error(f"解析XML文件 {input_path.name} 失败: {e}")
return []
models = root.findall('.//model')
for model in models:
model_name = model.get('name')
pose_element = model.find('pose')
if model_name and pose_element is not None and pose_element.text:
try:
pose_values = [float(p) for p in pose_element.text.strip().split()]
sentence = (
f"仿真环境中有一个名为 '{model_name}' 的物体,"
f"其位置和姿态(x, y, z, roll, pitch, yaw)为: {pose_values}"
)
descriptions.append(sentence)
except (ValueError, IndexError):
logging.warning(f"跳过模型 '{model_name}'因其pose格式不正确。")
logging.info(f"{input_path.name} 提取了 {len(descriptions)} 个物体信息。")
return descriptions
def main():
"""
主函数扫描源数据目录为每个文件生成独立的NDJSON知识库。
"""
script_dir = Path(__file__).resolve().parent
# 输入源: tools/map/
source_data_dir = script_dir / 'map'
# 输出目录: tools/knowledge_base/
output_knowledge_base_dir = script_dir / 'knowledge_base'
if not source_data_dir.exists():
logging.error(f"源数据目录不存在: {source_data_dir}")
return
# 确保输出目录存在
output_knowledge_base_dir.mkdir(parents=True, exist_ok=True)
total_files_processed = 0
logging.info(f"--- 开始扫描源数据目录: {source_data_dir} ---")
for file_path in source_data_dir.iterdir():
if not file_path.is_file():
continue
descriptions = []
if file_path.suffix == '.json':
descriptions = process_osm_json(file_path)
elif file_path.suffix == '.world':
descriptions = process_gazebo_world(file_path)
else:
logging.warning(f"跳过不支持的文件类型: {file_path.name}")
continue
if not descriptions:
logging.warning(f"未能从 {file_path.name} 提取有效信息,跳过生成文件。")
continue
output_filename = file_path.stem + '_knowledge.ndjson'
output_path = output_knowledge_base_dir / output_filename
try:
with open(output_path, 'w', encoding='utf-8') as f:
for sentence in descriptions:
json_record = {"text": sentence}
f.write(json.dumps(json_record, ensure_ascii=False) + '\n')
logging.info(f"成功为 '{file_path.name}' 生成知识库文件: {output_path.name}")
total_files_processed += 1
except IOError as e:
logging.error(f"写入输出文件 '{output_path.name}' 失败: {e}")
logging.info("--- 数据处理完成 ---")
if total_files_processed > 0:
logging.info(f"共为 {total_files_processed} 个源文件生成了知识库。")
else:
logging.warning("未生成任何知识库文件。")
if __name__ == "__main__":
main()

194
tools/rag/ingest.py Normal file
View File

@@ -0,0 +1,194 @@
# 该代码用于将本地知识库中的文档导入到ChromaDB中并使用远程嵌入模型进行向量化
import os
from pathlib import Path
import chromadb
# from chromadb.utils import embedding_functions - 不再需要
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
from unstructured.partition.auto import partition
from rich.progress import track
import logging
import requests # 导入requests
import json # 导入json模块
# 配置日志
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
# --- 配置 ---
# 获取脚本所在目录,确保路径的正确性
SCRIPT_DIR = Path(__file__).resolve().parent
KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base"
VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store"
COLLECTION_NAME = "drone_docs"
# EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理
# --- 自定义远程嵌入函数 ---
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
"""
一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。
"""
def __init__(self, api_url: str):
self._api_url = api_url
logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}")
def __call__(self, input: Embeddable) -> Embeddings:
"""
对输入的文档进行嵌入。
"""
# 我们的服务只能处理文本,所以检查输入是否为字符串列表
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。")
return []
try:
# 移除 "model" 参数因为embedding服务可能不需要它
response = requests.post(
self._api_url,
json={"input": input},
headers={"Content-Type": "application/json"}
)
response.raise_for_status() # 如果请求失败则抛出HTTPError
# 按照OpenAI API的格式解析返回的嵌入向量
data = response.json().get("data", [])
if not data:
raise ValueError("API响应中没有找到'data'字段或'data'为空")
embeddings = [item['embedding'] for item in data]
return embeddings
except requests.RequestException as e:
logging.error(f"调用嵌入API失败: {e}")
# 返回一个空列表或根据需要处理错误
return []
except (ValueError, KeyError) as e:
logging.error(f"解析API响应失败: {e}")
logging.error(f"收到的响应内容: {response.text}")
return []
def get_documents(directory: Path):
"""从知识库目录加载所有文档并进行切分"""
documents = []
logging.info(f"'{directory}' 加载文档...")
for file_path in directory.rglob("*"):
if file_path.is_file() and not file_path.name.startswith('.'):
try:
# 对简单文本文件直接读取
if file_path.suffix in ['.txt', '.md']:
text = file_path.read_text(encoding='utf-8')
documents.append({
"text": text,
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理文本文件: {file_path.name}")
# 特别处理常规的JSON文件
elif file_path.suffix == '.json':
with open(file_path, 'r', encoding='utf-8') as f:
data = json.load(f)
if 'elements' in data and isinstance(data['elements'], list):
for element in data['elements']:
documents.append({
"text": json.dumps(element, ensure_ascii=False),
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。")
else:
documents.append({
"text": json.dumps(data, ensure_ascii=False),
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)")
# 新增:专门处理我们生成的 NDJSON 文件
elif file_path.suffix == '.ndjson':
with open(file_path, 'r', encoding='utf-8') as f:
count = 0
for line in f:
try:
record = json.loads(line)
if 'text' in record and isinstance(record['text'], str):
documents.append({
"text": record['text'],
"metadata": {"source": str(file_path.name)}
})
count += 1
except json.JSONDecodeError:
logging.warning(f"跳过无效的JSON行: {line.strip()}")
if count > 0:
logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。")
# 对其他所有文件类型使用unstructured
else:
elements = partition(filename=str(file_path))
for element in elements:
documents.append({
"text": element.text,
"metadata": {"source": str(file_path.name)}
})
logging.info(f"成功处理文件: {file_path.name} (使用unstructured)")
except Exception as e:
logging.error(f"处理文件 {file_path.name} 失败: {e}")
return documents
def main():
"""主函数,执行文档入库流程"""
if not KNOWLEDGE_BASE_DIR.exists():
KNOWLEDGE_BASE_DIR.mkdir(parents=True)
logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}")
logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。")
return
# 1. 加载并切分文档
docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR)
if not docs_to_ingest:
logging.warning("在知识库中未找到可处理的文档。")
return
# 2. 初始化ChromaDB客户端和远程嵌入函数
orin_ip = os.getenv("ORIN_IP", "localhost")
embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings"
logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}")
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR))
# 3. 创建或获取集合
logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}")
collection = client.get_or_create_collection(
name=COLLECTION_NAME,
embedding_function=embedding_func
)
# 4. 将文档向量化并存入数据库
logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...")
# 为了避免重复添加,可以先检查
# (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑)
doc_texts = [doc['text'] for doc in docs_to_ingest]
metadatas = [doc['metadata'] for doc in docs_to_ingest]
ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))]
try:
# ChromaDB的add方法会自动处理嵌入
collection.add(
documents=doc_texts,
metadatas=metadatas,
ids=ids
)
logging.info("所有文档块已成功入库!")
except Exception as e:
logging.error(f"向ChromaDB添加文档时出错: {e}")
# 验证一下
count = collection.count()
logging.info(f"数据库中现在有 {count} 个条目。")
print("\n✅ 数据入库完成!")
print(f"知识库位于: {KNOWLEDGE_BASE_DIR}")
print(f"向量数据库位于: {VECTOR_STORE_DIR}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,6 @@
{"text": "在地图上有一个名为 '跷跷板' 的地点或区域它的leisure是'playground',其中心位置坐标大约在 (x:15, y:-8.5, z:1.2)。"}
{"text": "在地图上有一个名为 'A地' 的地点或区域它的building是'commercial',其中心位置坐标大约在 (x:10, y:-10, z:2)。"}
{"text": "在地图上有一个名为 '学生宿舍' 的地点或区域它的building是'dormitory',其中心位置坐标大约在 (x:5, y:3, z:2)。"}
{"text": "地点:'研究所正大门'。别名:'大门'、'入口'。坐标:(x:-23.8, y:292.8, z:14)。建议悬停高度14米。适合任务定点侦察、拍照。"}
{"text": "地点:'研究所广场'。属性:开阔区域。坐标:(x:-24.0, y:241.8, z:14)。建议搜索半径30米。适合任务寻找人员、旋转搜索。"}
{"text": "路线:'研究所外围巡逻'。关键航点序列:[(x:-24.0, y:241.8), (x:-107.8, y:289.8), (x:-106.5, y:241.3), (x:-23.80, y:292.80)]。高度14米。适合环绕侦察任务。"}

70349
tools/rag/map/export.json Normal file

File diff suppressed because it is too large Load Diff

Binary file not shown.