CDVAE 是 Conditional Diffusion Variational Autoencoder 的缩写,它将扩散模型(Diffusion Models)的核心思想集成到变分自编码器(VAE)的框架中,并增加了条件生成的能力。它主要用于结构化数据(如分子、晶体材料、点云)的生成与优化

1. 原理与核心思想

核心目标: 学习一个条件概率分布 p ( x | c ) p ( x | c ) p(x|c)p(\mathbf{x} | \mathbf{c})p(x|c),其中 x x x\mathbf{x}x 是我们想要生成的结构化对象(例如一个分子的3D构象), c c c\mathbf{c}c 是给定的条件(例如该分子的化学式、目标属性如药物活性、或材料的带隙)。
核心创新: CDVAE 结合了三种范式的优点:
  1. VAE: 提供一个低维、连续的隐空间,允许对生成过程进行平滑的插值和有意义的语义操作。
  2. 扩散模型: 作为一个强大的解码器/生成器,通过一个逐步去噪的过程生成数据,通常比传统VAE的解码器能产生质量更高、更多样化的样本。
  3. 条件生成: 通过将条件信息 c c c\mathbf{c}c 注入到VAE的编码器和扩散解码器的每一步中,实现对生成结果的精确控制。
为什么有效: 传统VAE直接通过一个神经网络解码器从隐变量 z z z\mathbf{z}z 映射到 x x x\mathbf{x}x,对于复杂的结构化数据(如分子的3D原子位置和类型),这个一步到位的映射学习非常困难,容易导致模糊或无效的生成结果。扩散模型将这个过程分解为许多小的、易于学习的去噪步骤,极大地提高了生成质量和训练稳定性。

2. 模型架构

CDVAE 通常包含三个主要组件:

a. 编码器 q ϕ ( z | x , c ) q ϕ ( z | x , c ) q_(phi)(z|x,c)q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c})qϕ(z|x,c)

  • 输入: 原始结构数据 x x x\mathbf{x}x 和条件信息 c c c\mathbf{c}c
  • 输出: 一个多元高斯分布的参数(均值和方差),从中采样得到全局隐变量 z z z\mathbf{z}z
  • 目的: 将高维、离散/连续混合的结构 x x x\mathbf{x}x 压缩成一个低维、连续的语义表示 z z z\mathbf{z}z,并且这个表示与条件 c c c\mathbf{c}c 相关联。
  • 结构: 通常是一个图神经网络(GNN),因为输入 x x x\mathbf{x}x(分子/晶体)天然可以用图表示(原子为节点,化学键为边)。

b. 扩散解码器 p θ ( x | z , c ) p θ ( x | z , c ) p_(theta)(x|z,c)p_{\theta}(\mathbf{x} | \mathbf{z}, \mathbf{c})pθ(x|z,c)

