CDVAE 是 Conditional Diffusion Variational Autoencoder 的缩写,它将扩散模型(Diffusion Models)的核心思想集成到变分自编码器(VAE)的框架中,并增加了条件生成的能力。它主要用于结构化数据(如分子、晶体材料、点云)的生成与优化。
1. 原理与核心思想
核心目标: 学习一个
条件概率分布 p(x|c)p(\mathbf{x} | \mathbf{c}),其中
x\mathbf{x} 是我们想要生成的
结构化对象(例如一个分子的3D构象),
c\mathbf{c} 是给定的
条件(例如该分子的化学式、目标属性如药物活性、或材料的带隙)。
核心创新: CDVAE 结合了三种范式的优点:
- VAE: 提供一个低维、连续的隐空间,允许对生成过程进行平滑的插值和有意义的语义操作。
- 扩散模型: 作为一个强大的解码器/生成器,通过一个逐步去噪的过程生成数据,通常比传统VAE的解码器能产生质量更高、更多样化的样本。
- 条件生成: 通过将条件信息 c\mathbf{c} 注入到VAE的编码器和扩散解码器的每一步中,实现对生成结果的精确控制。
为什么有效: 传统VAE直接通过一个神经网络解码器从隐变量
z\mathbf{z} 映射到
x\mathbf{x},对于复杂的结构化数据(如分子的3D原子位置和类型),这个一步到位的映射学习非常困难,容易导致模糊或无效的生成结果。扩散模型将这个过程分解为许多小的、易于学习的去噪步骤,极大地提高了生成质量和训练稳定性。
2. 模型架构
CDVAE 通常包含三个主要组件:
a. 编码器 q_(phi)(z|x,c)q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c})
- 输入: 原始结构数据 x\mathbf{x} 和条件信息 c\mathbf{c}。
- 输出: 一个多元高斯分布的参数(均值和方差),从中采样得到全局隐变量 z\mathbf{z}。
- 目的: 将高维、离散/连续混合的结构 x\mathbf{x} 压缩成一个低维、连续的语义表示 z\mathbf{z},并且这个表示与条件 c\mathbf{c} 相关联。
- 结构: 通常是一个图神经网络(GNN),因为输入 x\mathbf{x}(分子/晶体)天然可以用图表示(原子为节点,化学键为边)。
b. 扩散解码器 p_(theta)(x|z,c)p_{\theta}(\mathbf{x} | \mathbf{z}, \mathbf{c})
这是CDVAE的核心。它是一个以时间为条件的去噪模型。
-
前向过程(固定的加噪过程):
给定一个从编码器得到的“干净”结构
x_(0)\mathbf{x}_0(即原始数据),我们按照一个预定义的噪声调度(schedule)逐步添加高斯噪声,生成一系列噪声越来越大的隐变量
x_(1),x_(2),...,x_(T)\mathbf{x}_1, \mathbf{x}_2, ..., \mathbf{x}_T。
q(x_(t)|x_(t-1))=N(x_(t);sqrt(1-beta _(t))x_(t-1),beta _(t)I)q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I})
其中
beta _(t)\beta_t 是第
tt 步的噪声方差,由调度决定。这个过程的特性是,我们可以直接从
x_(0)\mathbf{x}_0 采样出任意
tt 时刻的
x_(t)\mathbf{x}_t:
x_(t)=sqrt( bar(alpha)_(t))x_(0)+sqrt(1- bar(alpha)_(t))epsilon,quad epsilon∼N(0,I)\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I})
其中
alpha _(t)=1-beta _(t)\alpha_t = 1 - \beta_t,
bar(alpha)_(t)=prod_(s=1)^(t)alpha _(s)\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_s。
-
反向过程(学习的去噪过程):
这是一个神经网络,其任务是
预测添加到数据中的噪声 epsilon\epsilon。
- 输入:
- 当前噪声数据 x_(t)\mathbf{x}_t(在时间步 tt)。
- 时间步 tt 的嵌入向量。
- 全局隐变量 z\mathbf{z}(来自编码器)。
- 条件信息 c\mathbf{c}。
- 输出: 对噪声 epsilon\epsilon 的预测,即 epsilon_(theta)(x_(t),t,z,c)\epsilon_{\theta}(\mathbf{x}_t, t, \mathbf{z}, \mathbf{c})。
- 目的: 给定 x_(t)\mathbf{x}_t,利用 z\mathbf{z} 和 c\mathbf{c} 提供的全局语义和目标信息,预测出噪声 epsilon\epsilon,从而可以计算出去噪后的 x_(t-1)\mathbf{x}_{t-1}。
- 结构: 同样是一个GNN,但接收 z\mathbf{z} 和 c\mathbf{c} 作为全局上下文,注入到每个节点/边的特征更新中。
c. 先验网络 p_(psi)(z|c)p_{\psi}(\mathbf{z} | \mathbf{c})
- 输入: 条件信息 c\mathbf{c}。
- 输出: 先验分布的参数(均值和方差),这是采样阶段用于生成新样本的隐变量分布。
- 目的: 在训练时,编码器产生的后验分布 q_(phi)(z|x,c)q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) 会被拉向这个先验分布,以保证隐空间的规整性。在生成新样本时,我们从该先验中采样一个 z\mathbf{z},然后输入给扩散解码器。
架构流程图:
训练时:
(𝐱, 𝐜) → [编码器 q_ϕ] → 𝐳 ~ q_ϕ(𝐳|𝐱,𝐜)
↓
[扩散解码器 p_θ] 学习从 𝐱_t 预测噪声 ϵ,其中 𝐱_t 由 𝐱_0 加噪得到,解码器接收 (𝐱_t, t, 𝐳, 𝐜)
生成时:
𝐜 → [先验网络 p_ψ] → 𝐳 ~ p_ψ(𝐳|𝐜)
↓
[扩散解码器 p_θ] 从纯噪声 𝐱_T ~ N(0,I) 开始,逐步去噪 T 步,每一步都使用 (𝐱_t, t, 𝐳, 𝐜) 预测噪声
↓
𝐱_0 (生成的结构)
3. 训练方法
CDVAE 通过优化一个变分下界 来训练。损失函数由三部分组成:
总损失函数:
L_("CDVAE")=ubrace(E_(q_(phi)(z|x,c))[-log p_(theta)(x|z,c)])_("重建项 "L_("rec"))+ubrace(beta*D_("KL")(q_(phi)(z|x,c)||p_(psi)(z|c)))_("KL正则项 "L_("KL"))+ubrace(lambda*L_("prop"))_("属性预测项")\mathcal{L}_{\text{CDVAE}} = \underbrace{\mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x},\mathbf{c})} \left[ -\log p_{\theta}(\mathbf{x} | \mathbf{z}, \mathbf{c}) \right]}_{\text{重建项 } \mathcal{L}_{\text{rec}}} + \underbrace{\beta \cdot D_{\text{KL}} \left( q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) \| p_{\psi}(\mathbf{z} | \mathbf{c}) \right)}_{\text{KL正则项 } \mathcal{L}_{\text{KL}}} + \underbrace{\lambda \cdot \mathcal{L}_{\text{prop}}}_{\text{属性预测项}}
逐项解释:
a. 重建项 L_("rec")\mathcal{L}_{\text{rec}}
这是训练扩散解码器的核心。它衡量模型从隐变量
z\mathbf{z} 和条件
c\mathbf{c} 重建原始数据
x\mathbf{x} 的能力。在扩散模型中,这个项被重参数化为一个
去噪分数匹配目标:
L_("rec")=E_(t∼[1,T],epsilon∼N(0,I))[w_(t)*||epsilon-epsilon_(theta)(x_(t),t,z,c)||^(2)]\mathcal{L}_{\text{rec}} = \mathbb{E}_{t \sim [1,T], \epsilon \sim \mathcal{N}(0,\mathbf{I})} \left[ w_t \cdot \| \epsilon - \epsilon_{\theta}(\mathbf{x}_t, t, \mathbf{z}, \mathbf{c}) \|^2 \right]
- 符号解释:
- tt: 从1到T均匀采样的时间步。
- epsilon\epsilon: 实际添加到原始数据 x_(0)\mathbf{x}_0 中的随机噪声。
- x_(t)=sqrt( bar(alpha)_(t))x_(0)+sqrt(1- bar(alpha)_(t))epsilon\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon: 根据噪声调度在第 tt 步的带噪数据。
- epsilon_(theta)(...)\epsilon_{\theta}(...): 扩散解码器网络,目标是预测噪声 epsilon\epsilon。
- w_(t)w_t: 时间步相关的权重,通常根据调度设置(如 w_(t)=1w_t = 1 或 w_(t)=1//sqrt(1- bar(alpha)_(t))w_t = 1 / \sqrt{1-\bar{\alpha}_t})。
- 直观理解: 网络学习在任意噪声水平 tt 下,给定全局语义 z\mathbf{z} 和目标 c\mathbf{c},如何从噪声数据中恢复出干净数据。这是一种更稳健的重建目标。
b. KL正则项 L_("KL")\mathcal{L}_{\text{KL}}
L_("KL")=D_("KL")(q_(phi)(z|x,c)||p_(psi)(z|c))\mathcal{L}_{\text{KL}} = D_{\text{KL}} \left( q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) \| p_{\psi}(\mathbf{z} | \mathbf{c}) \right)
- 符号解释:
- q_(phi)(z|x,c)=N(mu_(phi),sigma_(phi)^(2))q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) = \mathcal{N}(\mu_{\phi}, \sigma_{\phi}^2): 编码器产生的后验分布。
- p_(psi)(z|c)=N(mu_(psi),sigma_(psi)^(2))p_{\psi}(\mathbf{z} | \mathbf{c}) = \mathcal{N}(\mu_{\psi}, \sigma_{\psi}^2): 先验网络产生的条件先验分布。
- D_("KL")D_{\text{KL}}: Kullback-Leibler散度,衡量两个分布的差异。
- 作用:
- 充当正则化器,防止编码器为每个样本产生过于特异的隐变量(避免退化为普通自编码器)。
- 确保在生成时,从先验 p_(psi)(z|c)p_{\psi}(\mathbf{z}|\mathbf{c}) 中采样的 z\mathbf{z} 与训练时编码器产生的 z\mathbf{z} 来自相似的分布,从而保证生成质量。
- beta\beta 是控制正则化强度的超参数(beta\beta-VAE思想)。
c. 属性预测项 L_("prop")\mathcal{L}_{\text{prop}} (可选但常见)
L_("prop")=||f_("pred")(z)-y||^(2)\mathcal{L}_{\text{prop}} = \| f_{\text{pred}}(\mathbf{z}) - y \|^2
- 符号解释:
- f_("pred")f_{\text{pred}}: 一个附加的小型预测网络(MLP),以隐变量 z\mathbf{z} 为输入。
- yy: 与结构 x\mathbf{x} 对应的真实属性值(如能量、溶解度)。
- 作用:
- 增强可控性: 强制隐变量 z\mathbf{z} 编码与目标属性相关的信息,使得在隐空间内沿特定方向移动可以改变生成结构的属性。
- 辅助训练: 提供额外的监督信号,帮助学习更有意义的隐表示。
- lambda\lambda 是其权重超参数。
训练流程:
- 从数据集中采样一个批次的 (x,c,y)(\mathbf{x}, \mathbf{c}, y)。
- 将 (x,c)(\mathbf{x}, \mathbf{c}) 输入编码器,得到后验分布参数,并通过重参数化技巧采样隐变量 z\mathbf{z}。
- 计算KL损失 L_("KL")\mathcal{L}_{\text{KL}},需要将 c\mathbf{c} 输入先验网络得到先验分布参数。
- 为重建损失做准备:随机采样时间步 tt,根据调度为 x\mathbf{x} 加噪得到 x_(t)\mathbf{x}_t。
- 将 (x_(t),t,z,c)(\mathbf{x}_t, t, \mathbf{z}, \mathbf{c}) 输入扩散解码器,预测噪声 epsilon_(theta)\epsilon_{\theta},并与真实噪声 epsilon\epsilon 计算 L_("rec")\mathcal{L}_{\text{rec}}。
- 将 z\mathbf{z} 输入属性预测网络,计算 L_("prop")\mathcal{L}_{\text{prop}}(如果使用)。
- 将三项损失加权求和,反向传播,更新编码器 (phi)(\phi)、扩散解码器 (theta)(\theta)、先验网络 (psi)(\psi) 和属性预测网络的参数。
4. 使用方法
训练好的CDVAE模型主要有三种使用模式:
a. 无条件/条件生成
- 输入: 一个目标条件 c_("target")\mathbf{c}_{\text{target}}(例如,一个特定的化学式或一个属性值范围)。
- 过程:
- 将 c_("target")\mathbf{c}_{\text{target}} 输入先验网络 p_(psi)p_{\psi},得到先验分布 p_(psi)(z|c_("target"))p_{\psi}(\mathbf{z} | \mathbf{c}_{\text{target}})。
- 从该分布中采样一个隐变量 z_("sample")\mathbf{z}_{\text{sample}}。
- 从标准高斯噪声 x_(T)∼N(0,I)\mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I}) 开始。
- 进行 TT 步迭代去噪:对于 t=T,T-1,...,1t = T, T-1, ..., 1,调用扩散解码器 epsilon_(theta)\epsilon_{\theta} 预测噪声,并根据DDPM或DDIM等采样算法计算 x_(t-1)\mathbf{x}_{t-1}。每一步都传入 (x_(t),t,z_("sample"),c_("target"))(\mathbf{x}_t, t, \mathbf{z}_{\text{sample}}, \mathbf{c}_{\text{target}})。
- 最终得到生成的结构 x_(0)\mathbf{x}_0。
b. 重构与插值
- 重构: 给定一个已知结构 x_("input")\mathbf{x}_{\text{input}} 及其条件 c\mathbf{c},通过编码器得到其隐变量 z_("input")\mathbf{z}_{\text{input}},然后用扩散解码器进行重构。这可以测试模型的表示能力。
- 插值: 在两个结构 (x_(A),c_(A))(\mathbf{x}_A, \mathbf{c}_A) 和 (x_(B),c_(B))(\mathbf{x}_B, \mathbf{c}_B) 对应的隐变量 z_(A)\mathbf{z}_A 和 z_(B)\mathbf{z}_B 之间进行线性插值:z_("interp")=(1-alpha)z_(A)+alphaz_(B)\mathbf{z}_{\text{interp}} = (1-\alpha)\mathbf{z}_A + \alpha\mathbf{z}_B。然后固定一个条件(或混合条件),用 z_("interp")\mathbf{z}_{\text{interp}} 进行生成,可以得到在两个结构之间平滑过渡的一系列新结构。
c. 基于属性的优化与搜索
这是CDVAE最强大的应用之一。
- 过程:
- 在隐空间 z\mathbf{z} 中定义或学习一个“属性方向” d\mathbf{d}(例如,通过回归属性预测器 f_("pred")f_{\text{pred}} 的梯度:d=grad_(z)f_("pred")(z)\mathbf{d} = \nabla_{\mathbf{z}} f_{\text{pred}}(\mathbf{z}),指向属性增加的方向)。
- 从一个起点 z_(0)\mathbf{z}_0 开始,沿方向 d\mathbf{d} 移动:z_("new")=z_(0)+eta*d\mathbf{z}_{\text{new}} = \mathbf{z}_0 + \eta \cdot \mathbf{d}。
- 用 z_("new")\mathbf{z}_{\text{new}} 和给定的条件 c\mathbf{c} 进行生成,得到的新结构 x_("new")\mathbf{x}_{\text{new}} 将具有更高的目标属性值。
- 这个过程可以迭代进行,在隐空间中高效地搜索满足特定属性要求的候选结构。
总结
CDVAE 是一个用于复杂结构化数据生成的强大统一框架。它通过扩散过程解决了传统VAE生成质量不高的问题,通过条件机制实现了可控生成,并通过连续的隐空间支持高效的优化和探索。其训练通过一个结合了扩散去噪损失、KL散度正则化和属性预测损失的变分目标来实现,在材料科学和药物发现等领域具有广泛的应用前景。