《LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels》深度解读

《LeWorldModel: Stable End-to-End Joint-Embedding Predictive Architecture from Pixels》深度解读

论文信息
- 作者:Lucas Maes, Quentin Le Lidec, Damien Sieur, Yann LeCun, Randall Balestriero
- 机构:Mila & 蒙特利尔大学,纽约大学,Samsung SAIL,布朗大学
- 发表时间:2026年3月13日
- arxiv:2603.19312


一句话总结

LeWorldModel (LeWM) 是第一个能够从原始像素端到端稳定训练的Joint-Embedding Predictive Architecture (JEPA) 世界模型。它只用两个损失项(下一帧嵌入预测损失 + SIGReg高斯正则化)就能避免表征坍塌,把需要调优的超参数从六个减少到一个,规划速度比基于预训练foundation model的世界模型快多达48倍,训练可在单GPU几小时内完成,并且隐空间能够涌现出有意义的物理结构理解能力。【原文 Abstract】


背景与动机

人工智能的一个核心目标是开发能够在多样环境和任务中通过统一范式学习的智能体——直接从原始感官输入(像素)学习,从端到端训练直到输出动作,不需要手工工程的状态表示或领域特定校准【原文 §1 Introduction】。

世界模型(World Models, WMs)是一类强大的方法,它学习预测环境中动作的后果。当训练成功后,世界模型允许智能体在隐式想象空间(imagination space)中改进自己的策略,这在离线设定下特别有价值——智能体必须从固定数据集中学习,通过模型生成合成经验并评估反事实动作序列【原文 §1 Introduction】。

JEPA(Joint-Embedding Predictive Architecture)是近来非常受欢迎的世界模型方法【原文 [5]】。JEPA不试图建模环境的每个方面,只聚焦于捕捉预测未来状态最相关的特征——具体来说,JEPA学习将观测编码到紧凑的低维隐空间,并通过预测未来观测的隐表示来建模时间动态【原文 §1 Introduction】。

然而,尽管JEPA概念简洁,现有JEPA方法非常容易发生表征坍塌(representation collapse):模型将所有输入映射到几乎相同的表示,平凡地满足了预测目标,但得到的表示不可用。防止坍塌是训练JEPA的核心挑战之一【原文 §1 Introduction】。

许多已有工作提出了不同方法来解决这个问题,但这些方法通常:
- 依赖启发式正则化和多目标损失函数
- 需要外部信息源
- 引入架构简化(比如使用预训练编码器)

实践中,这些策略通常带来额外的不稳定性,或显著增加训练复杂度【原文 §1 Introduction】。

下图对比了不同类别方法的特性:

方法 特点 LeWM改进
PLDM(端到端) 需要6+超参数,容易过拟合,容易坍塌 ➕ 1个超参数,可证明抗坍塌
DINO-WM(预训练) 被预训练知识限制,非端到端 ➕ 完全端到端学习
Dreamer/TD-MPC(任务特定) 需要图像重建或奖励信号 ➕ 任务无关,无需重建、无需奖励
- - ➕ 纯像素输入

【原文 Fig 2】


核心方法

整体架构

LeWM包含两个核心组件:
1. Encoder(编码器):将观测帧 $o_t$ 映射为紧凑低维隐表示 $z_t$
- 使用Vision Transformer (ViT) tiny配置:约5M参数,patch size=14,12层,3个注意力头,隐维度192
- 从<[BOS_never_used_51bce0c785ca2f68081bfa7d91973934]>token embedding经过MLP投影 + BatchNorm得到最终表示【原文 §3.1】

  1. Predictor(预测器):根据当前隐状态 $z_t$ 和动作 $a_t$ 预测下一帧的隐表示 $\hat{z}_{t+1}$
  2. 6层Transformer,16个注意力头,dropout 10%:约10M参数
  3. 通过Adaptive LayerNorm (AdaLN)将动作注入每一层
  4. 输入为最近N帧的历史表示,自回归预测下一帧,使用因果mask保证不看未来【原文 §3.1】