这是CDVAE的核心。它是一个以时间为条件的去噪模型
  • 前向过程(固定的加噪过程)
    给定一个从编码器得到的“干净”结构 x 0 x 0 x_(0)\mathbf{x}_0x0(即原始数据),我们按照一个预定义的噪声调度(schedule)逐步添加高斯噪声,生成一系列噪声越来越大的隐变量 x 1 , x 2 , . . . , x T x 1 , x 2 , . . . , x T x_(1),x_(2),...,x_(T)\mathbf{x}_1, \mathbf{x}_2, ..., \mathbf{x}_Tx1,x2,...,xT
    q ( x t | x t 1 ) = N ( x t ; 1 β t x t 1 , β t I ) q ( x t | x t 1 ) = N ( x t ; 1 β t x t 1 , β t I ) q(x_(t)|x_(t-1))=N(x_(t);sqrt(1-beta _(t))x_(t-1),beta _(t)I)q(\mathbf{x}_t | \mathbf{x}_{t-1}) = \mathcal{N}(\mathbf{x}_t; \sqrt{1 - \beta_t} \mathbf{x}_{t-1}, \beta_t \mathbf{I})q(xt|xt1)=N(xt;1βtxt1,βtI)
    其中 β t β t beta _(t)\beta_tβt 是第 t t ttt 步的噪声方差,由调度决定。这个过程的特性是,我们可以直接从 x 0 x 0 x_(0)\mathbf{x}_0x0 采样出任意 t t ttt 时刻的 x t x t x_(t)\mathbf{x}_txt
    x t = α ¯ t x 0 + 1 α ¯ t ϵ , ϵ N ( 0 , I ) x t = α ¯ t x 0 + 1 α ¯ t ϵ , ϵ N ( 0 , I ) x_(t)=sqrt( bar(alpha)_(t))x_(0)+sqrt(1- bar(alpha)_(t))epsilon,quad epsilon∼N(0,I)\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \epsilon, \quad \epsilon \sim \mathcal{N}(0, \mathbf{I})xt=α¯tx0+1α¯tϵ,ϵN(0,I)
    其中 α t = 1 β t α t = 1 β t alpha _(t)=1-beta _(t)\alpha_t = 1 - \beta_tαt=1βt α ¯ t = s = 1 t α s α ¯ t = s = 1 t α s bar(alpha)_(t)=prod_(s=1)^(t)alpha _(s)\bar{\alpha}_t = \prod_{s=1}^{t} \alpha_sα¯t=s=1tαs
  • 反向过程(学习的去噪过程)
    这是一个神经网络,其任务是预测添加到数据中的噪声 ϵ ϵ epsilon\epsilonϵ
    • 输入
      1. 当前噪声数据 x t x t x_(t)\mathbf{x}_txt(在时间步 t t ttt)。
      2. 时间步 t t ttt 的嵌入向量。
      3. 全局隐变量 z z z\mathbf{z}z(来自编码器)。
      4. 条件信息 c c c\mathbf{c}c
    • 输出: 对噪声 ϵ ϵ epsilon\epsilonϵ 的预测,即 ϵ θ ( x t , t , z , c ) ϵ θ ( x t , t , z , c ) epsilon_(theta)(x_(t),t,z,c)\epsilon_{\theta}(\mathbf{x}_t, t, \mathbf{z}, \mathbf{c})ϵθ(xt,t,z,c)
    • 目的: 给定 x t x t x_(t)\mathbf{x}_txt,利用 z z z\mathbf{z}z c c c\mathbf{c}c 提供的全局语义和目标信息,预测出噪声 ϵ ϵ epsilon\epsilonϵ,从而可以计算出去噪后的 x t 1 x t 1 x_(t-1)\mathbf{x}_{t-1}xt1
    • 结构: 同样是一个GNN,但接收 z z z\mathbf{z}z c c c\mathbf{c}c 作为全局上下文,注入到每个节点/边的特征更新中。

c. 先验网络 p ψ ( z | c ) p ψ ( z | c ) p_(psi)(z|c)p_{\psi}(\mathbf{z} | \mathbf{c})pψ(z|c)

  • 输入: 条件信息 c c c\mathbf{c}c
  • 输出: 先验分布的参数(均值和方差),这是采样阶段用于生成新样本的隐变量分布。
  • 目的: 在训练时,编码器产生的后验分布 q ϕ ( z | x , c ) q ϕ ( z | x , c ) q_(phi)(z|x,c)q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c})qϕ(z|x,c) 会被拉向这个先验分布,以保证隐空间的规整性。在生成新样本时,我们从该先验中采样一个 z z z\mathbf{z}z,然后输入给扩散解码器。
架构流程图
训练时:
(𝐱, 𝐜) → [编码器 q_ϕ] → 𝐳 ~ q_ϕ(𝐳|𝐱,𝐜)
                         ↓
           [扩散解码器 p_θ] 学习从 𝐱_t 预测噪声 ϵ,其中 𝐱_t 由 𝐱_0 加噪得到,解码器接收 (𝐱_t, t, 𝐳, 𝐜)

