WAV:让世界模型自我改进 — World Action Verifier 深度解读
论文解读:世界模型如何通过前向-逆向不对称性自我验证和自我提升
论文信息
- 标题:World Action Verifier: Self-Improving World Models via Forward-Inverse Asymmetry
- 作者:Yuejiang Liu, Fan Feng, Lingjing Kong, Weifeng Lu, Jinzhou Tang, Kun Zhang, Kevin Murphy, Chelsea Finn, Yilun Du
- 机构:斯坦福大学、加州大学圣地亚哥分校、卡内基梅隆大学、Google DeepMind、哈佛大学
- 发表:ICLR 2026 Workshop on World Models (Oral), ICLR 2026 Workshop on Self-Improvement (Spotlight)
- 链接:arXiv:2604.01985 | 项目主页 | 代码
一句话总结
本文提出 WAV (World Action Verifier),利用前向-逆向不对称性(forward-inverse asymmetry)让世界模型能够自我识别预测错误并自动迭代改进,在有限标注数据情况下将样本效率提高2倍,下游策略性能提升18%【原文 Abstract】。

背景与动机
世界模型(World Model) 是机器人学习领域的核心基础设施——它学习以动作条件的前向动力学模型(action-conditioned forward dynamics model),可以用于策略评估、优化和规划。有了准确的世界模型,机器人就可以在"脑子里"想象未来,不需要每次都在真实世界试错。
但现有的世界模型存在一个根本问题:
与只关注最优动作的策略学习不同,世界模型需要对更广泛范围的次优动作都保持可靠性,而这些次优动作在有动作标注的数据中往往覆盖不足【原文 §1 Introduction】。
具体来说:
- 训练数据中只有少量被专家或当前策略访问过的(状态,动作)对
- 大量未探索过的状态-动作组合,预测结果不可靠
- 在长轨迹规划中,预测误差会逐步累积放大,最终导致规划失败
研究的核心问题就是:如何让世界模型在数据有限、探索不充分的情况下,自我识别哪些预测不可靠,并自我改进?【官方博客】
核心方法
核心洞察:前向-逆向不对称性
作者提出一个关键观察:如果把验证任务分解成两个子问题,每个子问题都能更容易解决,因为它们可以利用不对称的数据资源:
- 状态合理性验证(State Plausibility):判断一个未来状态是否可能存在
- 这个问题可以用无动作的视频数据来训练验证器
-
而无标注视频在机器人领域很容易大量获取,不需要 costly 的动作标注
-
动作可及性验证(Action Reachability):判断一个动作能否从当前状态到达目标状态
- 这个问题只需要低维度的动作相关特征,比完整状态预测维度低得多
这种不对称性就是方法得名的原因——前向模型需要做完整预测,而分解后的两个验证任务各自利用不对称的数据资源和建模优势【原文 §2 Method】。
WAV 整体架构
WAV 由四个关键组件构成,形成一个闭环自改进系统:
graph LR
A[当前状态] --> B[子目标生成器<br/>从视频语料库生成]
B --> C[候选子目标]
C --> D[稀疏逆向模型<br/>推断动作]
D --> E[推断动作]
A --> F[前向世界模型<br/>预测下一状态]
E --> F
F --> G[循环一致性检查]
G -->|通过| H[保留预测]
G -->|不通过| I[识别为错误]
I --> J[加入训练集]
J --> K[微调世界模型]
K --> A
四个核心组件:
-
多样化子目标生成器(Diverse Subgoal Generator):从大规模视频语料库中学习生成多样的子目标状态,覆盖大量未探索过的潜在状态【原文 §3 Approach】。
-
稀疏逆向模型(Sparse Inverse Model):不同于完整逆模型需要从完整状态推断动作,WAV 的逆向模型只从状态特征的子集推断动作,这样降低了问题维度,更容易训练【原文 §3 Approach】。
-
循环一致性验证(Cycle Consistency Verification):验证逻辑很优雅:
- 生成子目标 → 逆向模型推断动作 → 前向世界模型用推断的动作做rollout
-
如果 rollout 结果能回到原目标,则循环闭合,预测可信;否则识别为错误
这一步完全是自监督的,不需要额外标注【原文 §3 Approach】。 -
自改进循环(Self-Improvement Cycle):被验证出来的错误预测会被添加到训练集,重新微调世界模型,形成闭环——模型不断发现自己的错误,然后修正自己【官方博客】。
Video Model
在 WAV (World Action Verifier) 框架中,Video Model(视频模型)扮演着多样化子目标生成器 (Diverse Subgoal Generator) 的角色,是验证世界模型预测准确性的核心组件之一。
以下是关于该视频模型的详细介绍:
- 核心功能与定义
- 子目标生成器: 视频模型被定义为一个生成概率模型 $p_\phi$,其主要任务是根据当前状态 $s_t$ 采样生成 $K$ 个候选的未来状态(即子目标 $s_{t+1}$)。
-
状态合理性验证 (State Plausibility): 它的核心作用是验证世界模型预测的下一状态是否在视觉和物理上是真实的。它提供了一个关于合理转换的“先验知识”,确保预测不会脱离现实的数据流形(Manifold)。
-
训练与数据来源
- 利用无标签视频: 该模型通常在大规模无动作标签的视频数据集 ($D_{vid}$)上进行预训练,例如互联网视频。
-
数据丰度优势: 这种设计利用了“分布式不对称性”:即无标签视频数据的数量远超带动作标签的机器人交互数据。这使得视频模型能够学习到比世界模型更广泛的环境演化规律。
-
在自改进循环中的作用
在 WAV 的自改进循环(Self-Improving Cycle)中,视频模型是“逆向循环”的第一步: - 提出目标: 首先由视频模型提议一个物理上合理的未来状态(子目标)。
-
引导验证: 随后由逆向模型推断达成该目标所需的动作,最后再让世界模型尝试生成预测,通过对比预测值与视频模型提出的“目标值”之间的一致性来发现错误。
-
实验中的应用
- 作为验证基准: 在 MiniGrid 实验中,研究者利用一半的交互序列(无动作标签)来训练这个“视频先验”,从而为后续的采样提供锚点。
- 提升效率: 通过视频模型提供的合理参考,WAV 能够更有效地识别世界模型在未充分探索区域的预测错误,从而将样本采集效率提高 2 倍。
IDM
IDM (Inverse Dynamics Model),即逆向动力学模型,在本文档中是指一种通过观察状态转换来推断动作的模型。
以下是根据提供的来源对 IDM 的详细介绍:
- 基本定义与功能
- 核心定义:IDM 的基本功能是从状态转换中推断动作(infer actions from state transitions)。
- 输入输出:它通常接收两个连续的状态或观察值($s_t, s_{t+1}$)作为输入,并预测导致这一转换的动作 $\hat{a}_t$,。
-
主要用途:它在机器人学习中可作为策略模型(将计划转化为动作)、正则化工具、标注器(为无标签视频添加动作标签),以及在 WAV 框架中作为验证器。
-
WAV 中的稀疏逆向模型 (Sparse IDM)
在 World Action Verifier (WAV) 框架中,作者特别强调了“稀疏” IDM 的重要性: - 特征选择:Sparse IDM 使用一个可学习的掩码(Mask)来选择与动作相关的状态特征(例如机械臂末端的位姿或被操作物体的运动),从而过滤掉与动作无关的高维噪声(如复杂的背景),,。
-
验证角色:它负责验证“动作可达性 (Action Reachability)”,即验证世界模型预测的状态转换是否真的可以通过特定动作实现,。
-
IDM 的优势(相对于前向模型)
来源指出,逆向验证在很多情况下比前向预测更容易且更稳健,主要基于以下不对称性: - 维度不对称性 (Dimensionality Asymmetry):在许多机器人任务中,动作仅由极少数关键特征决定,其维度远低于世界模型必须预测的完整状态空间,。
- 样本效率:实验表明,在数据有限的情况下,IDM 的准确率提升速度远快于前向世界模型。
-
抗噪能力:Sparse IDM 对环境中的随机干扰(如地板颜色闪烁)具有很强的免疫力,因为这些干扰不包含动作信息,。
-
变体对比
- Vanilla IDM:接收完整的观察帧和机器人本体感觉状态,但在处理复杂交互或未见过的物体(OOD)时容易失效。
- Sparse IDM:通过强制稀疏性,模型能更好地泛化到未见过的物体和交互,具有更强的分步泛化能力,。
关键设计选择
- 分解验证:不直接验证前向模型的完整预测,而是分解为两个更容易的子问题
- 利用无标注数据:状态合理性验证直接受益于大量无动作标注的视频数据,正好契合机器人领域数据标注昂贵的现状
- 稀疏逆模型:只需要状态子集就能推断动作,缓解维度灾难
- 自监督自改进:不需要额外人工标注,模型自己发现错误自己改进
实验与结果
作者在 9 个任务上进行了全面评估,横跨三个不同领域验证方法的通用性:
| 领域 | 任务类型 |
|---|---|
| MiniGrid | 网格世界导航 |
| RoboMimic | 机器人操作 |
| ManiSkill | 铰接物体操作 |
主要实验结论:
| 指标 | 结果 |
|---|---|
| 样本效率 | 达到相同性能所需标注数据量降为原来的 1/2 → 2× 更高样本效率【原文 §4 Experiments】 |
| 下游性能 | 相比于现有方法,下游策略性能 平均提升 18%【原文 §4 Experiments】 |
关键发现:
- 在数据稀缺场景下(只有少量标注数据),提升尤为显著,这正好是实际机器人应用中最常见的场景【官方博客】
- 随着收集更多数据,WAV 持续自我改进,性能稳步提升
- 验证机制能有效识别出不可靠的预测,选择性重训练这些例子比全数据训练更有效【原文 §4 Experiments】
亮点与局限
亮点
- 【方法创新】提出利用前向-逆向不对称性进行自验证的新思路,不同于以往直接训练世界模型的范式
- 【数据高效】能够利用无动作标注的视频数据,这正好契合机器人领域数据标注昂贵的现状
- 【闭环自改进】真正实现了闭环:模型自己找错 → 自己修正,不需要人工介入
- 【通用性】在网格世界、机器人模拟、复杂操作任务三个不同领域都取得一致提升【原文 §4 Experiments】
局限(原文承认的)
- 计算开销更大:WAV 需要三次推理(生成子目标 → 逆向推理 → 前向rollout),比单一步骤的前向模型计算成本更高【原文 §5 Limitations】
- 当前论文只在模拟环境验证,真实机器人实验尚未开展(项目页说明机器人代码即将开源)【官方博客】
思考(此为解读者观点,非原文)
- 循环一致性验证虽然巧妙,但验证质量依赖于逆模型的准确性,如果逆模型本身不准,验证也会出错
- 自改进是增量微调(incremental fine-tuning),可能存在灾难性遗忘问题,需要更稳定的训练机制
- 这个思路和大语言模型中的"自我校正"(self-correction)精神是相通的——让模型发现自己的错误,然后改进,这可能是通向更通用自改进系统的方向
对从业者的启示
- 数据不对称利用:在机器人学习中,不同类型的数据(有动作标注 vs 无标注视频)不对称,方法设计上要考虑如何利用这种不对称性
- 自监督验证:不需要额外标注,通过循环一致性就能做自验证,这在数据稀缺领域非常有价值
- 自改进闭环:让模型自己发现错误自己改进,这比人类去挖错高效得多,可能是未来大模型+机器人学习的一个重要方向
延伸阅读
- 世界模型基础:《Planning with Uncertainty in World Models》
- 机器人自监督学习:利用大规模无标注视频预训练机器人基础模型
- 自改进学习:LLM 领域的 Self-Refine 也有类似思想
代码结构分析(Minigrid 实验版)
官方开源的 wav_minigrid 是一个精心设计的受控验证实验,代码结构清晰模块化。
项目结构
wav_minigrid/
├── src/wav_minigrid/ # 核心库代码
│ ├── models/
│ │ ├── wm.py # 前向世界模型 (WorldModel)
│ │ └── idm.py # 逆动力学模型 (DenseIDM, SparseIDM)
│ ├── dataset.py # 数据集加载与包装类
│ ├── al_utils.py # 主动学习工具函数(数据选择策略)
│ ├── utils.py # 训练/评估工具函数
│ ├── evaluate_generation.py # MiniGrid 物理预言机(仿真)
│ └── config.py # 实验配置
├── exps/ # 实验脚本
│ ├── idm_comparison.py # SparseIDM vs DenseIDM 对比
│ ├── data_efficiency_gap.py # 数据效率对比
│ ├── state_complexity_gap.py # 复杂度缩放实验
│ ├── noise_robustness.py # 噪声鲁棒性实验
│ └── wm_active_learning.py # WAV 主动学习实验
├── env/ # 环境数据收集代码
├── data/ # 预处理好的数据集 (.npz)
└── checkpoints/ # 预训练检查点
核心模块实现细节
1. SparseIDM:稀疏逆动力学模型(核心创新)
SparseIDM 是论文的关键贡献,主要创新在于通过可学习的注意力 mask 只关注变化区域:
# 1. Mask 生成:输入当前帧和下一帧,输出变化区域 mask
mask, mask_logits = mask_gen(combined_for_mask, tau=tau)
# 2. 稀疏编码:只对 mask 区域池化
masked_features = features * mask
sum_features = torch.sum(masked_features, dim=(2, 3))
# 3. 稀疏正则:鼓励只选择少量单元
target_cells = 2.0 # 正好对应智能体+前方物体两个格子
sparsity_loss = torch.mean(torch.abs(cells_selected - target_cells))
loss = ce_loss + 0.1 * sparsity_loss
关键洞察:在网格世界中,一步动作通常只改变少数几个格子,全帧编码会引入大量噪声。让模型学习自动选关注变化区域,本质上是引入了正确的空间归纳偏置。
2. WorldModel:前向世界模型
世界模型使用 FiLM 调制来注入动作条件:
# FiLM 条件化注入动作信息
film_params = film_gen(z) # z = action embedding
gamma, beta = split(film_params)
h_delta = (1 + gamma) * dyn_conv(h_t) + beta
h_next_pred = h_t + h_delta # 残差预测
- 输入表示:每个格子 3 通道 =
物体类型ID + 颜色ID + 状态ID,各自嵌入后拼接 - 输出:三个分类头预测下一帧每像素,一个回归头预测携带物体信息
- 残差预测:只预测变化量,更容易训练
3. WAV 主动采集策略
WAV 的数据选择策略基于预测不一致性:
# 1. 视频模型(动作自由)预测下一状态
gen_out = video_gen_model(inputs_flat, mode='inference')
# 2. 逆模型根据当前 + 预测下一状态推断动作
inv_inputs = {frame: [curr_frame, s_gen_frame], ...}
pred_actions = argmax(inverse_model(inv_inputs))
# 3. 世界模型用推断动作再次预测下一状态
wm_out = world_model(inputs_flat, mode='predict_with_action', gt_actions=pred_actions)
# 4. 计算预测不一致性得分,选择得分最高的样本
score = frame_loss + 10.0 * (carried_loss)
选择不一致性最大的样本加入训练集 —— 这些就是当前模型最不确定、对模型改进最有帮助的样本。
支持的五种采集策略对比
| 策略 | 思想 | 说明 |
|---|---|---|
Random |
随机选择 | 基线 |
Hard-Oracle |
选择当前损失最大的样本 | 基于当前模型损失 |
Uncertainty |
基于 MC Dropout 不确定性 | 用 Bald 熵衡量不确定性 |
Progress |
选择损失下降最大的样本 | 基于两轮模型间的损失差 |
WAV |
选择前向-逆向预测不一致性最大 | 本文方法 |
MiniGrid 物理预言机
MiniGridPhysicsOracle 是一个轻量的仿真器,用于在不需要实际环境交互的情况下生成下一状态真值,方便离线主动学习:
# 支持所有基本 MiniGrid 动作
# 0: Turn Left, 1: Turn Right, 2: Move Forward
# 3: Pickup, 4: Drop, 5: Toggle, 6: Done (Swap)
next_frame, next_c_col, next_c_obj = oracle.step(frame, carried_col, carried_obj, action)
主要实验结论(复现)
代码库中已经包含了所有实验数据和预训练权重,可以直接运行验证论文结论:
- IDM 鲁棒性对比:
python exps/idm_comparison.py -
SparseIDM 在分布偏移下,交互动作准确率显著高于 DenseIDM
-
样本效率:
python exps/data_efficiency_gap.py -
SparseIDM 用更少数据就能达到和前向世界模型相当的性能
-
状态复杂度缩放:
python exps/state_complexity_gap.py -
随着环境中物体数量增加,SparseIDM 性能下降更慢
-
噪声鲁棒性:
python exps/noise_robustness.py -
随着环境中随机噪声 tiles 增加,SparseIDM 保持更好鲁棒性
-
主动学习:
python exps/wm_active_learning.py - WAV 采集策略比其他策略更高效,相同标注预算下获得更低预测误差
附录理论总结
论文附录提供了完整的理论分析和算法伪代码,证明 WAV 的有效性。
两个核心命题
Proposition 3.1(分布级鲁棒性):
在满足三个条件时(存在因果独立的动作相关子集、子集依然在支持集、动作可识别),稀疏逆模型在完整状态-动作对离分布(OOS)时依然能正确恢复动作。前向-逆向不一致性可以准确定位前向模型错误。
意义:这从理论上保证了 WAV 在未充分探索区域依然能有效工作。
Proposition 3.2(样本效率优势):
在线性高斯假设下,稠密前向模型与稀疏逆模型的期望误差比满足:
$$ \frac{\mathbb{E}[\mathcal{E}_F]}{\mathbb{E}[\mathcal{E}_I]} \geq \underbrace{\Bigl(\frac{d_s+d_a}{2d_z} \cdot \frac{d_s}{d_a}\Bigr)}_{\text{维度效应}} \cdot \underbrace{\Bigl(\frac{\sigma_s}{\lambda \sigma_a}\Bigr)^2}_{\text{噪声效应}} \cdot \underbrace{\Bigl(\frac{n-2d_z-1}{n-(d_s+d_a)-1}\Bigr)}_{\text{样本量效应}} $$
结论:WAV 优势在以下情况最显著:
1. 大 $\frac{d_s}{d_z}$ → 全状态维度远大于动作相关子集维度(大多数机器人场景成立)
2. 大 $\frac{\sigma_s}{\sigma_a}$ → 状态观测噪声远大于动作推断噪声
3. 小 $n$ → 标注样本量少(数据稀缺场景)
这正好对应论文实验结论:数据稀缺时 WAV 提升最大。
完整 WAV 算法伪代码
# Algorithm 1: WAV-Guided Exploration
# s: current state, f: world model
# v: subgoal generator, h: inverse model
# D: current dataset, K: number of candidates
for each exploration iteration:
s_g = v.sample(s, K) # 从子目标生成器采样 K 个候选
a = h.inverse(s, s_g) # 逆模型推断到达每个子目标需要的动作
s_p = f.predict(s, a) # 前向世界模型用推断动作预测下一状态
scores = dist(s_g, s_p) # 计算不一致性得分(距离)
idx = argmax(scores) # 选择不一致性最大的样本
s_n = env.step(a[idx]) # 在真实环境执行,获取真值
D.append((s, a[idx], s_n)) # 加入数据集
f.update(D), h.update(D) # 更新模型
逆向验证循环:
$$s^t \xrightarrow{p_\phi} \tilde{s}^{t+1} \xrightarrow{h_\psi} \hat{a}^t \xrightarrow{f_\theta} \hat{s}^{t+1}$$
验证信号:差异 $\hat{\varepsilon} = \ell(\tilde{s}^{t+1}, \hat{s}^{t+1})$ 越大,说明模型对这个样本越不确定,越应该优先学习。
组件分工总结
| 组件 | 训练数据 | 作用 |
|---|---|---|
| Subgoal generator $p_\phi$ | $\mathcal{D}_{\text{vid}}$(无动作标注视频) | 采样 K 个候选子目标,作为流形上的参考点 |
| Sparse inverse model $h_\psi$ | $\mathcal{D}_{\text{act}}$(带动作标注交互数据) | 学习 mask $M$ 选择动作相关特征:$\hat{a}^t = h_\psi(M \odot s^t, M \odot s^{t+1})$ |
| Forward world model $f_\theta$ | $\mathcal{D}_{\text{act}}$ | 给定 $(s^t, a^t)$ 预测下一状态 |
深度架构分析:核心模块源码阅读
从实际阅读源码来看,WAV 的代码实现非常清晰干净,每个关键设计决策都直接对应论文的创新点。我们来深入看看两个核心模型的架构细节。
WorldModel:前向世界模型完整架构
WorldModel 类位于 models/wm.py,是一个带 Vector Quantization 的残差预测网络:
输入表示层:
# MiniGrid 输入是离散网格:每个格子有三个属性
self.obj_embedding = nn.Embedding(self.NUM_OBJ_CLASSES, self.emb_dim) # 20 类物体
self.col_embedding = nn.Embedding(self.NUM_COL_CLASSES, self.emb_dim) # 10 类颜色
self.state_embedding = nn.Embedding(self.NUM_STATE_CLASSES, self.emb_dim) # 10 类状态
每个格子输出 3 × 8 = 24 维,然后显式添加坐标通道(xx, yy 归一化到 [0,1]),总共 26 通道输入 CNN。添加坐标通道让网络直接获得位置信息,不用自己学习。
FiLM 动作条件化 —— 这是整个架构最精彩的设计:
film_params = self.film_gen(z) # z = 动作潜变量 / 动作嵌入
gamma, beta = torch.split(film_params, self.feature_dim, dim=-1)
gamma = gamma.unsqueeze(-1).unsqueeze(-1)
beta = beta.unsqueeze(-1).unsqueeze(-1)
h_mid = self.dyn_conv(h_t_map) # 卷积提取特征
h_delta = (1 + gamma) * h_mid + beta # FiLM 特征-wise 线性调制
h_next_pred = h_t_map + h_delta # 残差连接:只预测变化量
设计优势:
- FiLM(Feature-wise Linear Modulation)比简单地在通道维度拼接动作更好地保留空间结构
- 残差预测:网络只需要预测 h_t 到 h_next 的增量,学习任务更容易
- 初始化技巧:film_gen 最后一层初始化为全 0,训练开始时 h_delta = 0,相当于恒等映射,训练更稳定
输出头设计:
self.head_obj_cls = Conv2d(32, NUM_OBJ_CLASSES, 1) # 物体类型分类
self.head_col_cls = Conv2d(32, NUM_COL_CLASSES, 1) # 颜色分类
self.head_state_cls = Conv2d(32, NUM_STATE_CLASSES, 1) # 状态分类
self.carried_head = MLP(...) -> 2维回归 # 携带物体信息
因为每个格子是多分类任务,所以用 1×1 卷积输出每个空间位置的分类 logits。
SparseIDM:稀疏逆动力学模型架构
SparseIDM 是论文的核心创新,位于 models/idm.py。相比普通稠密逆模型,它通过可学习的稀疏 mask 只关注变化区域:
Mask 生成器:
class MaskGenerator(nn.Module):
def __init__(self, in_channels=6):
self.net = nn.Sequential(
nn.Conv2d(in_channels, 16, kernel_size=3, padding=1),
nn.ReLU(),
nn.Conv2d(16, 2, kernel_size=1) # 二分类:变化/不变化
)
def forward(self, x, tau=1.0):
logits = self.net(x)
# Gumbel-Softmax 使得离散选择可微
probs = F.gumbel_softmax(logits, tau=tau, hard=True, dim=1)
return probs[:, 1:2, :, :], logits # 返回变化区域的 mask
稀疏编码过程:
# 输入:当前帧 + 下一帧
curr_p = curr_frame.permute(0, 3, 1, 2).float()
next_p = next_frame.permute(0, 3, 1, 2).float()
combined_for_mask = torch.cat([curr_p, next_p], dim=1) # [B, 6, H, W]
mask, mask_logits = self.mask_gen(combined_for_mask, tau=tau)
features = self.feature_extractor(curr_p) # [B, 32, H, W]
masked_features = features * mask # 只保留 mask 区域
sum_features = torch.sum(masked_features, dim=(2, 3)) # 全局池化,空间维度消失
关键稀疏正则化(训练损失):
target_cells = 2.0 # 一步动作通常只改变两个格子:智能体位置 + 目标物体
cells_selected = mask.sum(dim=(1,2,3)) # 统计选中多少个cell
sparsity_loss = torch.mean(torch.abs(cells_selected - target_cells))
loss = ce_loss + 0.1 * sparsity_loss # 鼓励稀疏性
这个正则化非常关键——它引入了正确的归纳偏置:一步动作只会改变少数格子。训练完后,模型确实学会了只关注变化区域。
三种逆模型对比:论文实现了三种设计用于 ablation 研究:
| 模型 | 思想 | 代码位置 |
|---|---|---|
DenseIDM |
全帧 CNN 编码,池化后接 MLP | idm.py L11-L151 |
SparseIDM |
可学习稀疏 mask + 池化(本文方法) | idm.py L170-L286 |
OracleSparseIDM |
手工提取智能位+前方格子(先知基线) | idm.py L287-L419 |
OracleSparseIDM 作为性能上限参考,验证稀疏性本身带来的收益。实验结果:学习出来的 mask 已经非常接近手工先知的性能。
几何特征工程
三种逆模型都使用了相似的几何特征编码,这对区分左转/右转等动作非常重要:
# 方向增量:用 sin/cos 编码,因为方向是环形的(0 ↔ 3 相邻)
delta = (next_dir - curr_dir + 4) % 4
angle = delta * (2 * π / 4)
dir_delta_encoded = [sin(angle), cos(angle)]
# 位置增量:直接坐标差
pos_delta = next_pos - curr_pos
# 最终输入 = [curr_emb, next_emb, curr_carried, next_carried, dir_delta, pos_delta]
- 方向用 sin/cos 编码处理环形周期性,否则模型会认为 0 → 3 比 0 → 1 距离更大,这不正确
- 显式提供增量信息帮助模型快速学习到运动规律
WAV 主动学习数据选择核心逻辑
整个自改进闭环的核心在于如何选择下一批标注样本:
# 1. 视频模型(动作自由建模)生成多个候选子目标
gen_out = video_gen_model(inputs_flat, mode='inference')
# 2. 逆模型根据 (当前状态, 生成子目标) 推断动作
inv_inputs = {'frame': [curr_frame, s_gen_frame], 'carried_col': ..., 'carried_obj': ...}
pred_actions = argmax(inverse_model(inv_inputs))
# 3. 前向世界模型用推断动作预测下一状态
wm_out = world_model(inputs_flat, mode='predict_with_action', gt_actions=pred_actions)
# 4. 计算预测不一致性得分
score = frame_loss(obj+col+state) + 10.0 * carried_loss
选择策略:选择不一致性最大的样本加入训练集。
直觉:不一致性大 → 当前模型前向-逆向循环闭合不了 → 模型对这个样本很不确定 → 学习这个样本收益最大。
五种采集策略对比:
| 策略 | 思想 |
|---|---|
Random |
随机选择(基线) |
Hard-Oracle |
选择当前损失最大的样本 |
Uncertainty |
MC Dropout 熵衡量不确定性 |
Progress |
选择两轮训练间损失下降最大 |
WAV |
前向-逆向预测不一致性(本文) |
论文实验:WAV 在相同标注预算下,获得最低预测误差,样本效率提升一倍。
代码设计亮点
阅读完源码后,几个工程细节值得称赞:
-
正交初始化:所有 CNN/Linear 都用 orthogonal 初始化,ReLU 层计算正确增益:
python init_relu = lambda m: init(m, nn.init.orthogonal_, lambda x: nn.init.constant_(x, 0), nn.init.calculate_gain('relu')) -
坐标通道显式添加:不依赖网络学习位置信息,直接给归一化坐标,省事有效。
-
多种运行模式:WorldModel 支持三种模式,复用同一网络:
posterior:训练时后验编码inference:生成时从先验采样-
predict_with_action:规划时给定动作预测 -
VQ 处理技巧:当有真值动作时,直接用动作索引码本,不需要 VQ 学习:
python if gt_actions is not None: indices = gt_actions.long() z_q = self.vq_layer._embedding(indices) vq_loss = torch.tensor(0.0)
设计选择总结
| 设计选择 | 原因 |
|---|---|
| FiLM 而非通道拼接 | 更好保留空间结构,参数效率更高 |
| 残差预测 | 只学习增量,优化更容易 |
| Gumbel-Softmax mask | 离散选择可微训练 |
| 稀疏性正则 | 引入正确归纳偏置:一步只改变少数格子 |
| 方向增量 sin/cos 编码 | 处理环形周期性,左转/右转可区分 |
来源:
- 论文原文:arXiv:2604.01985
- 项目主页:world-action-verifier.github.io
- 代码:github.com/world-action-verifier/wav_minigrid
- 源码分析:基于本地 ~/Documents/RL/wav/ 实际代码阅读
Comments (0)
Please sign in to leave a comment.