公式表示:
$$ z_t = \text{enc}_\theta(o_t), \quad \hat{z}_{t+1} = \text{pred}_\phi(z_t, a_t) $$
【原文 §3.1 Eq (LeWM)】

总体模型只有约15M参数,非常轻量。

训练目标

LeWM的训练目标只有两项,结构极其简洁:
$$ \mathcal{L}_{\text{LeWM}} = \mathcal{L}_{\text{pred}} + \lambda \cdot \text{SIGReg}(Z) $$
【原文 Eq 3】

1. 预测损失 $\mathcal{L}_{\text{pred}}$:预测下一帧隐嵌入与目标隐嵌入之间的均方误差(MSE)
$$ \mathcal{L}_{\text{pred}} \triangleq \|\hat{z}_{t+1} - z_{t+1}\|_2^2, \quad \hat{z}_{t+1} = \text{pred}_\phi(z_t, a_t) $$
【原文 Eq 1】

通过预测损失,编码器被激励学习出对预测器可预测的表示。但仅有这个损失会导致表征坍塌——编码器把所有输入都映射为相同的常数表示,预测损失也能降到很低,但表示完全无用。

2. SIGReg正则化(抗坍塌):为防止平凡坍塌,作者引入Sketched-Isotropic-Gaussian Regularizer(SIGReg),它强制隐嵌入分布匹配各向同性高斯分布,从而促进特征多样性【原文 §3.1】。

SIGReg的巧妙之处在于:高维空间直接检验正态分布很难,所以它通过统计技巧:
- 将隐嵌入投影到 $M$ 个随机单位方向上
- 对每个一维投影,用Epps-Pulley检验统计量衡量分布与标准正态的距离
- 对所有随机投影取平均

$$ \text{SIGReg}(Z) \triangleq \frac{1}{M} \sum_{m=1}^M T(h^{(m)}) $$
【原文 Eq 2】

根据Cramér-Wold定理:匹配所有一维边缘分布等价于匹配整个联合分布。因此,当SIGReg损失趋近于0时,隐嵌入分布必然趋近于各向同性高斯分布【原文 附录 A】。

训练伪代码:

def LeWorldModel(obs, actions, lambd=0.1):
    """
    obs: (B, T, C, H, W) raw pixels sequence
    actions: (B, T, A) action sequence
    lambd: SIGReg loss weight
    """
    emb = encoder(obs)      # (B, T, D)
    next_emb = predictor(emb, actions)  # (B, T, D)

    # Next-embedding prediction loss
    pred_loss = F.mse_loss(emb[:, 1:], next_emb[:, :-1])

    # Step-wise SIGReg (anti-collapse)
    sigreg_loss = mean(SIGReg(emb.transpose(0, 1)))

    return pred_loss + lambd * sigreg_loss

【原文 Algorithm 1】

超参数简化

这是LeWM的一大优势:
- 只需要调优一个有效超参数:正则化权重 $\lambda$(默认取0.1通常效果良好)
- 随机投影数量 $M$ 和积分节点数量对下游性能影响可以忽略,不需要精细调优
- 相比之下,PLDM需要调优6个损失系数,搜索复杂度是 $O(n^6)$;而LeWM只需要对数二分搜索 $O(\log n)$

LeWM不使用stop-gradient、EMA等启发式稳定技巧,梯度回传经过所有组件,所有参数联合端到端优化,训练流程简洁易实现【原文 §3.1】。

隐空间规划推理