生成时:
𝐜 → [先验网络 p_ψ] → 𝐳 ~ p_ψ(𝐳|𝐜)
                     ↓
        [扩散解码器 p_θ] 从纯噪声 𝐱_T ~ N(0,I) 开始,逐步去噪 T 步,每一步都使用 (𝐱_t, t, 𝐳, 𝐜) 预测噪声
                     ↓
                    𝐱_0 (生成的结构)

3. 训练方法

CDVAE 通过优化一个变分下界 来训练。损失函数由三部分组成:

总损失函数:

L CDVAE = E q ϕ ( z | x , c ) [ log p θ ( x | z , c ) ] 重建项 L rec + β D KL ( q ϕ ( z | x , c ) p ψ ( z | c ) ) KL正则项 L KL + λ L prop 属性预测项 L CDVAE = E q ϕ ( z | x , c ) log p θ ( x | z , c ) 重建项  L rec + β D KL q ϕ ( z | x , c ) p ψ ( z | c ) KL正则项  L KL + λ L prop 属性预测项 L_("CDVAE")=ubrace(E_(q_(phi)(z|x,c))[-log p_(theta)(x|z,c)])_("重建项 "L_("rec"))+ubrace(beta*D_("KL")(q_(phi)(z|x,c)||p_(psi)(z|c)))_("KL正则项 "L_("KL"))+ubrace(lambda*L_("prop"))_("属性预测项")\mathcal{L}_{\text{CDVAE}} = \underbrace{\mathbb{E}_{q_{\phi}(\mathbf{z}|\mathbf{x},\mathbf{c})} \left[ -\log p_{\theta}(\mathbf{x} | \mathbf{z}, \mathbf{c}) \right]}_{\text{重建项 } \mathcal{L}_{\text{rec}}} + \underbrace{\beta \cdot D_{\text{KL}} \left( q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) \| p_{\psi}(\mathbf{z} | \mathbf{c}) \right)}_{\text{KL正则项 } \mathcal{L}_{\text{KL}}} + \underbrace{\lambda \cdot \mathcal{L}_{\text{prop}}}_{\text{属性预测项}}LCDVAE=Eqϕ(z|x,c)[logpθ(x|z,c)]重建项 Lrec+βDKL(qϕ(z|x,c)pψ(z|c))KL正则项 LKL+λLprop属性预测项

逐项解释:

