Files
DronePlanning/tools/build_knowledge_base.py
2025-08-17 22:41:54 +08:00

157 lines
5.6 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import xml.etree.ElementTree as ET
import json
from pathlib import Path
import logging
import os
# --- 配置日志 ---
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
def process_osm_json(input_path: Path) -> list[str]:
"""
处理OpenStreetMap的JSON文件返回描述性句子列表。
"""
logging.info(f"正在以OSM JSON格式处理文件: {input_path.name}")
descriptions = []
try:
with open(input_path, 'r', encoding='utf-8') as f:
data = json.load(f)
except (json.JSONDecodeError, IOError) as e:
logging.error(f"读取或解析 {input_path.name} 时出错: {e}")
return []
elements = data.get('elements', [])
if not elements:
return []
nodes_map = {node['id']: node for node in elements if node.get('type') == 'node'}
ways = [elem for elem in elements if elem.get('type') == 'way']
for way in ways:
tags = way.get('tags', {})
if 'name' not in tags:
continue
way_name = tags.get('name')
way_nodes_ids = way.get('nodes', [])
if not way_nodes_ids:
continue
total_lat, total_lon, node_count = 0, 0, 0
for node_id in way_nodes_ids:
node_info = nodes_map.get(node_id)
if node_info:
total_lat += node_info.get('lat', 0)
total_lon += node_info.get('lon', 0)
node_count += 1
if node_count == 0:
continue
center_lat = total_lat / node_count
center_lon = total_lon / node_count
sentence = f"在地图上有一个名为 '{way_name}' 的地点或区域"
other_tags = {k: v for k, v in tags.items() if k != 'name'}
if other_tags:
tag_descs = [f"{key}'{value}'" for key, value in other_tags.items()]
sentence += f",它的{ ''.join(tag_descs) }"
sentence += f",其中心位置坐标大约在 ({center_lat:.6f}, {center_lon:.6f})。"
descriptions.append(sentence)
logging.info(f"{input_path.name} 提取了 {len(descriptions)} 条位置描述。")
return descriptions
def process_gazebo_world(input_path: Path) -> list[str]:
"""
处理Gazebo的.world文件返回描述性句子列表。
"""
logging.info(f"正在以Gazebo World格式处理文件: {input_path.name}")
descriptions = []
try:
tree = ET.parse(input_path)
root = tree.getroot()
except ET.ParseError as e:
logging.error(f"解析XML文件 {input_path.name} 失败: {e}")
return []
models = root.findall('.//model')
for model in models:
model_name = model.get('name')
pose_element = model.find('pose')
if model_name and pose_element is not None and pose_element.text:
try:
pose_values = [float(p) for p in pose_element.text.strip().split()]
sentence = (
f"仿真环境中有一个名为 '{model_name}' 的物体,"
f"其位置和姿态(x, y, z, roll, pitch, yaw)为: {pose_values}"
)
descriptions.append(sentence)
except (ValueError, IndexError):
logging.warning(f"跳过模型 '{model_name}'因其pose格式不正确。")
logging.info(f"{input_path.name} 提取了 {len(descriptions)} 个物体信息。")
return descriptions
def main():
"""
主函数扫描源数据目录为每个文件生成独立的NDJSON知识库。
"""
script_dir = Path(__file__).resolve().parent
# 输入源: tools/map/
source_data_dir = script_dir / 'map'
# 输出目录: tools/knowledge_base/
output_knowledge_base_dir = script_dir / 'knowledge_base'
if not source_data_dir.exists():
logging.error(f"源数据目录不存在: {source_data_dir}")
return
# 确保输出目录存在
output_knowledge_base_dir.mkdir(parents=True, exist_ok=True)
total_files_processed = 0
logging.info(f"--- 开始扫描源数据目录: {source_data_dir} ---")
for file_path in source_data_dir.iterdir():
if not file_path.is_file():
continue
descriptions = []
if file_path.suffix == '.json':
descriptions = process_osm_json(file_path)
elif file_path.suffix == '.world':
descriptions = process_gazebo_world(file_path)
else:
logging.warning(f"跳过不支持的文件类型: {file_path.name}")
continue
if not descriptions:
logging.warning(f"未能从 {file_path.name} 提取有效信息,跳过生成文件。")
continue
output_filename = file_path.stem + '_knowledge.ndjson'
output_path = output_knowledge_base_dir / output_filename
try:
with open(output_path, 'w', encoding='utf-8') as f:
for sentence in descriptions:
json_record = {"text": sentence}
f.write(json.dumps(json_record, ensure_ascii=False) + '\n')
logging.info(f"成功为 '{file_path.name}' 生成知识库文件: {output_path.name}")
total_files_processed += 1
except IOError as e:
logging.error(f"写入输出文件 '{output_path.name}' 失败: {e}")
logging.info("--- 数据处理完成 ---")
if total_files_processed > 0:
logging.info(f"共为 {total_files_processed} 个源文件生成了知识库。")
else:
logging.warning("未生成任何知识库文件。")
if __name__ == "__main__":
main()