Transition Models: Rethinking the Generative Learning Objective


TL;DR

本文提出了一种名为过渡模型 (Transition Models, TiM) 的新型生成范式,它通过学习一个能够在任意时间间隔 \($\Delta t\)$ 上进行状态转换的精确动力学方程,成功统一了高效的少步生成与高质量的多步精炼,解决了现有生成模型中普遍存在的“速度-质量”权衡困境。

关键定义

本文的核心是围绕一个新的生成学习目标展开的,其关键概念如下:

相关工作

当前视觉内容生成领域由扩散模型主导,但面临一个根本性的困境:

本文旨在解决上述两难困境,即模型要么保真度高但计算昂贵,要么效率高但牺牲了精炼能力。本文提出的问题是:什么才是生成模型最合适的学习目标? 作者认为,一个理想的目标应该能让模型学习一个通用的、由时间间隔 \($\Delta t\)$ 参数化的去噪算子,从而在少步和多步生成场景下都能表现出色,并随着计算预算的增加而单调提升质量。

不同生成范式图示 图 2: 不同生成范式图示。传统扩散模型学习局部向量场,少步模型学习固定的终点映射(单个大步),而本文的过渡模型 (TiM) 训练用于掌握任意状态间的转换。这种方法使 TiM 能够学习生成过程的整个解流形,统一了少步和多步生成机制。

本文方法

本文首先分析了传统PF-ODE监督的局限性,然后推导出一个适用于任意时间间隔的状态转换恒等式,并基于此构建了一个可扩展且稳定的学习目标。最后,提出了针对性的架构改进。

PF-ODE监督的局限性

扩散模型通过前向过程 \($ \mathbf{x}\_{t}=\alpha\_{t}\mathbf{x}+\sigma\_{t}\mathbf{\varepsilon}\)$ 对数据进行加噪。其生成过程等价于求解一个逆时 PF-ODE:

\[\frac{\mathrm{d}\mathbf{x}_{t}}{\mathrm{d}t} = \mathrm{f}(\mathbf{x}_{t},t)-\frac{1}{2}\mathrm{g}(t)^{2}\nabla_{\mathbf{x}_{t}}\log p_{t}(\mathbf{x}_{t})\]

模型的训练目标 \($\mathbf{f}\_{\theta}(\mathbf{x}\_{t},t)\)$ 实质上是在监督这个微分方程的向量场。在采样时,需要使用数值求解器进行积分。为了保证精度,步长 \($\Delta t\)$ 必须很小,这导致了高昂的 NFE。

状态转换

本文方法的核心是从第一性原理出发,将状态转换视为一个必须对任意时间间隔 \($\Delta t = t - r\)$ 精确成立的恒等式,而非数值近似。

状态转换恒等式

从任意状态 \($\mathbf{x}\_t\)$ 可以预测出 \($\hat{x}\)$ 和 \($\hat{\mathbf{\varepsilon}}\)$,进而可以表示任意先前的状态 \($\mathbf{x}\_r = \alpha\_r \hat{\mathbf{x}} + \sigma\_r \hat{\mathbf{\varepsilon}}\)$。将此过程用一个依赖于起始时间 \(t\) 和目标时间 \(r\) 的网络 \($\mathbf{f}\_{\theta}(\mathbf{x}\_{t},t,r)\)$ 来参数化,可以得到:

\[\mathbf{x}_{r}=\frac{(\alpha_{r}\hat{\sigma}_{t}-\sigma_{r}\hat{\alpha}_{t})\mathbf{x}_{t}+(\sigma_{r}\alpha_{t}-\alpha_{r}\sigma_{t})\mathbf{f}_{\theta}(\mathbf{x}_{t},t,r)}{\hat{\sigma}_{t}\alpha_{t}-\hat{\alpha}_{t}\sigma_{t}}\]

将上式简写为 \($\mathbf{x}\_r = A\_{t,r}\mathbf{x}\_t + B\_{t,r}\mathbf{f}\_{\theta,t,r}\)$。通过对时间 \(t\) 求导并整理,本文推导出一个关键的状态转换恒等式

\[\frac{\mathrm{d}(B_{t,r}\cdot(\hat{\alpha}_{t}\mathbf{x}+\hat{\sigma}_{t}\mathbf{\varepsilon}-\mathbf{f}_{\theta,t,r}))}{\mathrm{d}t}=0\]

该恒等式展开后包含两项:

\[(\underbrace{\hat{\alpha}_{t}\mathbf{x}+\hat{\sigma}_{t}\mathbf{\varepsilon}-\mathbf{f}_{\theta,t,r}}_{\text{PF-ODE supervision}})\frac{\mathrm{d}B_{t,r}}{\mathrm{d}t}+B_{t,r}\underbrace{\frac{\mathrm{d}(\hat{\alpha}_{t}\mathbf{x}+\hat{\sigma}_{t}\mathbf{\varepsilon}-\mathbf{f}_{\theta,t,r})}{\mathrm{d}t}}_{\text{time-slope matching}}=0\]

这个恒等式强加了比传统扩散模型更严格的约束:

  1. 隐式轨迹一致性:它要求加权残差 \($B\_{t,r}h(t)\)$(其中 \($h(t)\)$ 是瞬时残差)对于任何以 \($\mathbf{x}\_r\)$ 为终点的轨迹,其值在整个轨迹上保持不变。这确保了多步采样路径的内在一致性,使得增加采样步骤成为一种精炼而非偏离。
  2. 时间斜率匹配:它不仅要求瞬时残差 \($h(t) \to 0\)$(传统目标),还要求残差的时间导数 \($\frac{\mathrm{d}}{\mathrm{d}t}h(t) \to 0\)$。这种高阶监督使得模型学习到的解流形更平滑,在大步长采样时保持连贯性,在小步长时保证稳定精炼。