a. 重建项 L rec L rec L_("rec")\mathcal{L}_{\text{rec}}Lrec
这是训练扩散解码器的核心。它衡量模型从隐变量 z z z\mathbf{z}z 和条件 c c c\mathbf{c}c 重建原始数据 x x x\mathbf{x}x 的能力。在扩散模型中,这个项被重参数化为一个去噪分数匹配目标
L rec = E t [ 1 , T ] , ϵ N ( 0 , I ) [ w t ϵ ϵ θ ( x t , t , z , c ) 2 ] L rec = E t [ 1 , T ] , ϵ N ( 0 , I ) w t ϵ ϵ θ ( x t , t , z , c ) 2 L_("rec")=E_(t∼[1,T],epsilon∼N(0,I))[w_(t)*||epsilon-epsilon_(theta)(x_(t),t,z,c)||^(2)]\mathcal{L}_{\text{rec}} = \mathbb{E}_{t \sim [1,T], \epsilon \sim \mathcal{N}(0,\mathbf{I})} \left[ w_t \cdot \| \epsilon - \epsilon_{\theta}(\mathbf{x}_t, t, \mathbf{z}, \mathbf{c}) \|^2 \right]Lrec=Et[1,T],ϵN(0,I)[wtϵϵθ(xt,t,z,c)2]
  • 符号解释
    • t t ttt: 从1到T均匀采样的时间步。
    • ϵ ϵ epsilon\epsilonϵ: 实际添加到原始数据 x 0 x 0 x_(0)\mathbf{x}_0x0 中的随机噪声。
    • x t = α ¯ t x 0 + 1 α ¯ t ϵ x t = α ¯ t x 0 + 1 α ¯ t ϵ x_(t)=sqrt( bar(alpha)_(t))x_(0)+sqrt(1- bar(alpha)_(t))epsilon\mathbf{x}_t = \sqrt{\bar{\alpha}_t} \mathbf{x}_0 + \sqrt{1 - \bar{\alpha}_t} \epsilonxt=α¯tx0+1α¯tϵ: 根据噪声调度在第 t t ttt 步的带噪数据。
    • ϵ θ ( . . . ) ϵ θ ( . . . ) epsilon_(theta)(...)\epsilon_{\theta}(...)ϵθ(...): 扩散解码器网络,目标是预测噪声 ϵ ϵ epsilon\epsilonϵ
    • w t w t w_(t)w_twt: 时间步相关的权重,通常根据调度设置(如 w t = 1 w t = 1 w_(t)=1w_t = 1wt=1 w t = 1 / 1 α ¯ t w t = 1 / 1 α ¯ t w_(t)=1//sqrt(1- bar(alpha)_(t))w_t = 1 / \sqrt{1-\bar{\alpha}_t}wt=1/1α¯t)。
  • 直观理解: 网络学习在任意噪声水平 t t ttt 下,给定全局语义 z z z\mathbf{z}z 和目标 c c c\mathbf{c}c,如何从噪声数据中恢复出干净数据。这是一种更稳健的重建目标。
b. KL正则项 L KL L KL L_("KL")\mathcal{L}_{\text{KL}}LKL
L KL = D KL ( q ϕ ( z | x , c ) p ψ ( z | c ) ) L KL = D KL q ϕ ( z | x , c ) p ψ ( z | c ) L_("KL")=D_("KL")(q_(phi)(z|x,c)||p_(psi)(z|c))\mathcal{L}_{\text{KL}} = D_{\text{KL}} \left( q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) \| p_{\psi}(\mathbf{z} | \mathbf{c}) \right)LKL=DKL(qϕ(z|x,c)pψ(z|c))
  • 符号解释
    • q ϕ ( z | x , c ) = N ( μ ϕ , σ ϕ 2 ) q ϕ ( z | x , c ) = N ( μ ϕ , σ ϕ 2 ) q_(phi)(z|x,c)=N(mu_(phi),sigma_(phi)^(2))q_{\phi}(\mathbf{z} | \mathbf{x}, \mathbf{c}) = \mathcal{N}(\mu_{\phi}, \sigma_{\phi}^2)qϕ(z|x,c)=N(μϕ,σϕ2): 编码器产生的后验分布
    • p ψ ( z | c ) = N ( μ ψ , σ ψ 2 ) p ψ ( z | c ) = N ( μ ψ , σ ψ 2 ) p_(psi)(z|c)=N(mu_(psi),sigma_(psi)^(2))p_{\psi}(\mathbf{z} | \mathbf{c}) = \mathcal{N}(\mu_{\psi}, \sigma_{\psi}^2)pψ(z|c)=N(μψ,σψ2): 先验网络产生的条件先验分布
    • D KL D KL D_("KL")D_{\text{KL}}DKL: Kullback-Leibler散度,衡量两个分布的差异。
  • 作用
    1. 充当正则化器,防止编码器为每个样本产生过于特异的隐变量(避免退化为普通自编码器)。
    2. 确保在生成时,从先验 p ψ ( z | c ) p ψ ( z | c ) p_(psi)(z|c)p_{\psi}(\mathbf{z}|\mathbf{c})pψ(z|c) 中采样的 z z z\mathbf{z}z 与训练时编码器产生的 z z z\mathbf{z}z 来自相似的分布,从而保证生成质量。
    3. β β beta\betaβ 是控制正则化强度的超参数( β β beta\betaβ-VAE思想)。
