157 lines
5.6 KiB
Python
157 lines
5.6 KiB
Python
import xml.etree.ElementTree as ET
|
||
import json
|
||
from pathlib import Path
|
||
import logging
|
||
import os
|
||
|
||
# --- 配置日志 ---
|
||
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
||
|
||
def process_osm_json(input_path: Path) -> list[str]:
|
||
"""
|
||
处理OpenStreetMap的JSON文件,返回描述性句子列表。
|
||
"""
|
||
logging.info(f"正在以OSM JSON格式处理文件: {input_path.name}")
|
||
descriptions = []
|
||
try:
|
||
with open(input_path, 'r', encoding='utf-8') as f:
|
||
data = json.load(f)
|
||
except (json.JSONDecodeError, IOError) as e:
|
||
logging.error(f"读取或解析 {input_path.name} 时出错: {e}")
|
||
return []
|
||
|
||
elements = data.get('elements', [])
|
||
if not elements:
|
||
return []
|
||
|
||
nodes_map = {node['id']: node for node in elements if node.get('type') == 'node'}
|
||
ways = [elem for elem in elements if elem.get('type') == 'way']
|
||
|
||
for way in ways:
|
||
tags = way.get('tags', {})
|
||
if 'name' not in tags:
|
||
continue
|
||
|
||
way_name = tags.get('name')
|
||
way_nodes_ids = way.get('nodes', [])
|
||
if not way_nodes_ids:
|
||
continue
|
||
|
||
total_lat, total_lon, node_count = 0, 0, 0
|
||
for node_id in way_nodes_ids:
|
||
node_info = nodes_map.get(node_id)
|
||
if node_info:
|
||
total_lat += node_info.get('lat', 0)
|
||
total_lon += node_info.get('lon', 0)
|
||
node_count += 1
|
||
|
||
if node_count == 0:
|
||
continue
|
||
|
||
center_lat = total_lat / node_count
|
||
center_lon = total_lon / node_count
|
||
|
||
sentence = f"在地图上有一个名为 '{way_name}' 的地点或区域"
|
||
other_tags = {k: v for k, v in tags.items() if k != 'name'}
|
||
if other_tags:
|
||
tag_descs = [f"{key}是'{value}'" for key, value in other_tags.items()]
|
||
sentence += f",它的{ '、'.join(tag_descs) }"
|
||
sentence += f",其中心位置坐标大约在 ({center_lat:.6f}, {center_lon:.6f})。"
|
||
descriptions.append(sentence)
|
||
|
||
logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 条位置描述。")
|
||
return descriptions
|
||
|
||
|
||
def process_gazebo_world(input_path: Path) -> list[str]:
|
||
"""
|
||
处理Gazebo的.world文件,返回描述性句子列表。
|
||
"""
|
||
logging.info(f"正在以Gazebo World格式处理文件: {input_path.name}")
|
||
descriptions = []
|
||
try:
|
||
tree = ET.parse(input_path)
|
||
root = tree.getroot()
|
||
except ET.ParseError as e:
|
||
logging.error(f"解析XML文件 {input_path.name} 失败: {e}")
|
||
return []
|
||
|
||
models = root.findall('.//model')
|
||
for model in models:
|
||
model_name = model.get('name')
|
||
pose_element = model.find('pose')
|
||
|
||
if model_name and pose_element is not None and pose_element.text:
|
||
try:
|
||
pose_values = [float(p) for p in pose_element.text.strip().split()]
|
||
sentence = (
|
||
f"仿真环境中有一个名为 '{model_name}' 的物体,"
|
||
f"其位置和姿态(x, y, z, roll, pitch, yaw)为: {pose_values}。"
|
||
)
|
||
descriptions.append(sentence)
|
||
except (ValueError, IndexError):
|
||
logging.warning(f"跳过模型 '{model_name}',因其pose格式不正确。")
|
||
|
||
logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 个物体信息。")
|
||
return descriptions
|
||
|
||
|
||
def main():
|
||
"""
|
||
主函数,扫描源数据目录,为每个文件生成独立的NDJSON知识库。
|
||
"""
|
||
script_dir = Path(__file__).resolve().parent
|
||
# 输入源: tools/map/
|
||
source_data_dir = script_dir / 'map'
|
||
# 输出目录: tools/knowledge_base/
|
||
output_knowledge_base_dir = script_dir / 'knowledge_base'
|
||
|
||
if not source_data_dir.exists():
|
||
logging.error(f"源数据目录不存在: {source_data_dir}")
|
||
return
|
||
|
||
# 确保输出目录存在
|
||
output_knowledge_base_dir.mkdir(parents=True, exist_ok=True)
|
||
|
||
total_files_processed = 0
|
||
logging.info(f"--- 开始扫描源数据目录: {source_data_dir} ---")
|
||
|
||
for file_path in source_data_dir.iterdir():
|
||
if not file_path.is_file():
|
||
continue
|
||
|
||
descriptions = []
|
||
if file_path.suffix == '.json':
|
||
descriptions = process_osm_json(file_path)
|
||
elif file_path.suffix == '.world':
|
||
descriptions = process_gazebo_world(file_path)
|
||
else:
|
||
logging.warning(f"跳过不支持的文件类型: {file_path.name}")
|
||
continue
|
||
|
||
if not descriptions:
|
||
logging.warning(f"未能从 {file_path.name} 提取有效信息,跳过生成文件。")
|
||
continue
|
||
|
||
output_filename = file_path.stem + '_knowledge.ndjson'
|
||
output_path = output_knowledge_base_dir / output_filename
|
||
|
||
try:
|
||
with open(output_path, 'w', encoding='utf-8') as f:
|
||
for sentence in descriptions:
|
||
json_record = {"text": sentence}
|
||
f.write(json.dumps(json_record, ensure_ascii=False) + '\n')
|
||
logging.info(f"成功为 '{file_path.name}' 生成知识库文件: {output_path.name}")
|
||
total_files_processed += 1
|
||
except IOError as e:
|
||
logging.error(f"写入输出文件 '{output_path.name}' 失败: {e}")
|
||
|
||
logging.info("--- 数据处理完成 ---")
|
||
if total_files_processed > 0:
|
||
logging.info(f"共为 {total_files_processed} 个源文件生成了知识库。")
|
||
else:
|
||
logging.warning("未生成任何知识库文件。")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
main() |