训练完成后,LeWM通过模型预测控制(MPC)在隐空间做规划:

  1. 给定初始观测 $o_1$ 和目标观测 $o_g$,encoder分别得到初始隐状态 $z_1$ 和目标隐状态 $z_g$
  2. 用Cross-Entropy Method (CEM)优化动作序列:
  3. 随机初始化候选动作序列分布(高斯)
  4. 迭代:采样候选序列 → rollout预测未来隐状态 → 计算与目标的距离 → 选择最优的elite候选 → 更新分布参数
  5. 输出最优动作序列【原文 §3.2 Algorithm 2】

  6. 目标函数:最小化最终预测隐状态与目标隐状态的距离
    $$ \min_{a_{1:H}} \mathcal{C}(\hat{z}_H) = \|\hat{z}_H - z_g\|_2^2 $$
    【原文 Eq 4-5】

  7. 采用receding-horizon MPC:只执行前K步规划动作,然后根据更新后的观测重新规划【原文 §3.2】。


实验与结果

作者在四个 diverse 任务上评估了LeWM,包括2D和3D环境,涵盖导航、操作:
- Two-Room:简单2D导航,智能体从一个房间穿过门到另一房间目标位置
- Reacher:2关节机械臂到达任务
- PushT:经典机器人基准,2D操作,推T形块到目标位姿
- OGBench-Cube:视觉更丰富的3D操作,机械臂抓取立方体放到目标位置【原文 Fig 5, 附录 E】

规划性能对比

环境 LeWM PLDM DINO-WM
PushT 90% 72% 13%
OGBench-Cube 74% - 48%

【原文 Fig 6】

关键结论:
- 在更具挑战性的任务(PushT、Reacher)上,LeWM持续优于PLDM和DINO-WM,PushT成功率比PLDM高18%【原文 §4.2】
- 即使只使用像素输入,在PushT上LeWM也超过了拥有额外本体感知输入的DINO-WM,证明LeWM能够从纯像素中捕捉任务相关物理量【原文 §4.2】
- 在最简单的Two-Room环境性能稍差,作者分析这是因为环境内在维度低,SIGReg强制高维高斯先验可能不匹配,这是SIGReg的一个已知局限【原文 §4.2】

规划速度对比

方法 完整规划时间
LeWM 0.98s
DINO-WM 47s

【原文 Fig 3】

LeWM使用紧凑的192维隐表示,相比DINO-WM的token化表示,编码观测少约200倍,规划速度快 48×,完整规划可以在1秒内完成,显著缩小了与实时控制之间的差距【原文 §4.2】。

训练稳定性

  • LeWM的两项损失展现平滑单调收敛:预测损失稳步下降,SIGReg在训练早期快速下降,很快就使隐分布接近各向同性高斯【原文 §4.3】
  • 相比之下,PLDM的七项损失表现出噪声大、非单调行为,多个分量之间存在梯度竞争【原文 §4.3】

消融实验显示:
- SIGReg对内部参数(随机投影数量、积分节点)不敏感,性能基本不受影响
- 嵌入维数超过一定阈值后性能很快饱和,方法对具体选择鲁棒
- 换成ResNet-18编码器依然能得到有竞争力的性能,方法对编码器架构选择很大程度上不敏感【原文 §4.3】

物理理解探测

作者通过probe实验检验隐空间编码了哪些物理信息:训练线性探针和MLP探针从隐embedding预测物理量。

PushT环境探测结果:

物理量 模型 Linear MSE ↓ Linear r ↑ MLP MSE ↓ MLP r ↑
Agent Location DINO-WM 1.888 ± 0.500 0.977 0.003 ± 0.022 0.999
Agent Location PLDM 0.090 ± 0.311 0.955 0.014 ± 0.119 0.993
Agent Location LeWM 0.052 ± 0.149 0.974 0.004 ± 0.056 0.998
Block Location DINO-WM 0.006 ± 0.007 0.997 0.002 ± 0.006 0.999
Block Location PLDM 0.122 ± 0.341 0.938 0.011 ± 0.063 0.994
Block Location LeWM 0.029 ± 0.073 0.986 0.001 ± 0.006 0.999
Block Angle DINO-WM 0.050 ± 0.101 0.979 0.009 ± 0.052 0.995
Block Angle PLDM 0.446 ± 0.625 0.745 0.056 ± 0.184 0.972
Block Angle LeWM 0.187 ± 0.359 0.902 0.021 ± 0.139 0.990