c. 属性预测项 L prop L prop L_("prop")\mathcal{L}_{\text{prop}}Lprop (可选但常见)
L prop = f pred ( z ) y 2 L prop = f pred ( z ) y 2 L_("prop")=||f_("pred")(z)-y||^(2)\mathcal{L}_{\text{prop}} = \| f_{\text{pred}}(\mathbf{z}) - y \|^2Lprop=fpred(z)y2
  • 符号解释
    • f pred f pred f_("pred")f_{\text{pred}}fpred: 一个附加的小型预测网络(MLP),以隐变量 z z z\mathbf{z}z 为输入。
    • y y yyy: 与结构 x x x\mathbf{x}x 对应的真实属性值(如能量、溶解度)。
  • 作用
    1. 增强可控性: 强制隐变量 z z z\mathbf{z}z 编码与目标属性相关的信息,使得在隐空间内沿特定方向移动可以改变生成结构的属性。
    2. 辅助训练: 提供额外的监督信号,帮助学习更有意义的隐表示。
    3. λ λ lambda\lambdaλ 是其权重超参数。
训练流程
  1. 从数据集中采样一个批次的 ( x , c , y ) ( x , c , y ) (x,c,y)(\mathbf{x}, \mathbf{c}, y)(x,c,y)
  2. ( x , c ) ( x , c ) (x,c)(\mathbf{x}, \mathbf{c})(x,c) 输入编码器,得到后验分布参数,并通过重参数化技巧采样隐变量 z z z\mathbf{z}z
  3. 计算KL损失 L KL L KL L_("KL")\mathcal{L}_{\text{KL}}LKL,需要将 c c c\mathbf{c}c 输入先验网络得到先验分布参数。
  4. 为重建损失做准备:随机采样时间步 t t ttt,根据调度为 x x x\mathbf{x}x 加噪得到 x t x t x_(t)\mathbf{x}_txt
  5. ( x t , t , z , c ) ( x t , t , z , c ) (x_(t),t,z,c)(\mathbf{x}_t, t, \mathbf{z}, \mathbf{c})(xt,t,z,c) 输入扩散解码器,预测噪声 ϵ θ ϵ θ epsilon_(theta)\epsilon_{\theta}ϵθ,并与真实噪声 ϵ ϵ epsilon\epsilonϵ 计算 L rec L rec L_("rec")\mathcal{L}_{\text{rec}}Lrec
  6. z z z\mathbf{z}z 输入属性预测网络,计算 L prop L prop L_("prop")\mathcal{L}_{\text{prop}}Lprop(如果使用)。
  7. 将三项损失加权求和,反向传播,更新编码器 ( ϕ ) ( ϕ ) (phi)(\phi)(ϕ)、扩散解码器 ( θ ) ( θ ) (theta)(\theta)(θ)、先验网络 ( ψ ) ( ψ ) (psi)(\psi)(ψ) 和属性预测网络的参数。

4. 使用方法

训练好的CDVAE模型主要有三种使用模式:

a. 无条件/条件生成

  • 输入: 一个目标条件 c target c target c_("target")\mathbf{c}_{\text{target}}ctarget(例如,一个特定的化学式或一个属性值范围)。
  • 过程
    1. c target c target c_("target")\mathbf{c}_{\text{target}}ctarget 输入先验网络 p ψ p ψ p_(psi)p_{\psi}pψ,得到先验分布 p ψ ( z | c target ) p ψ ( z | c target ) p_(psi)(z|c_("target"))p_{\psi}(\mathbf{z} | \mathbf{c}_{\text{target}})pψ(z|ctarget)
    2. 从该分布中采样一个隐变量 z sample z sample z_("sample")\mathbf{z}_{\text{sample}}zsample
    3. 从标准高斯噪声 x T N ( 0 , I ) x T N ( 0 , I ) x_(T)∼N(0,I)\mathbf{x}_T \sim \mathcal{N}(0, \mathbf{I})xTN(0,I) 开始。
    4. 进行 T T TTT 步迭代去噪:对于 t = T , T 1 , . . . , 1 t = T , T 1 , . . . , 1 t=T,T-1,...,1t = T, T-1, ..., 1t=T,T1,...,1,调用扩散解码器 ϵ θ ϵ θ epsilon_(theta)\epsilon_{\theta}ϵθ 预测噪声,并根据DDPM或DDIM等采样算法计算 x t 1 x t 1 x_(t-1)\mathbf{x}_{t-1}xt1。每一步都传入 ( x t , t , z sample , c target ) ( x t , t , z sample , c target ) (x_(t),t,z_("sample"),c_("target"))(\mathbf{x}_t, t, \mathbf{z}_{\text{sample}}, \mathbf{c}_{\text{target}})(xt,t,zsample,ctarget)
    5. 最终得到生成的结构 x 0 x 0 x_(0)\mathbf{x}_0x0