学习目标

基于状态转换恒等式,本文导出了一个动态的学习目标 \($\hat{\mathbf{f}}\)$:

\[\hat{\mathbf{f}}=\hat{\alpha}_{t}\mathbf{x}+\hat{\sigma}_{t}\mathbf{\varepsilon}+\frac{B_{t,r}}{\frac{\mathrm{d}B_{t,r}}{\mathrm{d}t}}\left(\frac{\mathrm{d}\hat{\alpha}_{t}}{\mathrm{d}t}\mathbf{x}+\frac{\mathrm{d}\hat{\sigma}_{t}}{\mathrm{d}t}\mathbf{\varepsilon}-\frac{\mathrm{d}\mathbf{f}_{\theta^{-},t,r}}{\mathrm{d}t}\right)\]

其中 \($\theta^{-}\)$ 表示固定的网络参数。最终的 TiM 训练目标为:

\[\mathbb{E}_{\mathbf{x},\mathbf{\varepsilon},t,r}\left[w(t,r)\cdot d\left(\mathbf{f}_{\theta}(\mathbf{x}_{t},t,r)-\hat{\mathbf{f}}\right)\right]\]

其中 \($w(t,r)\)$ 是一个为稳定训练而引入的权重函数。

训练的可扩展性与稳定性

  1. 可扩展性 (Scalability): 为解决学习目标中计算时间导数 \($\frac{\mathrm{d}\mathbf{f}\_{\theta^{-},t,r}}{\mathrm{d}t}\)$ 带来的可扩展性瓶颈(传统JVP方法无法与现代训练优化兼容),本文提出了微分推导方程 (DDE) 进行近似:

    \[\frac{\mathrm{d}\mathbf{f}_{\theta^{-},t,r}}{\mathrm{d}t}\approx\frac{\mathbf{f}_{\theta^{-}}(\mathbf{x}_{t+\epsilon},t+\epsilon,r)-\mathbf{f}_{\theta^{-}}(\mathbf{x}_{t-\epsilon},t-\epsilon,r)}{2\epsilon}\]

    DDE 仅需前向传播,计算效率高,且与 FlashAttention 和 FSDP 等分布式训练技术兼容,使得训练大模型成为可能。


方法 算子   训练   FID    
  FLOPs (G) 延迟 (ms) 吞吐量 (/s) 显存 (GiB) NFE=1 NFE=8 NFE=50
JVP 48.29 213.14 1.80 14.89 49.75 26.22 18.11
DDE 24.14 110.08 2.40 15.23 49.91 26.09 17.99


  1. 稳定性 (Stability): 在训练中,过大的时间间隔 \($\Delta t\)$ 可能导致梯度方差过大和训练不稳定。为此,本文设计了一个损失权重函数 \($w(t,r)\)$,它优先考虑较短的时间间隔,为训练提供更稳定的信号。最终采用的权重函数为:

    \[w(t,r)=({\sigma_{\text{data}}+\tan(t)-\tan(r)})^{-\frac{1}{2}}\]

架构改进

为使模型能有效学习状态转换,本文对 DiT 架构进行了两点改进:

TiM across different NFEs, resolutions, and aspect ratios 图 1: TiM 在不同 NFE、分辨率和宽高比下的卓越性能。

实验结论

关键实验结果

TiM 在各类文生图基准测试中展现了最先进的性能、效率和灵活性。

Qualitative Analysis 图 3: 不同 NFE 下的定性比较。TiM 在所有 NFE 下均提供出色的保真度和文本对齐。


方法 NFE=1 NFE=8 NFE=32 NFE=128
SD3.5-Turbo [61] 0.50 0.66 0.70 0.70
FLUX.1-Schnell [6] 0.68 0.67 0.63 0.58
SD3.5-Large [20] 0.00 0.50 0.69 0.70
FLUX.1-Dev [5] 0.00 0.40 0.64 0.65
TiM 0.67 0.76 0.80 0.83

表 5: 在 GenEval 基准上跨 NFE 的生成质量对比 (得分↑)。TiM展现了随 NFE 增加而单调提升的质量。


消融研究

在 ImageNet-256 数据集上的消融实验验证了各项设计的有效性 (Table 4)。


方法 NFE=1 NFE=8 NFE=50
训练目标      
(a) 基线 (SiT-B/4) 309.5 77.26 20.35
(b) TiM-B/4 (使用 JVP) 49.75 26.22 18.11
(c) TiM-B/4 (使用 DDE) 49.91 26.09 17.99
架构      
(d) 原始架构 56.22 28.75 20.37
(e) + 解耦时间嵌入 (De-TE) 49.91 26.09 17.99
(f) + 间隔感知注意力 (IA-Attn) 48.38 26.10 17.85
(g) + De-TE + IA-Attn 48.30 25.05 17.43
训练策略 (在(g)基础上)      
(h) + 时间权重 47.46 24.62 17.10

表 4: 在 ImageNet-256 上的消融研究 (FID↓)。


最终结论

本文提出的过渡模型 (TiM) 是一种更高效、更强大的生成范式。通过一个统一的模型,它不仅解决了生成领域长期存在的速度-质量权衡问题,实现了从单步到多步的质量单调提升,还在一个紧凑的模型(865M)中超越了数倍于其大小的业界模型,并展现了出色的高分辨率生成能力。这项工作为下一代兼具高效、可扩展和创造潜力的基础模型铺平了道路。