【原文 Table 1】

结论:LeWM一致优于PLDM,性能与基于大规模预训练DINOv2的DINO-WM保持竞争力。 这说明端到端训练得到的隐空间已经能够编码有意义的物理结构【原文 §5.1】。

进一步可视化:
- 解码器(训练后仅用于可视化)可以从192维隐表示重建出像素观测,证明低维紧凑表示仍然保留了足够的环境信息【原文 Fig 8】
- t-SNE可视化显示,隐空间保留了空间邻接关系,相似物理状态在隐空间也靠近【原文 Fig 9】
- 训练过程中,LeWM的隐轨迹变得越来越直(temporal latent path straightening),这一现象在神经科学中被认为是理解的标志,LeWM即使不做专门正则化也能涌现出这一性质,并且比PLDM效果更好【原文 §5.1】

违背期望检测(物理异常检测)

作者借鉴发展心理学中的Violation-of-Expectation (VoE)范式测试:模型对不符合物理规律的事件是否会分配更高的"surprise"。

测试了两种扰动:
1. 视觉扰动:物体颜色突然改变
2. 物理扰动:物体瞬间传送到随机位置(违背物理连续性)

实验结果:
- 对所有三个环境,物理扰动(传送)带来的surprise显著高于无扰动(配对t-test p < 0.01)【原文 §5.2 Fig 10】
- 颜色变化带来的surprise提升较弱且不显著
- 结论:模型对物理扰动比视觉扰动更敏感,说明学习到了合理的物理常识,能够可靠检测物理不可能事件【原文 Abstract】


亮点与局限

亮点:
- 方法论:第一个真正稳定端到端训练的JEPA世界模型,仅用两个损失项就解决了表征坍塌问题【原文 §1】
- 简洁性:超参数从六个减少到一个,大大降低调优复杂度,训练流程易复现【原文 §3.1】
- 效率:15M参数轻量模型,规划速度快48倍,单GPU几小时完成训练【原文 Abstract, §4.2】
- 涌现性:隐空间自动涌现出物理结构理解,能够检测物理异常事件【原文 §5】
- 理论基础:SIGReg基于统计理论(Cramér-Wold定理),可证明能防止表征坍塌【原文 附录 A】

局限(原文承认的):
- 当前规划仅限于短视域,长视域规划需要分层世界建模【原文 §6 Conclusion】
- 仍然依赖离线数据集,需要足够的交互覆盖,数据收集成本高【原文 §6】
- 需要动作标签来预测未来状态,动作标签获取成本高【原文 §6】
- 在低内在维度的简单环境中,SIGReg的高斯先验假设可能不匹配,性能略有下降【原文 §4.2】

局限(解读者补充):
- 实验主要在仿真环境验证,尚未在真实机器人硬件上测试(此为解读者推断,非原文明确表述)
- SIGReg强制各向同性高斯先验可能丢弃数据中固有的低秩结构,在简单低维任务上这可能反而有害(原文已指出这一点,此为解读者总结)


意义与讨论

这篇论文是JEPA路线图上的重要一步:它证明了你不需要预训练大模型、不需要复杂的多目标启发式,只需要一个原则性好的正则化,就可以从纯像素端到端训练出一个可用的世界模型。

JEPA的核心思想本来就是"只预测需要预测的,不需要重建像素",LeWM终于让这个思想在训练稳定的前提下落地了。相比于生成式世界模型(如Dreamer系列),JEPA不需要像素重建,计算更高效;相比于基于预训练ViT的方法(如DINO-WM),LeWM完全端到端,可以针对特定环境数据调整编码器,不需要依赖大模型的通用特征。

隐空间涌现出物理理解这一点非常有趣——这和近来"Intuitive physics emerges from self-supervised pre-training"的发现一致【原文 [44]】,进一步说明预测性自监督学习能够让模型学到类人的物理常识,这对机器人在开放环境中鲁棒决策很有价值。


延伸阅读

Comments (0)