b. 重构与插值

  • 重构: 给定一个已知结构 x input x input x_("input")\mathbf{x}_{\text{input}}xinput 及其条件 c c c\mathbf{c}c,通过编码器得到其隐变量 z input z input z_("input")\mathbf{z}_{\text{input}}zinput,然后用扩散解码器进行重构。这可以测试模型的表示能力。
  • 插值: 在两个结构 ( x A , c A ) ( x A , c A ) (x_(A),c_(A))(\mathbf{x}_A, \mathbf{c}_A)(xA,cA) ( x B , c B ) ( x B , c B ) (x_(B),c_(B))(\mathbf{x}_B, \mathbf{c}_B)(xB,cB) 对应的隐变量 z A z A z_(A)\mathbf{z}_AzA z B z B z_(B)\mathbf{z}_BzB 之间进行线性插值: z interp = ( 1 α ) z A + α z B z interp = ( 1 α ) z A + α z B z_("interp")=(1-alpha)z_(A)+alphaz_(B)\mathbf{z}_{\text{interp}} = (1-\alpha)\mathbf{z}_A + \alpha\mathbf{z}_Bzinterp=(1α)zA+αzB。然后固定一个条件(或混合条件),用 z interp z interp z_("interp")\mathbf{z}_{\text{interp}}zinterp 进行生成,可以得到在两个结构之间平滑过渡的一系列新结构。

c. 基于属性的优化与搜索

这是CDVAE最强大的应用之一。
  • 过程
    1. 在隐空间 z z z\mathbf{z}z 中定义或学习一个“属性方向” d d d\mathbf{d}d(例如,通过回归属性预测器 f pred f pred f_("pred")f_{\text{pred}}fpred 的梯度: d = z f pred ( z ) d = z f pred ( z ) d=grad_(z)f_("pred")(z)\mathbf{d} = \nabla_{\mathbf{z}} f_{\text{pred}}(\mathbf{z})d=zfpred(z),指向属性增加的方向)。
    2. 从一个起点 z 0 z 0 z_(0)\mathbf{z}_0z0 开始,沿方向 d d d\mathbf{d}d 移动: z new = z 0 + η d z new = z 0 + η d z_("new")=z_(0)+eta*d\mathbf{z}_{\text{new}} = \mathbf{z}_0 + \eta \cdot \mathbf{d}znew=z0+ηd
    3. z new z new z_("new")\mathbf{z}_{\text{new}}znew 和给定的条件 c c c\mathbf{c}c 进行生成,得到的新结构 x new x new x_("new")\mathbf{x}_{\text{new}}xnew 将具有更高的目标属性值。
    4. 这个过程可以迭代进行,在隐空间中高效地搜索满足特定属性要求的候选结构。

总结

CDVAE 是一个用于复杂结构化数据生成的强大统一框架。它通过扩散过程解决了传统VAE生成质量不高的问题,通过条件机制实现了可控生成,并通过连续的隐空间支持高效的优化和探索。其训练通过一个结合了扩散去噪损失KL散度正则化属性预测损失的变分目标来实现,在材料科学和药物发现等领域具有广泛的应用前景。


用AI生成所需内容的提示技巧:

AI技术、CDVAE 模型、扩散模型:


>> AI热点技术目录