优化交互式测试验证脚本,针对场景4修改提示词以及代码
This commit is contained in:
39
tools/rag/README.md
Normal file
39
tools/rag/README.md
Normal file
@@ -0,0 +1,39 @@
|
||||
# RAG & Map Tools
|
||||
|
||||
该目录包含了地图构建、知识库生成和向量数据库管理的相关工具。
|
||||
|
||||
## 目录结构
|
||||
|
||||
- **knowledge_base/**: 存放源文档数据。
|
||||
- 支持格式: `.txt`, `.md`, `.pdf`
|
||||
- 生成格式: `.json`, `.ndjson` (由 `build_knowledge_base.py` 生成)
|
||||
|
||||
- **map/**: 存放地图原始数据。
|
||||
- `.osm` (OpenStreetMap 数据)
|
||||
- `.world` (Gazebo 仿真环境数据)
|
||||
|
||||
- **vector_store/**: ChromaDB 向量数据库的持久化存储目录。
|
||||
|
||||
## 脚本说明
|
||||
|
||||
### 1. `build_knowledge_base.py`
|
||||
**功能**: 处理 `map/` 目录下的地图文件,提取地理信息和语义描述,生成知识库文件到 `knowledge_base/` 目录。
|
||||
**使用方法**:
|
||||
```bash
|
||||
python build_knowledge_base.py
|
||||
```
|
||||
|
||||
### 2. `ingest.py`
|
||||
**功能**: 读取 `knowledge_base/` 中的所有文档,调用嵌入模型(Embedding Model)将其向量化,并存入 `vector_store/` 中的 ChromaDB 数据库。
|
||||
**使用方法**:
|
||||
```bash
|
||||
python ingest.py
|
||||
```
|
||||
**依赖**: 需要确保后端嵌入服务(如 `llama-server`)已启动,或者配置正确的 `ORIN_IP` 环境变量。
|
||||
|
||||
## 工作流
|
||||
1. 将地图文件放入 `map/`。
|
||||
2. 运行 `build_knowledge_base.py` 生成文本描述。
|
||||
3. 将其他补充文档放入 `knowledge_base/`。
|
||||
4. 运行 `ingest.py` 构建向量索引。
|
||||
|
||||
157
tools/rag/build_knowledge_base.py
Normal file
157
tools/rag/build_knowledge_base.py
Normal file
@@ -0,0 +1,157 @@
|
||||
import xml.etree.ElementTree as ET
|
||||
import json
|
||||
from pathlib import Path
|
||||
import logging
|
||||
import os
|
||||
|
||||
# --- 配置日志 ---
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
def process_osm_json(input_path: Path) -> list[str]:
|
||||
"""
|
||||
处理OpenStreetMap的JSON文件,返回描述性句子列表。
|
||||
"""
|
||||
logging.info(f"正在以OSM JSON格式处理文件: {input_path.name}")
|
||||
descriptions = []
|
||||
try:
|
||||
with open(input_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
except (json.JSONDecodeError, IOError) as e:
|
||||
logging.error(f"读取或解析 {input_path.name} 时出错: {e}")
|
||||
return []
|
||||
|
||||
elements = data.get('elements', [])
|
||||
if not elements:
|
||||
return []
|
||||
|
||||
nodes_map = {node['id']: node for node in elements if node.get('type') == 'node'}
|
||||
ways = [elem for elem in elements if elem.get('type') == 'way']
|
||||
|
||||
for way in ways:
|
||||
tags = way.get('tags', {})
|
||||
if 'name' not in tags:
|
||||
continue
|
||||
|
||||
way_name = tags.get('name')
|
||||
way_nodes_ids = way.get('nodes', [])
|
||||
if not way_nodes_ids:
|
||||
continue
|
||||
|
||||
total_lat, total_lon, node_count = 0, 0, 0
|
||||
for node_id in way_nodes_ids:
|
||||
node_info = nodes_map.get(node_id)
|
||||
if node_info:
|
||||
total_lat += node_info.get('lat', 0)
|
||||
total_lon += node_info.get('lon', 0)
|
||||
node_count += 1
|
||||
|
||||
if node_count == 0:
|
||||
continue
|
||||
|
||||
center_lat = total_lat / node_count
|
||||
center_lon = total_lon / node_count
|
||||
|
||||
sentence = f"在地图上有一个名为 '{way_name}' 的地点或区域"
|
||||
other_tags = {k: v for k, v in tags.items() if k != 'name'}
|
||||
if other_tags:
|
||||
tag_descs = [f"{key}是'{value}'" for key, value in other_tags.items()]
|
||||
sentence += f",它的{ '、'.join(tag_descs) }"
|
||||
sentence += f",其中心位置坐标大约在 ({center_lat:.6f}, {center_lon:.6f})。"
|
||||
descriptions.append(sentence)
|
||||
|
||||
logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 条位置描述。")
|
||||
return descriptions
|
||||
|
||||
|
||||
def process_gazebo_world(input_path: Path) -> list[str]:
|
||||
"""
|
||||
处理Gazebo的.world文件,返回描述性句子列表。
|
||||
"""
|
||||
logging.info(f"正在以Gazebo World格式处理文件: {input_path.name}")
|
||||
descriptions = []
|
||||
try:
|
||||
tree = ET.parse(input_path)
|
||||
root = tree.getroot()
|
||||
except ET.ParseError as e:
|
||||
logging.error(f"解析XML文件 {input_path.name} 失败: {e}")
|
||||
return []
|
||||
|
||||
models = root.findall('.//model')
|
||||
for model in models:
|
||||
model_name = model.get('name')
|
||||
pose_element = model.find('pose')
|
||||
|
||||
if model_name and pose_element is not None and pose_element.text:
|
||||
try:
|
||||
pose_values = [float(p) for p in pose_element.text.strip().split()]
|
||||
sentence = (
|
||||
f"仿真环境中有一个名为 '{model_name}' 的物体,"
|
||||
f"其位置和姿态(x, y, z, roll, pitch, yaw)为: {pose_values}。"
|
||||
)
|
||||
descriptions.append(sentence)
|
||||
except (ValueError, IndexError):
|
||||
logging.warning(f"跳过模型 '{model_name}',因其pose格式不正确。")
|
||||
|
||||
logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 个物体信息。")
|
||||
return descriptions
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
主函数,扫描源数据目录,为每个文件生成独立的NDJSON知识库。
|
||||
"""
|
||||
script_dir = Path(__file__).resolve().parent
|
||||
# 输入源: tools/map/
|
||||
source_data_dir = script_dir / 'map'
|
||||
# 输出目录: tools/knowledge_base/
|
||||
output_knowledge_base_dir = script_dir / 'knowledge_base'
|
||||
|
||||
if not source_data_dir.exists():
|
||||
logging.error(f"源数据目录不存在: {source_data_dir}")
|
||||
return
|
||||
|
||||
# 确保输出目录存在
|
||||
output_knowledge_base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
total_files_processed = 0
|
||||
logging.info(f"--- 开始扫描源数据目录: {source_data_dir} ---")
|
||||
|
||||
for file_path in source_data_dir.iterdir():
|
||||
if not file_path.is_file():
|
||||
continue
|
||||
|
||||
descriptions = []
|
||||
if file_path.suffix == '.json':
|
||||
descriptions = process_osm_json(file_path)
|
||||
elif file_path.suffix == '.world':
|
||||
descriptions = process_gazebo_world(file_path)
|
||||
else:
|
||||
logging.warning(f"跳过不支持的文件类型: {file_path.name}")
|
||||
continue
|
||||
|
||||
if not descriptions:
|
||||
logging.warning(f"未能从 {file_path.name} 提取有效信息,跳过生成文件。")
|
||||
continue
|
||||
|
||||
output_filename = file_path.stem + '_knowledge.ndjson'
|
||||
output_path = output_knowledge_base_dir / output_filename
|
||||
|
||||
try:
|
||||
with open(output_path, 'w', encoding='utf-8') as f:
|
||||
for sentence in descriptions:
|
||||
json_record = {"text": sentence}
|
||||
f.write(json.dumps(json_record, ensure_ascii=False) + '\n')
|
||||
logging.info(f"成功为 '{file_path.name}' 生成知识库文件: {output_path.name}")
|
||||
total_files_processed += 1
|
||||
except IOError as e:
|
||||
logging.error(f"写入输出文件 '{output_path.name}' 失败: {e}")
|
||||
|
||||
logging.info("--- 数据处理完成 ---")
|
||||
if total_files_processed > 0:
|
||||
logging.info(f"共为 {total_files_processed} 个源文件生成了知识库。")
|
||||
else:
|
||||
logging.warning("未生成任何知识库文件。")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
194
tools/rag/ingest.py
Normal file
194
tools/rag/ingest.py
Normal file
@@ -0,0 +1,194 @@
|
||||
# 该代码用于将本地知识库中的文档导入到ChromaDB中,并使用远程嵌入模型进行向量化
|
||||
import os
|
||||
from pathlib import Path
|
||||
import chromadb
|
||||
# from chromadb.utils import embedding_functions - 不再需要
|
||||
from chromadb.api.types import Documents, EmbeddingFunction, Embeddings, Embeddable
|
||||
from unstructured.partition.auto import partition
|
||||
from rich.progress import track
|
||||
import logging
|
||||
import requests # 导入requests
|
||||
import json # 导入json模块
|
||||
|
||||
# 配置日志
|
||||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||||
|
||||
# --- 配置 ---
|
||||
# 获取脚本所在目录,确保路径的正确性
|
||||
SCRIPT_DIR = Path(__file__).resolve().parent
|
||||
KNOWLEDGE_BASE_DIR = SCRIPT_DIR / "knowledge_base"
|
||||
VECTOR_STORE_DIR = SCRIPT_DIR / "vector_store"
|
||||
COLLECTION_NAME = "drone_docs"
|
||||
# EMBEDDING_MODEL_NAME = "bge-small-zh-v1.5" # 不再需要,模型名在函数内部处理
|
||||
|
||||
# --- 自定义远程嵌入函数 ---
|
||||
class RemoteEmbeddingFunction(EmbeddingFunction[Embeddable]):
|
||||
"""
|
||||
一个使用远程、兼容OpenAI API的嵌入服务的嵌入函数。
|
||||
"""
|
||||
def __init__(self, api_url: str):
|
||||
self._api_url = api_url
|
||||
logging.info(f"自定义嵌入函数已初始化,将连接到: {self._api_url}")
|
||||
|
||||
def __call__(self, input: Embeddable) -> Embeddings:
|
||||
"""
|
||||
对输入的文档进行嵌入。
|
||||
"""
|
||||
# 我们的服务只能处理文本,所以检查输入是否为字符串列表
|
||||
if not isinstance(input, list) or not all(isinstance(doc, str) for doc in input):
|
||||
logging.error("此嵌入函数仅支持字符串列表(文档)作为输入。")
|
||||
return []
|
||||
|
||||
try:
|
||||
# 移除 "model" 参数,因为embedding服务可能不需要它
|
||||
response = requests.post(
|
||||
self._api_url,
|
||||
json={"input": input},
|
||||
headers={"Content-Type": "application/json"}
|
||||
)
|
||||
response.raise_for_status() # 如果请求失败则抛出HTTPError
|
||||
|
||||
# 按照OpenAI API的格式解析返回的嵌入向量
|
||||
data = response.json().get("data", [])
|
||||
if not data:
|
||||
raise ValueError("API响应中没有找到'data'字段或'data'为空")
|
||||
|
||||
embeddings = [item['embedding'] for item in data]
|
||||
return embeddings
|
||||
|
||||
except requests.RequestException as e:
|
||||
logging.error(f"调用嵌入API失败: {e}")
|
||||
# 返回一个空列表或根据需要处理错误
|
||||
return []
|
||||
except (ValueError, KeyError) as e:
|
||||
logging.error(f"解析API响应失败: {e}")
|
||||
logging.error(f"收到的响应内容: {response.text}")
|
||||
return []
|
||||
|
||||
|
||||
def get_documents(directory: Path):
|
||||
"""从知识库目录加载所有文档并进行切分"""
|
||||
documents = []
|
||||
logging.info(f"从 '{directory}' 加载文档...")
|
||||
for file_path in directory.rglob("*"):
|
||||
if file_path.is_file() and not file_path.name.startswith('.'):
|
||||
try:
|
||||
# 对简单文本文件直接读取
|
||||
if file_path.suffix in ['.txt', '.md']:
|
||||
text = file_path.read_text(encoding='utf-8')
|
||||
documents.append({
|
||||
"text": text,
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理文本文件: {file_path.name}")
|
||||
# 特别处理常规的JSON文件
|
||||
elif file_path.suffix == '.json':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
data = json.load(f)
|
||||
if 'elements' in data and isinstance(data['elements'], list):
|
||||
for element in data['elements']:
|
||||
documents.append({
|
||||
"text": json.dumps(element, ensure_ascii=False),
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理JSON文件: {file_path.name}, 提取了 {len(data['elements'])} 个元素。")
|
||||
else:
|
||||
documents.append({
|
||||
"text": json.dumps(data, ensure_ascii=False),
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理JSON文件: {file_path.name} (作为单个文档)")
|
||||
# 新增:专门处理我们生成的 NDJSON 文件
|
||||
elif file_path.suffix == '.ndjson':
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
count = 0
|
||||
for line in f:
|
||||
try:
|
||||
record = json.loads(line)
|
||||
if 'text' in record and isinstance(record['text'], str):
|
||||
documents.append({
|
||||
"text": record['text'],
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
count += 1
|
||||
except json.JSONDecodeError:
|
||||
logging.warning(f"跳过无效的JSON行: {line.strip()}")
|
||||
if count > 0:
|
||||
logging.info(f"成功处理NDJSON文件: {file_path.name}, 提取了 {count} 个文档。")
|
||||
# 对其他所有文件类型,使用unstructured
|
||||
else:
|
||||
elements = partition(filename=str(file_path))
|
||||
for element in elements:
|
||||
documents.append({
|
||||
"text": element.text,
|
||||
"metadata": {"source": str(file_path.name)}
|
||||
})
|
||||
logging.info(f"成功处理文件: {file_path.name} (使用unstructured)")
|
||||
except Exception as e:
|
||||
logging.error(f"处理文件 {file_path.name} 失败: {e}")
|
||||
return documents
|
||||
|
||||
|
||||
def main():
|
||||
"""主函数,执行文档入库流程"""
|
||||
if not KNOWLEDGE_BASE_DIR.exists():
|
||||
KNOWLEDGE_BASE_DIR.mkdir(parents=True)
|
||||
logging.warning(f"知识库目录不存在,已自动创建: {KNOWLEDGE_BASE_DIR}")
|
||||
logging.warning("请向该目录中添加您的知识文件(如 .txt, .pdf, .md)。")
|
||||
return
|
||||
|
||||
# 1. 加载并切分文档
|
||||
docs_to_ingest = get_documents(KNOWLEDGE_BASE_DIR)
|
||||
if not docs_to_ingest:
|
||||
logging.warning("在知识库中未找到可处理的文档。")
|
||||
return
|
||||
|
||||
# 2. 初始化ChromaDB客户端和远程嵌入函数
|
||||
orin_ip = os.getenv("ORIN_IP", "localhost")
|
||||
embedding_api_url = f"http://{orin_ip}:8090/v1/embeddings"
|
||||
|
||||
logging.info(f"正在初始化远程嵌入函数,目标服务地址: {embedding_api_url}")
|
||||
embedding_func = RemoteEmbeddingFunction(api_url=embedding_api_url)
|
||||
|
||||
client = chromadb.PersistentClient(path=str(VECTOR_STORE_DIR))
|
||||
|
||||
# 3. 创建或获取集合
|
||||
logging.info(f"正在访问ChromaDB集合: {COLLECTION_NAME}")
|
||||
collection = client.get_or_create_collection(
|
||||
name=COLLECTION_NAME,
|
||||
embedding_function=embedding_func
|
||||
)
|
||||
|
||||
# 4. 将文档向量化并存入数据库
|
||||
logging.info(f"开始将 {len(docs_to_ingest)} 个文档块入库...")
|
||||
|
||||
# 为了避免重复添加,可以先检查
|
||||
# (这里为了简单,我们每次都重新添加,生产环境需要更复杂的逻辑)
|
||||
|
||||
doc_texts = [doc['text'] for doc in docs_to_ingest]
|
||||
metadatas = [doc['metadata'] for doc in docs_to_ingest]
|
||||
ids = [f"doc_{KNOWLEDGE_BASE_DIR.name}_{i}" for i in range(len(doc_texts))]
|
||||
|
||||
try:
|
||||
# ChromaDB的add方法会自动处理嵌入
|
||||
collection.add(
|
||||
documents=doc_texts,
|
||||
metadatas=metadatas,
|
||||
ids=ids
|
||||
)
|
||||
logging.info("所有文档块已成功入库!")
|
||||
except Exception as e:
|
||||
logging.error(f"向ChromaDB添加文档时出错: {e}")
|
||||
|
||||
|
||||
# 验证一下
|
||||
count = collection.count()
|
||||
logging.info(f"数据库中现在有 {count} 个条目。")
|
||||
|
||||
print("\n✅ 数据入库完成!")
|
||||
print(f"知识库位于: {KNOWLEDGE_BASE_DIR}")
|
||||
print(f"向量数据库位于: {VECTOR_STORE_DIR}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
6
tools/rag/knowledge_base/export_knowledge.ndjson
Normal file
6
tools/rag/knowledge_base/export_knowledge.ndjson
Normal file
@@ -0,0 +1,6 @@
|
||||
{"text": "在地图上有一个名为 '跷跷板' 的地点或区域,它的leisure是'playground',其中心位置坐标大约在 (x:15, y:-8.5, z:1.2)。"}
|
||||
{"text": "在地图上有一个名为 'A地' 的地点或区域,它的building是'commercial',其中心位置坐标大约在 (x:10, y:-10, z:2)。"}
|
||||
{"text": "在地图上有一个名为 '学生宿舍' 的地点或区域,它的building是'dormitory',其中心位置坐标大约在 (x:5, y:3, z:2)。"}
|
||||
{"text": "地点:'研究所正大门'。别名:'大门'、'入口'。坐标:(x:-23.8, y:292.8, z:14)。建议悬停高度:14米。适合任务:定点侦察、拍照。"}
|
||||
{"text": "地点:'研究所广场'。属性:开阔区域。坐标:(x:-24.0, y:241.8, z:14)。建议搜索半径:30米。适合任务:寻找人员、旋转搜索。"}
|
||||
{"text": "路线:'研究所外围巡逻'。关键航点序列:[(x:-24.0, y:241.8), (x:-107.8, y:289.8), (x:-106.5, y:241.3), (x:-23.80, y:292.80)]。高度:14米。适合环绕侦察任务。"}
|
||||
70349
tools/rag/map/export.json
Normal file
70349
tools/rag/map/export.json
Normal file
File diff suppressed because it is too large
Load Diff
Binary file not shown.
Binary file not shown.
Binary file not shown.
BIN
tools/rag/vector_store/chroma.sqlite3
Normal file
BIN
tools/rag/vector_store/chroma.sqlite3
Normal file
Binary file not shown.
Reference in New Issue
Block a user