From 8cb86115c261282054a6c0efbc63af3249c83741 Mon Sep 17 00:00:00 2001 From: huangfu <3045324663@qq.com> Date: Sun, 7 Dec 2025 20:15:35 +0800 Subject: [PATCH] =?UTF-8?q?=E4=BB=A3=E6=9B=BF=E7=82=B9=E4=BA=91?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/settings.json | 6 + Env/__pycache__/replay_policy.cpython-310.pyc | Bin 1990 -> 1990 bytes Env/__pycache__/scenario_env.cpython-310.pyc | Bin 11180 -> 11657 bytes Env/__pycache__/scenario_env.cpython-313.pyc | Bin 20798 -> 23823 bytes .../simple_idm_policy.cpython-310.pyc | Bin 770 -> 770 bytes Env/check_dataset.py | 109 +++++++++++ Env/run_multiagent_env.py | 43 +++- Env/scenario_env.py | 115 +++++++---- Env/verify_observation.py | 184 ++++++++++++++++++ train_magail.py | 115 +++++++++++ 10 files changed, 527 insertions(+), 45 deletions(-) create mode 100644 .vscode/settings.json create mode 100644 Env/check_dataset.py create mode 100644 Env/verify_observation.py create mode 100644 train_magail.py diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..93997ae --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,6 @@ +{ + "cursorpyright.analysis.extraPaths": [ + "/home/huangfukk/mdsn/metadrive", + "/home/huangfukk/mdsn/scenarionet" + ] +} \ No newline at end of file diff --git a/Env/__pycache__/replay_policy.cpython-310.pyc b/Env/__pycache__/replay_policy.cpython-310.pyc index fed7d1be07dedd8819c7dc27fa593d9d7de09902..fcf6cca8f766570eba787013ef6e668ddecda5eb 100644 GIT binary patch delta 20 acmX@ce~h0ypO=@50SNwa{NKpEn;if-8wJb& delta 20 acmX@ce~h0ypO=@50SH#v|J%sDn;if+6a}sT diff --git a/Env/__pycache__/scenario_env.cpython-310.pyc b/Env/__pycache__/scenario_env.cpython-310.pyc index 31f5e12183cb60ccfe1b7d1400b46f4ad13eda37..b8ffe4de1c4b355f98de8c746a5ed175b4082661 100644 GIT binary patch delta 1386 zcmZ8gO>7%Q6rR~#+nbHOu^l(TNF`SO9YGWU0*N$GRFtZKiYkzZwp1(Q*(AFpyUFZ2 z5!G5wp|mk=nxaWT&=edMDv*$xDu+UpQaB)S;lK?poMOiuP0zeCxP?Apl}l%q~%qVj^}Q}4JxU|x3~6LW%+V>d#!r&i}L!HDydLo zUiRlB$MJDC+BNhpm^NG6hK3U?x3hSkE^1X^imiDzJgU>yKB%C|y0RO;|s zYaHXlAz;}7mgt_q8k_AteDn|gR~}D`T5QgsS;Jf~bYi+*#?Bg=V&`Uc$D1~&tF>ha zY1q1*SunMTYZAa3o{{w&%6{k`#+~e5_v(qWP*97Ix9!ZVnbj8^IzzOBQ--IzIa=eS zX;10R2Q=+TM+-wk?*W*ef&w(o1?NRN%x?7z9@647Cdp7eJLBYtKIfRW=W5Y<*Ho delta 859 zcmZ8eO=}ZT6n*#2WHR}f=`^h$0gEZggv5%7)MBYZOQqON7g|svC6vBeGg^z26cJ`< zQ7TxiA3UgwDkNA@K|v`lUHTVX2)Yp$?u72E=Ot0mgt_;gbMCuy-c094#4o~l8N1*ULtbH_28DZC zO%F6EIyy;2(o7g@dZ_=yL?81o5>P!N%mxXzNd6#n6DJ}|!aU9b-7`c-1#pw@5dqhu z7>?#<9Yb?0&Dtmqgf&_SU5S8W0`rn}E8u@v?%7^PTXFE>xJWc`3G{W=Qz9Ag4JIt0 zeT2epZHt>3YQR-ZursXaOvsP>C#0J>KnYpS^vGM8fy6Lu@<&w2$@iHa!|=NSbMtJSH_8_M&S@=NcVrO*In z)@@-m$mIUY8LxlM1NJ8j+kPd95_fhIUU1ryo|@=abrti(;rhIIMgL4v*+lOe4QQcE?L!zm@SLh+U!hi z!T+vqEBkYYXt$in-9NGdShB?RE7S9pnew#-|FS40tF`IcZ{$y*!%wM@Lny}Os9T8m zH&Dyx?toE(Vb#;}tLvsymPy7>hjT*SakN%0=*HI3{12janJ-M*chwLY^V4eLn(T~p z#UG%2gfad@`K+)=f62`jd8|0^+)?d|D!3}F%6G*Rw7j)l9H+=Sm=|)S?{dvwld+&Qce~&zKlaxJ{O2q_RaXCev{jJw!8i`l|k9_z)?nMQXwgkQn z|Ad<*=v7G$DM9}xLFS0eKj0kpIgYw}JzZY6Kcb9254c@!e}M5gef}6p*Rd|I#~<*c zEmWQG0EmrbqQM$1dXG|(jp*l8sb*ep3LDeIdeft1sZRWakS0i$QfK}~9U;k1` z3z>xm8p&f@76=H(ds>&u*+im(n{J=eez}AoKbg-YLx=^8;;^X zmS`j3ms-Nd*2JV^PELcuF1MxPD0JayS479os4riF^q#0*^hB3TDv&HmZ;p9q!192b zEah>dXYtY~zLe)vLPW0D8=v18=P*bhtR5W08OoS#D-rI5S@CjIi`T@u>)^KqlIgj< zqU}l;@od#dx%6@SkS^LoK#FT8G%0}8%cS#S*jBd3mR%@ZT_U{9AZq18hVpnwFOs;r zxaNOFz51V#zs8vb-7#pyIV)`Fx*=blYmbN2l4lnKT-QE=_NQ7D>+K@g0^U=EihO&* zbI8O3y9D?O*Bmb_l-zLMpcTSIk(rPFnp8#>qyI{6jxWLP`qVIo+6+2gNlt9J&^v~0 z@n*XkI7(p_-^I zqnJ1#uR=lNVA@D*rcGlf4tijB$QTWcRlWovWg;pk2d~~k%az3huvKPhsQOd zqYLP#3#(;^)9ZIL`$5i#W>w6t?JsP%wOJc#TIi9K6{a=Sm*X?L8NByv_1EsJZECQ! z(Q8!23UpKzs-q3+wH(VOoomW9Gp7?a}t!c}U9A@)rn16T@J-Bsh@zcwep#%@^p>vl#TSFlOiu>0;U3m8gQ56;R z5N>tod$S)z%^kuN3m0xKPF!1j`{xhex&;(*m}2pxD=3gwowAkbhR2A|+uzgUW*9i8 z1k;H@FZx|tu5JJiy8uMuaElX(>vJFTGlNKyUML;KXaV#_PKI%z9qCtuIP)?d^eBC& zQM~F=B!1a&rUK9R13(oS#b_wwq*Ti!VVoBOAL3`W(t-Yf)9YBX&ku5kaJJ2YC< zSUn6CNf#_3eG#iKx+^aVD~ywAtRmx_EUZkva4=*jV-00@m1SW?+T?Clk{(H>>Rb97`dg+OreMRqV22~Lxs%=88CK+KLy8ToV#Bm#PEi$B(33T+ zA}gq%(aD`gM;5Ej3g%YKmfcQaE1T|gu$B9Q75jtx2ZPznoZ25zB+bip*j_WMH{X?; zXJ5G0c%$)F>y6f6lOyPKhpM{Ss%}v7-8wp>IxMs9+dJ9n*5IxVw)#NOaxi#sIGFw7 zoch?RlI2DDPEWpDN%=o^1h1Fo?4HvvU6zLEGL|m;qr41#k<(2^P+{)h)`d-(lY1v# zo;u8$Hit~rtf~60Y3o@koRxi+`tj~ba=eXIq=Ss~LbN&Wn*`{wGdjq}BRsTmUCXF> zjw%aNiZhL;8$(noOQi-)n`Vu{P4#!=!FEq@@8RHK{}MFjWG(LzzU}3~Q-Ey8H{(9u z6M~TL=RG5!GFa6dtZWGyTIZ-;&(6)Z1S=bZhNd~H8K$aM6_39hEUXR|Y?&Pi8g|Z6 zjZZO-zdYp%71pzb^}&LMJL;gJWsYi{`O~^AGPzbjY7Bo$O8s9y1@ZhFWpDZktb6XkQ-c(dYd61pj{FzI$lm?H?JoM2?6zpR6)yM@5U@P%)~W@xFGxZ z?L}xF-iFR9<0vQQHxNV-;fq5-q4s>Gx}pP{z+Lt@V2s06814dmhc~`BpWr!cd?C*# zwDZvQGBr9WH|{23Q$qJ2E9|KLBA^$2{FQ| z!)=p>XA@)QHMLb>O~y0IWiM7|(2`0kw(~h1gf#`dR*{YVTS>*W@3M-}q_Sdljz+na z*{L>(je54va=|*o%TXaxZY+qG+WF8-$Pk^X9fg4$JSDL)FjmC8hD8*ungFb@`*_KF z67mv__)*`~U_Er*7VVWBUG&UyW8m z@xuGBEWUez=^ygBEJgQb-&{O>V)4T12cuk_1Tueb|F9$AIpVG^`s%aS7se){;#bg% z7fvippIMRj)n{W+E+1Sw`QX&d!t^;Tijx#o`_*Ts#mosXG9u~f_XU`KuNO)Jg9{UL z5(C^C46^!ghUvA78rRdxhrIryYFQ1N;3Yyu9*;LF0OpVK9v~YFd z(Exx5742Z6T@8aPee}>EJfbnsV33Icel=iR7~r;yc@2Y3f^ZPb_wW!`!@PK|3PzN% zZw-z>zk^c|;deoC{R9*KCk8VBA{68Hy916sIFdbne>|X;%UQI5O=Fru?kN6#lWw$> zgi7`ifD^=$m@t^2ip{HZA=Uah)%vm8usrc>B`Y^gmj1NjQpF$S1yG|xid?~ zg{-#lKNN=1u&E%NoI1XDA~l?(A8(kDEhP}T4NHVTCmX9iW?QP`AY;@N2Z6~dGIL+KC#%~Q_N<8Pqa)?!k5~V=Pd>3WTFaf zwFtvA$R0_LfO;+MH2hh$frd&-61HKYWF$VYzCZcMy z*6i5PQ0r)FsH<;t)LFN0Z(#!H+?L&=VLVSAa{Aq|Z+H>uK<_cXr_1jc?q|HNhqzVV z0qH?G;AXnqo}+Gu3wm6quge_~`-YA<`UAag#vhS8U9RP^!{=Mhg=)YgcBV;o|*%nF4#=9Ayy@pbn@Mgo~JVrC4>!i_GMxbmoIlg*_O zoIx1)h5-toWA5$-9mi+?H_VE2Kqg|0JMuyhaj^?6P`b!_Ce1N zUqNF=a5kXQnhYJTwriW~XiB<{yfrz(VVsHHMH4lTMvX-2re9QEuMC+RSaZXjoV#Xg z$lSr2JHmQHI9VSyWWwjti+vM)A>&5YxG`kh&KkEv;wU{pLgpxyu~bozGD4cj%>8NQ zrOHsoW;SEOC+V!HnB$2+AT-rbLxVAP< z8vti};Llw<82BPZv_W-gsxNNgss`Ti=-ylK2$v4#M*t#XXEy`@#DRN&!Rea0j=>%b zIx*g> z>E*X34Yo#)l?owt{B9P#*?eUiS3tSi$Eo}f3%iEFM;PFxTWv9;Cxxq<+(K|G=t0L@ z@*59eLi`%e;G@QTi~(0_r|@WX0dan0@S%=6cNs>Jc%KU@sPh=N@Qen1-LmNgg+Tk5 zPY9JuJQ#hGnIK4eq&>>o)Q-y54%3d=>54HAbeRKksT$b_>a*@y2fmeCHDQ{{c7)hd}@U delta 3959 zcmZ`5X>gO*@x4#?>Ce_D$)^tAwy^PSumOX?cEAT9Szu!_wy;1V45K!6B^Sr87FRM(oB<*>6Og1iB!l)Rr7i>Q`)zm5ID4b8tu1l z-@e^_yKi^jnYZXWFH+-egF#E!kZrqj@LN5v8_hbc0jKPL`v5fPMrF$NYIsIhL)Gwx z4&ZV903C;``bOFdf7iF^Z{};P+Runaqss^@>kM<0J`CSCZl~KIV?3!n&J(hUK1Vzh z8vDvup3i$$2_2T8fqUo~LLT7=IXphL3!2?#7`F)Uk;S3Xc=QUU`>eIR+RxU9w-Mr1 zc?=EAsIbbb@fdN{lZP;@^=jRFfwvXIZ%QcS+ormvYn*B;vh+4CY=Em)m#n~Jg1?nosSg%+ znJNr918eUJ_WPJL^<5XLy-N2q^1;qBKMlaUZqs&wX*Q!Kj|HKTc&uK7k+jNlk{o1! z4quckT*VhQe$`_4EbOwCS6jCY*C3`Gs7B(J22-9hM593wR)xlVV@YBcg*=7claL>ebL~2m@OzB zk!av>1pZaCsd$q(&CZ!thURBy1EQ#76A>$d7yxf=trh(zPBD;{hk}u4T7D>SG$MxK zyS0smCy=aQhwzD_4?e8Dtn?%I3>>T50e`A%(u!R;4>16y`csB{u@L9XlSDiOm+E(d z$yK0ML`C29Au$5mT`h(DycnDfhJB&QgA6-19~Mg(cO)M!xlHPlVkvGt1Mj*Vs+~yB zQ?2~jC30}6aZBlqJYiwa>94-VkL+Nx;LoQ#^w{!m!@kBFX~;#c|AQytum4sL40%Kr8jm84gc&qOJ9U%yVr6bla-%!>p2}Z zIH{rig@Y}MV!esgvxivt!_-UR&c-r37H=%rMXUyLUh0twJ*$9Rm}??o86n$qSyIQ3h)*&b(z@ySa8#TRg#!Lb5%Wm= zI()jNT_qkzJOtIf6?WYk$Km-%5OdZcZf5j4lXxDNS;UAq2d8?i$^fEg;9_s#>=dG` z#gS~*VYw5(jX=6Qcw3}QtW!!^n3hemtoL;}&tL+nz`Mk1_>6Nfai zdG`ZkX@rE)3Y)fC3lx;4K#+lxBxSNDO|Fk%eyim~S?!W(Y4l3bk}h7-wxCWqS^A}l zZ7;YkxRS*!iQ<-z1!tTa$cEg0ESUE%jZ|UgujSN<)HgSejn0yb`?5 zue;IOmE@}ve0A*c1mBe6g(P2>;Omln6PqpI*xwXy&-}$CG!Ed=? zI$=H6z1V%OXR#;Nxzv|5wOnVG3Wete7YCCgsUBwj5hg~+V|N`y6{jk$c9x%l3|ypuBcmYk61$XaZ6_ZHdZM8t=s*o@#O zaB4>pJp$j`v73GXf8F5_(uPRTADHw9qO9G{i=Sd$gwesqW{Y?LcT6F8nt`-tmgNIB z-9WbJ{uEc&o)-Xv%!)qM7z>DyaL&mSwuXI3{t&*1~DJxjaw`3 z$5DE$u*}3z-F>adSa{MUs1zp074#?s#xM6uW!#4nrVw04a0S832>b}{_4w?SG$T#& zgXF*&7}~dHa2$yO0;~(-5`r@bUPN%;e;S%SyO$Vl+BeexQG`IehjDx&n(5c`y6 mhM;GnmtKT#OuR%-Le2iCDG$s0NAgD$wDA1k=LEk}S^E#i#juY6 diff --git a/Env/__pycache__/simple_idm_policy.cpython-310.pyc b/Env/__pycache__/simple_idm_policy.cpython-310.pyc index b0dddc5e97ab694c646cef66d80379dda7a8f041..0782927efdf8ab246c9409571bc266b5c50730c5 100644 GIT binary patch delta 20 acmZo-YhvTh=jG*M0D|`%|2J}fX955%$pu;f delta 20 acmZo-YhvTh=jG*M00MQ9e;c{KGXVf7lmwjs diff --git a/Env/check_dataset.py b/Env/check_dataset.py new file mode 100644 index 0000000..98ae66b --- /dev/null +++ b/Env/check_dataset.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +数据集检查脚本:统计可用的场景数量 +""" + +import pickle +import os +from pathlib import Path +from metadrive.engine.asset_loader import AssetLoader + +def check_dataset(data_dir, subfolder="exp_filtered"): + """ + 检查数据集中的场景数量 + + Args: + data_dir: 数据根目录 + subfolder: 数据子目录(exp_filtered 或 exp_converted) + """ + print("=" * 80) + print("数据集检查工具") + print("=" * 80) + + # 获取完整路径 + full_path = AssetLoader.file_path(data_dir, subfolder, unix_style=False) + print(f"\n数据集路径: {full_path}") + + # 检查文件结构 + if not os.path.exists(full_path): + print(f"❌ 错误:路径不存在: {full_path}") + return + + print(f"✅ 路径存在") + + # 读取数据集映射 + mapping_file = os.path.join(full_path, "dataset_mapping.pkl") + if os.path.exists(mapping_file): + print(f"\n读取 dataset_mapping.pkl...") + with open(mapping_file, 'rb') as f: + dataset_mapping = pickle.load(f) + + print(f"✅ 数据集映射文件存在") + print(f" 映射的场景数量: {len(dataset_mapping)}") + + # 统计各个子目录的分布 + subdirs = {} + for filename, subdir in dataset_mapping.items(): + if subdir not in subdirs: + subdirs[subdir] = [] + subdirs[subdir].append(filename) + + print(f"\n场景分布:") + for subdir, files in subdirs.items(): + print(f" {subdir}: {len(files)} 个场景") + + # 读取数据集摘要 + summary_file = os.path.join(full_path, "dataset_summary.pkl") + if os.path.exists(summary_file): + print(f"\n读取 dataset_summary.pkl...") + with open(summary_file, 'rb') as f: + dataset_summary = pickle.load(f) + + print(f"✅ 数据集摘要文件存在") + print(f" 摘要的场景数量: {len(dataset_summary)}") + + # 打印前几个场景的ID + print(f"\n前10个场景ID:") + for i, (scenario_id, info) in enumerate(list(dataset_summary.items())[:10]): + if isinstance(info, dict): + track_length = info.get('track_length', 'N/A') + num_objects = info.get('number_summary', {}).get('num_objects', 'N/A') + print(f" {i}: {scenario_id[:16]}... (时长: {track_length}, 对象数: {num_objects})") + + # 检查实际文件 + print(f"\n检查实际文件...") + pkl_files = list(Path(full_path).rglob("*.pkl")) + # 排除dataset_mapping和dataset_summary + scenario_files = [f for f in pkl_files if f.name not in ["dataset_mapping.pkl", "dataset_summary.pkl"]] + print(f" 实际场景文件数量: {len(scenario_files)}") + + # 检查子目录 + print(f"\n子目录结构:") + for item in os.listdir(full_path): + item_path = os.path.join(full_path, item) + if os.path.isdir(item_path): + pkl_count = len([f for f in os.listdir(item_path) if f.endswith('.pkl')]) + print(f" {item}/: {pkl_count} 个pkl文件") + + print("\n" + "=" * 80) + print("检查完成") + print("=" * 80) + + # 返回场景数量 + return len(dataset_mapping) if 'dataset_mapping' in locals() else 0 + + +if __name__ == "__main__": + WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" + + print("\n【检查 exp_filtered 数据集】") + num_filtered = check_dataset(WAYMO_DATA_DIR, "exp_filtered") + + print("\n\n【检查 exp_converted 数据集】") + num_converted = check_dataset(WAYMO_DATA_DIR, "exp_converted") + + print("\n\n总结:") + print(f" exp_filtered: {num_filtered} 个场景") + print(f" exp_converted: {num_converted} 个场景") + diff --git a/Env/run_multiagent_env.py b/Env/run_multiagent_env.py index 7f8e70d..4821e45 100644 --- a/Env/run_multiagent_env.py +++ b/Env/run_multiagent_env.py @@ -4,7 +4,7 @@ from simple_idm_policy import ConstantVelocityPolicy from replay_policy import ReplayPolicy from metadrive.engine.asset_loader import AssetLoader -WAYMO_DATA_DIR = r"/home/huangfukk/mdsn/exp_converted" +WAYMO_DATA_DIR = r"/home/huangfukk/mdsn" def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=False, @@ -40,10 +40,10 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa # ✅ 环境创建移到循环外面,避免重复创建 env = MultiAgentScenarioEnv( config={ - "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), + "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), "is_multi_agent": True, "horizon": horizon, - "use_render": render, + "use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误 "sequential_seed": True, "reactive_traffic": False, # 回放模式下不需要反应式交通 "manual_control": False, @@ -57,20 +57,33 @@ def run_replay_mode(data_dir, num_episodes=1, horizon=300, render=True, debug=Fa "spawn_vehicles": spawn_vehicles, "spawn_pedestrians": spawn_pedestrians, "spawn_cyclists": spawn_cyclists, + # ✅ 关键:设置可用场景数量 + #"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数 }, agent2policy=None # 回放模式不需要统一策略 ) try: + # 获取可用场景数量 + num_scenarios = env.config.get("num_scenarios", 1) + print(f"可用场景数量: {num_scenarios}") + for episode in range(num_episodes): print(f"\n{'='*50}") print(f"回合 {episode + 1}/{num_episodes}") if scenario_id is not None: print(f"场景ID: {scenario_id}") + else: + # 循环使用场景 + scenario_idx = episode % num_scenarios + print(f"使用场景索引: {scenario_idx}") print(f"{'='*50}") - # ✅ 如果不是指定场景,使用seed来遍历不同场景 - seed = scenario_id if scenario_id is not None else episode + # ✅ 如果不是指定场景,使用循环的场景索引 + if scenario_id is not None: + seed = scenario_id + else: + seed = episode % num_scenarios obs = env.reset(seed=seed) # 为每个车辆分配 ReplayPolicy @@ -180,11 +193,11 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu env = MultiAgentScenarioEnv( config={ - "data_directory": AssetLoader.file_path(data_dir, "training_20s", unix_style=False), + "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), "is_multi_agent": True, "num_controlled_agents": 3, "horizon": horizon, - "use_render": render, + "use_render": render, # 如果False会完全禁用渲染,避免LANE_FREEWAY错误 "sequential_seed": True, "reactive_traffic": True, "manual_control": False, @@ -198,19 +211,33 @@ def run_simulation_mode(data_dir, num_episodes=1, horizon=300, render=True, debu "spawn_vehicles": spawn_vehicles, "spawn_pedestrians": spawn_pedestrians, "spawn_cyclists": spawn_cyclists, + # ✅ 关键:设置可用场景数量 + #"num_scenarios": 19012, # 从dataset_mapping.pkl中统计的实际场景数 }, agent2policy=ConstantVelocityPolicy(target_speed=50) ) try: + # 获取可用场景数量 + num_scenarios = env.config.get("num_scenarios", 1) + print(f"可用场景数量: {num_scenarios}") + for episode in range(num_episodes): print(f"\n{'='*50}") print(f"回合 {episode + 1}/{num_episodes}") if scenario_id is not None: print(f"场景ID: {scenario_id}") + else: + # 循环使用场景 + scenario_idx = episode % num_scenarios + print(f"使用场景索引: {scenario_idx}") print(f"{'='*50}") - seed = scenario_id if scenario_id is not None else episode + # ✅ 如果不是指定场景,使用循环的场景索引 + if scenario_id is not None: + seed = scenario_id + else: + seed = episode % num_scenarios obs = env.reset(seed=seed) actual_horizon = env.config["horizon"] diff --git a/Env/scenario_env.py b/Env/scenario_env.py index 85a616d..e3c0b21 100644 --- a/Env/scenario_env.py +++ b/Env/scenario_env.py @@ -306,43 +306,72 @@ class MultiAgentScenarioEnv(ScenarioEnv): return False def _spawn_controlled_agents(self): + """ + 生成应该在当前或之前出现的车辆 + 如果round=0且所有车辆的show_time>0,则生成show_time最小的车辆(保证至少有车辆出现) + """ + vehicles_to_spawn = [] + for car in self.car_birth_info_list: - if car['show_time'] == self.round: - agent_id = f"controlled_{car['id']}" - vehicle_config = {} - vehicle = self.engine.spawn_object( - PolicyVehicle, - vehicle_config=vehicle_config, - position=car['begin'], - heading=car['heading'] + if car['show_time'] <= self.round: + vehicles_to_spawn.append(car) + + # 如果当前round没有车辆应该出现,但车辆列表不为空,则生成最早出现的车辆 + # 这样可以确保在reset时至少有车辆出现 + # if len(vehicles_to_spawn) == 0 and len(self.car_birth_info_list) > 0: + # if self.config.get("debug", False): + # self.logger.debug( + # f"No vehicles to spawn at round {self.round}, " + # f"spawning earliest vehicle instead" + # ) + # # 找到show_time最小的车辆 + # earliest_car = min(self.car_birth_info_list, key=lambda x: x['show_time']) + # vehicles_to_spawn.append(earliest_car) + + for car in vehicles_to_spawn: + agent_id = f"controlled_{car['id']}" + + # 避免重复生成 + if agent_id in self.controlled_agents: + continue + + vehicle_config = {} + vehicle = self.engine.spawn_object( + PolicyVehicle, + vehicle_config=vehicle_config, + position=car['begin'], + heading=car['heading'] + ) + + # 重置车辆状态 + reset_kwargs = { + 'position': car['begin'], + 'heading': car['heading'] + } + + # 如果启用速度继承,设置初始速度 + if car.get('velocity') is not None: + reset_kwargs['velocity'] = car['velocity'] + + vehicle.reset(**reset_kwargs) + + # 设置策略和目的地 + vehicle.set_policy(self.policy) + vehicle.set_destination(car['end']) + vehicle.set_expert_vehicle_id(car['id']) + + self.controlled_agents[agent_id] = vehicle + self.controlled_agent_ids.append(agent_id) + + # 注册到引擎的 active_agents + self.engine.agent_manager.active_agents[agent_id] = vehicle + + if self.config.get("debug", False): + self.logger.debug( + f"Spawned vehicle {agent_id} at round {self.round} " + f"(show_time={car['show_time']}), position {car['begin']}" ) - # 重置车辆状态 - reset_kwargs = { - 'position': car['begin'], - 'heading': car['heading'] - } - - # 如果启用速度继承,设置初始速度 - if car.get('velocity') is not None: - reset_kwargs['velocity'] = car['velocity'] - - vehicle.reset(**reset_kwargs) - - # 设置策略和目的地 - vehicle.set_policy(self.policy) - vehicle.set_destination(car['end']) - vehicle.set_expert_vehicle_id(car['id']) - - self.controlled_agents[agent_id] = vehicle - self.controlled_agent_ids.append(agent_id) - - # 注册到引擎的 active_agents - self.engine.agent_manager.active_agents[agent_id] = vehicle - - if self.config.get("debug", False): - self.logger.debug(f"Spawned vehicle {agent_id} at round {self.round}, position {car['begin']}") - def _get_all_obs(self): self.obs_list = [] @@ -364,8 +393,20 @@ class MultiAgentScenarioEnv(ScenarioEnv): traffic_light = 0 break - lidar = self.engine.get_sensor("lidar").perceive(num_lasers=80, distance=30, base_vehicle=vehicle, - physics_world=self.engine.physics_world.dynamic_world) + # 使用最近10辆车的相对位置与相对速度替代原80维LiDAR点云 + lidar_cloud_points, detected_objects = self.engine.get_sensor("lidar").perceive( + num_lasers=80, + distance=30, + base_vehicle=vehicle, + physics_world=self.engine.physics_world.dynamic_world + ) + nearest_vehicle_info = self.engine.get_sensor("lidar").get_surrounding_vehicles_info( + vehicle, + detected_objects, + perceive_distance=30, + num_others=10, + add_others_navi=False + ) side_lidar = self.engine.get_sensor("side_detector").perceive(num_lasers=10, distance=8, base_vehicle=vehicle, physics_world=self.engine.physics_world.static_world) @@ -374,7 +415,7 @@ class MultiAgentScenarioEnv(ScenarioEnv): physics_world=self.engine.physics_world.static_world) obs = (list(state['position'][:2]) + list(state['velocity']) + [state['heading_theta']] - + lidar[0] + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + + nearest_vehicle_info + side_lidar[0] + lane_line_lidar[0] + [traffic_light] + list(vehicle.destination)) self.obs_list.append(obs) diff --git a/Env/verify_observation.py b/Env/verify_observation.py new file mode 100644 index 0000000..36ab298 --- /dev/null +++ b/Env/verify_observation.py @@ -0,0 +1,184 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +简单验证脚本:检查车辆是否正确获取观测空间 +用法:python verify_observations.py +""" + +import argparse +from scenario_env import MultiAgentScenarioEnv +from replay_policy import ReplayPolicy +from metadrive.engine.asset_loader import AssetLoader +import numpy as np + +def verify_observations(data_dir, scenario_id=0): + """ + 验证观测空间是否正确获取 + + Args: + data_dir: 数据目录 + scenario_id: 场景ID + """ + print("=" * 60) + print("观测空间验证工具") + print("=" * 60) + + # 创建环境 + env = MultiAgentScenarioEnv( + config={ + "data_directory": AssetLoader.file_path(data_dir, "exp_filtered", unix_style=False), + "is_multi_agent": True, + "horizon": 300, + "use_render": False, # 不渲染,加速运行 + "sequential_seed": True, + "reactive_traffic": False, + "manual_control": False, + "filter_offroad_vehicles": True, + "lane_tolerance": 3.0, + "replay_mode": True, + "debug": True, # 启用调试以查看详细信息 + "specific_scenario_id": scenario_id, + "use_scenario_duration": True, + }, + agent2policy=None + ) + + # 重置环境 + print(f"\n加载场景 {scenario_id}...") + obs = env.reset(seed=scenario_id) + + # 输出基本信息 + print(f"\n场景信息:") + print(f" - 可控车辆数: {len(env.controlled_agents)}") + print(f" - 观测数量: {len(obs)}") + print(f" - 场景时长: {env.scenario_max_duration} 步") + print(f" - 车辆生成列表长度: {len(env.car_birth_info_list)}") + print(f" - 当前回合数 (round): {env.round}") + + # 检查车辆生成信息 + if len(env.car_birth_info_list) > 0: + print(f"\n车辆生成信息分析:") + show_times = [car['show_time'] for car in env.car_birth_info_list] + print(f" - show_time 分布: min={min(show_times)}, max={max(show_times)}") + print(f" - show_time == 0 的车辆数: {sum(1 for st in show_times if st == 0)}") + print(f" - 前5个车辆的 show_time: {show_times[:5]}") + else: + print(f"\n⚠️ 警告: 车辆生成列表为空!可能原因:") + print(f" 1. 所有车辆都被车道过滤移除") + print(f" 2. 所有车辆都被类型过滤移除") + print(f" 3. 场景数据中没有有效车辆") + + # 验证观测空间 + print(f"\n" + "=" * 60) + print("观测空间验证") + print("=" * 60) + + if len(obs) == 0: + print("❌ 错误:没有获取到任何观测!") + env.close() + return False + + # 检查第一个观测 + first_obs = obs[0] + print(f"\n第一个车辆的观测:") + print(f" - 观测类型: {type(first_obs)}") + print(f" - 观测维度: {len(first_obs)}") + + # 详细解析观测 + if isinstance(first_obs, (list, np.ndarray)): + obs_array = np.array(first_obs) + print(f" - 观测形状: {obs_array.shape}") + print(f" - 数据类型: {obs_array.dtype}") + print(f"\n观测内容分解:") + # 新观测划分(与 Env/scenario_env._get_all_obs 对齐): + # [x, y] (2) + [vx, vy] (2) + heading (1) + # + nearest vehicles info (10 vehicles * 4 = 40) + # + side_lidar (10) + lane_line_lidar (10) + # + traffic_light (1) + destination (2) + lidar_len = 40 + side_len = 10 + lane_len = 10 + pos_slice = slice(0, 2) + vel_slice = slice(2, 4) + heading_idx = 4 + lidar_slice = slice(5, 5 + lidar_len) + side_slice = slice(lidar_slice.stop, lidar_slice.stop + side_len) + lane_slice = slice(side_slice.stop, side_slice.stop + lane_len) + tl_idx = lane_slice.stop + dest_slice = slice(tl_idx + 1, tl_idx + 3) + + print(f" 位置 (x, y): {obs_array[pos_slice]}") + print(f" 速度 (vx, vy): {obs_array[vel_slice]}") + print(f" 航向角: {obs_array[heading_idx]:.3f} 弧度") + print(f" 最近车辆信息: {len(obs_array[lidar_slice])} 维 (10辆*4: 相对x/相对y/相对vx/相对vy)") + print(f" 侧向检测: {len(obs_array[side_slice])} 个点") + print(f" 车道线检测: {len(obs_array[lane_slice])} 个点") + print(f" 交通灯状态: {obs_array[tl_idx]}") + print(f" 目的地 (x, y): {obs_array[dest_slice]}") + + # 检查数据有效性 + print(f"\n数据有效性检查:") + has_nan = np.isnan(obs_array).any() + has_inf = np.isinf(obs_array).any() + + if has_nan: + print(f" ❌ 观测包含 NaN 值!") + else: + print(f" ✅ 无 NaN 值") + + if has_inf: + print(f" ❌ 观测包含 Inf 值!") + else: + print(f" ✅ 无 Inf 值") + + # 检查最近车辆信息数据(40维) + lidar_data = obs_array[lidar_slice] + lidar_min = np.min(lidar_data) + lidar_max = np.max(lidar_data) + print(f"\n 最近车辆特征范围: [{lidar_min:.2f}, {lidar_max:.2f}]") + + if lidar_max > 0: + print(f" ✅ 存在非零最近车辆特征") + else: + print(f" ⚠️ 最近车辆特征全零(可能无邻车或距离过远)") + + # 运行几步,验证观测持续有效 + print(f"\n" + "=" * 60) + print("多步运行验证(前 5 步)") + print("=" * 60) + + for step in range(5): + # 空动作 + actions = {aid: [0.0, 0.0] for aid in env.controlled_agents} + obs, rewards, dones, infos = env.step(actions) + + print(f"\nStep {step + 1}:") + print(f" - 活跃车辆数: {len(env.controlled_agents)}") + print(f" - 观测数量: {len(obs)}") + + if len(obs) > 0: + sample_obs = np.array(obs[0]) + print(f" - 第一辆车位置: ({sample_obs[0]:.2f}, {sample_obs[1]:.2f})") + print(f" - 数据有效: {'✅' if not (np.isnan(sample_obs).any() or np.isinf(sample_obs).any()) else '❌'}") + + if dones["__all__"]: + print(f" - 场景结束") + break + + env.close() + + print(f"\n" + "=" * 60) + print("✅ 验证完成!观测空间正常工作") + print("=" * 60) + + return True + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="验证观测空间") + parser.add_argument("--data_dir", type=str, default="/home/huangfukk/mdsn", help="数据目录") + parser.add_argument("--scenario_id", type=int, default=0, help="场景ID") + + args = parser.parse_args() + + verify_observations(args.data_dir, args.scenario_id) diff --git a/train_magail.py b/train_magail.py new file mode 100644 index 0000000..0fa9c58 --- /dev/null +++ b/train_magail.py @@ -0,0 +1,115 @@ +import torch +import numpy as np +from torch.utils.tensorboard import SummaryWriter +from Env.scenario_env import MultiAgentScenarioEnv +from Algorithm.magail import MAGAIL +from Algorithm.buffer import RolloutBuffer # 假设 Buffer 在这里 +# 假设你有加载专家数据的工具 +# from utils import load_expert_buffer + +# --- 配置 --- +CONFIG = { + "data_dir": "/home/huangfukk/mdsn", + # ... 其他配置 ... + "state_dim": 30, # 你的 Observation 维度 + "action_dim": 2, # [steering, throttle] + "rollout_length": 2048, # Buffer 大小 +} + +def train(): + # 1. 环境 + env = MultiAgentScenarioEnv(config={...}, agent2policy=None) + + # 2. 专家 Buffer (伪代码) + # expert_buffer = RolloutBuffer(...) + # expert_buffer.load(...) + # 必须保证 expert_buffer.sample() 返回 (state, next_state) 供 Discriminator 训练 + + # 3. 初始化 MAGAIL + magail = MAGAIL( + buffer_exp=expert_buffer, # 传入专家 buffer + input_dim=(CONFIG["state_dim"],), + action_shape=(CONFIG["action_dim"],), # PPO 初始化可能需要 action_shape,原代码好像没传? + device=torch.device("cuda"), + rollout_length=CONFIG["rollout_length"] + ) + + writer = SummaryWriter("./logs") + + # 4. 训练循环 + total_steps = 0 + obs_dict = env.reset() + + # 将 Dict Obs 转为 Array: (N_agents, Obs_dim) + # 假设所有 Agent 的 Obs 维度相同 + agents = list(obs_dict.keys()) + current_obs = np.stack([obs_dict[a] for a in agents]) + + # 如果需要 state_gail,假设它就是 obs + current_state_gail = current_obs.copy() + + while total_steps < 1e7: + # --- 收集数据 (Rollout) --- + # PPO 的 buffer 长度是 rollout_length + # 我们可以在这里调用 magail.step 来自动处理 explore 和 buffer append + + # 注意:magail.step 内部调用了 env.step,这可能不适用于 Multi-Agent 环境返回 Dict 的情况 + # 原 PPO.step 似乎是为 Single-Agent 或 VectorEnv 设计的 + # 这里我们需要手动写 Rollout 循环或者修改 PPO.step 以适配 Dict 返回值 + + # --- 方案:手动 Rollout 适配 Multi-Agent --- + for _ in range(CONFIG["rollout_length"]): + # 1. 决策 + actions_list, log_pis_list = magail.explore(current_obs) + + # 2. 拼装 Action Dict + action_dict = {agent_id: action for agent_id, action in zip(agents, actions_list)} + + # 3. 环境步进 + next_obs_dict, rewards_dict, dones_dict, infos_dict = env.step(action_dict) + + # 4. 处理返回值 + # 需要处理 Agent 死亡/重置的情况。这里简化假设 Agent 数量不变 + next_obs = np.stack([next_obs_dict[a] for a in agents]) + rewards = np.array([rewards_dict[a] for a in agents]) + dones = np.array([dones_dict[a] for a in agents]) + # truncated 通常在 infos 里或者 dones 里隐含,需根据 gym 版本确认 + terminated = dones # 简化 + truncated = [False] * len(agents) # 简化 + + # 5. 存入 Buffer + # 获取当前 Actor 的均值方差用于 PPO 更新 + with torch.no_grad(): + # 重新计算一遍或者在 explore 里返回 + # 这里 magail.actor(state) 返回的是 mean (deterministic action) + means = magail.actor(torch.tensor(current_obs, device=magail.device, dtype=torch.float)).cpu().numpy() + stds = magail.actor.log_stds.exp().detach().cpu().numpy() + # 如果是 StateIndependentPolicy,std 可能是共享的参数,维度需要广播 + if stds.shape[0] != len(agents): + stds = np.repeat(stds, len(agents), axis=0) + + magail.buffer.append( + current_obs, current_state_gail, actions_list, + rewards, dones, terminated, log_pis_list, + next_obs, next_obs, # next_state_gail = next_obs + means, stds + ) + + current_obs = next_obs + current_state_gail = next_obs # 更新 gail state + total_steps += len(agents) # 步数增加 N + + # 处理 Done + if all(dones): + obs_dict = env.reset() + current_obs = np.stack([obs_dict[a] for a in agents]) + current_state_gail = current_obs + + # --- 更新 --- + # 当 Buffer 满时 (rollout_length),调用 update + magail.update(writer, total_steps) + + # 保存 + if total_steps % 10000 == 0: + magail.save_models("./models") +