import xml.etree.ElementTree as ET import json from pathlib import Path import logging import os # --- 配置日志 --- logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') def process_osm_json(input_path: Path) -> list[str]: """ 处理OpenStreetMap的JSON文件,返回描述性句子列表。 """ logging.info(f"正在以OSM JSON格式处理文件: {input_path.name}") descriptions = [] try: with open(input_path, 'r', encoding='utf-8') as f: data = json.load(f) except (json.JSONDecodeError, IOError) as e: logging.error(f"读取或解析 {input_path.name} 时出错: {e}") return [] elements = data.get('elements', []) if not elements: return [] nodes_map = {node['id']: node for node in elements if node.get('type') == 'node'} ways = [elem for elem in elements if elem.get('type') == 'way'] for way in ways: tags = way.get('tags', {}) if 'name' not in tags: continue way_name = tags.get('name') way_nodes_ids = way.get('nodes', []) if not way_nodes_ids: continue total_lat, total_lon, node_count = 0, 0, 0 for node_id in way_nodes_ids: node_info = nodes_map.get(node_id) if node_info: total_lat += node_info.get('lat', 0) total_lon += node_info.get('lon', 0) node_count += 1 if node_count == 0: continue center_lat = total_lat / node_count center_lon = total_lon / node_count sentence = f"在地图上有一个名为 '{way_name}' 的地点或区域" other_tags = {k: v for k, v in tags.items() if k != 'name'} if other_tags: tag_descs = [f"{key}是'{value}'" for key, value in other_tags.items()] sentence += f",它的{ '、'.join(tag_descs) }" sentence += f",其中心位置坐标大约在 ({center_lat:.6f}, {center_lon:.6f})。" descriptions.append(sentence) logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 条位置描述。") return descriptions def process_gazebo_world(input_path: Path) -> list[str]: """ 处理Gazebo的.world文件,返回描述性句子列表。 """ logging.info(f"正在以Gazebo World格式处理文件: {input_path.name}") descriptions = [] try: tree = ET.parse(input_path) root = tree.getroot() except ET.ParseError as e: logging.error(f"解析XML文件 {input_path.name} 失败: {e}") return [] models = root.findall('.//model') for model in models: model_name = model.get('name') pose_element = model.find('pose') if model_name and pose_element is not None and pose_element.text: try: pose_values = [float(p) for p in pose_element.text.strip().split()] sentence = ( f"仿真环境中有一个名为 '{model_name}' 的物体," f"其位置和姿态(x, y, z, roll, pitch, yaw)为: {pose_values}。" ) descriptions.append(sentence) except (ValueError, IndexError): logging.warning(f"跳过模型 '{model_name}',因其pose格式不正确。") logging.info(f"从 {input_path.name} 提取了 {len(descriptions)} 个物体信息。") return descriptions def main(): """ 主函数,扫描源数据目录,为每个文件生成独立的NDJSON知识库。 """ script_dir = Path(__file__).resolve().parent # 输入源: tools/map/ source_data_dir = script_dir / 'map' # 输出目录: tools/knowledge_base/ output_knowledge_base_dir = script_dir / 'knowledge_base' if not source_data_dir.exists(): logging.error(f"源数据目录不存在: {source_data_dir}") return # 确保输出目录存在 output_knowledge_base_dir.mkdir(parents=True, exist_ok=True) total_files_processed = 0 logging.info(f"--- 开始扫描源数据目录: {source_data_dir} ---") for file_path in source_data_dir.iterdir(): if not file_path.is_file(): continue descriptions = [] if file_path.suffix == '.json': descriptions = process_osm_json(file_path) elif file_path.suffix == '.world': descriptions = process_gazebo_world(file_path) else: logging.warning(f"跳过不支持的文件类型: {file_path.name}") continue if not descriptions: logging.warning(f"未能从 {file_path.name} 提取有效信息,跳过生成文件。") continue output_filename = file_path.stem + '_knowledge.ndjson' output_path = output_knowledge_base_dir / output_filename try: with open(output_path, 'w', encoding='utf-8') as f: for sentence in descriptions: json_record = {"text": sentence} f.write(json.dumps(json_record, ensure_ascii=False) + '\n') logging.info(f"成功为 '{file_path.name}' 生成知识库文件: {output_path.name}") total_files_processed += 1 except IOError as e: logging.error(f"写入输出文件 '{output_path.name}' 失败: {e}") logging.info("--- 数据处理完成 ---") if total_files_processed > 0: logging.info(f"共为 {total_files_processed} 个源文件生成了知识库。") else: logging.warning("未生成任何知识库文件。") if __name__ == "__main__": main()