这是用户在 2025-7-10 16:07 为 https://zhuanlan.zhihu.com/p/348498294 保存的双语快照页面,由 沉浸式翻译 提供双语支持。了解如何保存?

机器学习方法—优雅的模型(一):变分自编码器(VAE)

3012 人赞同了该文章

1. Introduction

有了之前损失函数(一):交叉熵与KL散度损失函数(二):MSE、0-1 Loss与Logistic Loss两篇文章的基础,我们现在可以开始做一些真的模型了。

中文社区里有一些对于VAE的介绍,但是我感觉这些往往流于表面,没有很好的把每一步的动机解释清楚。这篇文章将详细对VAE模型进行介绍,包括它的产生动机、数学推导、Conditional VAE扩展,以及它的实现与细节讨论。

希望这篇文章能够更加清楚的写出,为什么我们需要VAE、为什么VAE这样设计。文章较长、推导较多,难免有疏忽和错误,也请读者不吝指正。

2. Motivation

在说VAE之前,自然要先说到传统的自编码器 (Autoencoder)。上图即是一个自编码的实例。自编码器类似于一个非线性的PCA,是一个利用神经网络来给复杂数据降维的模型。现在我们记XX为整个数据集的集合,xix_i是数据集中的一个样本。

自编码器包含一个编码器z=g(X)z = g(X),它的输出zz我们称作编码,zz的维度往往远远小于输入XX的维度。它还包含一个解码器X~=f(z)\tilde{X} = f(z),这个解码器能够通过编码zz得到X~\tilde{X}。

我们希望解码器解码得到的X~\tilde{X}能够尽可能的接近XX,所以自编码器一个常用的损失函数是=XX~2\ell = \|X - \tilde{X}\|^2。这样一来,模型训练结束后,我们就可以认为编码zz囊括了输入数据XX的大部分信息,也因此我们可以直接利用zz表达原始数据,从而达到数据降维的目的。

出于方便,假设现在我们的输入XRC×H×WX \in \mathbb{R}^{C \times H \times W}是一些图片,我们可以训练一个自编码器。它的编码器z=g(X)z = g(X)将每个图片编码成zRdz \in \mathbb{R}^d,它的解码器X~=f(z)\tilde{X} = f(z)利用zz将输入的图片重建为X~RC×H×W\tilde{X} \in \mathbb{R}^{C \times H \times W}

我们现在仔细研究一下这个模型的解码器g:RdRC×H×Wg: \mathbb{R}^d \rightarrow \mathbb{R}^{C \times H \times W}。这个解码器只需要输入某些低维向量zz,就能够输出高维的图片数据XX。那我们能否把这个模型直接当做生成模型,在低维空间Rd\mathbb{R}^d中随机生成某些向量zz,再喂给解码器f(z)f(z)来生成图片呢?

答案是,我们可以这么做,运气好的话我们可以得到一些有用的图片,但是对绝大多数随机生成的zzf(z)f(z)只会生成一些没有意义的噪声。

为什么会这样呢?这是因为我们没有显性的对zz的分布p(z)p(z)进行建模,我们并不知道哪些zz能够生成有用的图片。我们用来训练f(z)f(z)的数据是有限的,ff可能只会对极有限的zz有响应。而Rd\mathbb{R}^d又是一个太大的空间,如果我们只在这个空间上随机采样的话,我们自然不能指望总能恰好采样到能够生成有用的图片的zz。

在Autoencoder的基础上,显性的对zz的分布p(z)p(z)进行建模,使得自编码器成为一个合格的生成模型,我们就得到了Variational Autoencoders,即今天的主题,变分自编码器

3. Derivation

我们在这一章正式对VAE进行推导。对于自编码器来说,zz的分布是不确定的,因此只能在Rd\mathbb{R}^d上采样、碰运气。我们为什么不给定zz一个简单的分布,将采样的空间缩的很小呢?

我们不妨假设zN(0,I)z \sim \mathcal{N}(0, I),其中II代表一个单位矩阵。也就是说,我们将zz看作是一个服从标准多元高斯分布的多维随机变量。

现在我们记zzXX为随机变量,ziz_ixix_i代表随机变量的样本。

在这个架构下,我们可以认为数据集是由某个随机过程生成的,而zz是这个随机过程中的一个不可观测到的隐变量。这个生成数据的随机过程包含两个步骤:

  1. 从先验分布p(z)p(z)中采样得到一个ziz_i
  2. 根据ziz_i,从条件分布p(Xzi)p(X \mid z_i)中采样得到一个数据点xix_i

如果我们能基于这个随机过程进行建模,那么我们可能可以更轻易的得到一个生成模型。

3.1 Decoder

首先,让我们从生成模型的角度来考虑Decoder的架构。

上图是一个Decoder架构的示意图。我们给Decoder输入一个从N(0,I)\mathcal{N}(0, I)中采样得到的ziz_i,其实是希望由θ\theta参数化的Decoder能够学会一个映射,输出ziz_i对应的XX的分布,即pθ(Xzi)p_{\theta}(X \mid z_i)。

让我们假设,给定任意ziz_i后,XX都服从某个各维度独立的多元高斯分布,即:

pθ(Xzi)=N(Xμi(zi;θ),σi2(zi;θ)I).p_{\theta}(X \mid z_i) = \mathcal{N}(X \mid \mu_i^{\prime}(z_i; \theta), \sigma_i^{\prime2}(z_i; \theta) * I).\\ \\

这样一来,我们只需要输入ziz_i给Decoder,然后让它拟合出μiσi2\mu_i^{\prime}和\sigma_i^{\prime2},我们就能知道XziX \mid z_i 的具体分布了。

我们举个例子来更直观的理解这个过程。我们根据分布p(z)p(z)采样出一个ziz_i,一个ziz_i应当对应图片集XX中的某一部分类似的图片。比如说,我们的图片集XX可能是世界上所有的猫,那么抽样得到的一个ziz_i可能代表颜色为橘色,耳朵为立耳的猫;而下次抽样得到的另一个zjz_j可能代表颜色为白色,耳朵为折耳的猫。

我们再假设,在这个ziz_i下,这类立耳橘猫的图片像素值的分布XziX\mid z_i服从一个多元高斯分布N(μi,σi2I)\mathcal{N}(\mu_i^{\prime}, \sigma_i^{\prime2} * I)。这样一来,我们的Decoder只需要通过神经网络,将ziz_i变换为适当的μiσi2{\mu}_i^{\prime}和 {\sigma}_i^{\prime2},我们就得到了这个多元高斯分布。之后我们就可以从N(μi,σi2I)\mathcal{N}(\mu_i^{\prime}, \sigma_i^{\prime2} * I)中采样,得到立耳橘猫的图片了!

3.2 Objective

因为本质上我们希望训练一个生成模型,我们也不妨以一个更加统计的视角来看一个生成模型的目标函数。

对于一个生成模型,我们的终极目标是什么?对,我们就是想对数据本身的分布p(X)p(X)进行建模。如果能成功得到一个逼近真实分布p(X)p(X)pθ(X)p_{\theta}(X),那么我们就能从中进行采样,生成一些可能的数据点。

如上图,我们举个当XX代表所有宝可梦的图片的例子。在得到pθ(X)p_{\theta}(X)后,我们就可以生成一些令pθ(xi)p_{\theta}(x_i)比较大的xix_i,这些xix_i就很可能会是正常的宝可梦的图片。

现在的问题就是,我们怎么对pθ(X)p_{\theta}(X)进行建模呢?

有了之前的铺垫,现在我们有p(z)=N(0,I)pθ(Xz)=N(Xμi(z;θ),σi2(z;θ)I)p(z) = \mathcal{N}(0, I),p_{\theta}(X \mid z) = \mathcal{N}(X \mid \mu_i^{\prime}(z; \theta), \sigma_i^{\prime2}(z; \theta) * I)。易得,

pθ(X)=zpθ(Xz)p(z)dz1mj=1mpθ(Xzj).\begin{aligned} p_{\theta}(X) &= \int_z p_{\theta}(X \mid z) p(z) d z \\ &\approx \frac{1}{m} \sum_{j=1}^m p_{\theta}(X \mid z_j). \end{aligned}\\

这样问题是不是就解决了呢?我们只要从p(z)=N(z0,I)p(z) = \mathcal{N}(z \mid 0, I)里采样许多ziz_i出来,就能算出pθ(X)p_{\theta}(X)。在之前的文章机器学习理论—统计:MLE与MAP中,我们已经介绍过了MLE。在这里,我们就可以利用MLE的思想,让数据集出现的概率最大化,也就是:

θ=argminθi=1nlogpθ(xi)=argminθi=1nlog(1mj=1mpθ(xizj)).\begin{aligned} \theta^* &= \operatorname{argmin}_{\theta} - \sum_{i=1}^n \log p_{\theta}(x_i) \\ &= \operatorname{argmin}_{\theta} - \sum_{i=1}^n \log \left( \frac{1}{m} \sum_{j=1}^m p_{\theta}(x_i \mid z_j) \right). \end{aligned}\\

我们确实可以这样做,但是这样做的代价是极大的。因为往往xix_i的维度会很大,ziz_i的维度也不会很低,并且,对于某个xix_i而言,与之强相关的ziz_i的数量是相对有限的,但是为了找到这些有限的ziz_i,我们可能要进行大量的采样。

所以如果我们希望较为准确的估计pθ(X)p_{\theta}(X)的话,我们可能需要采样极大量ziz_i,只有这样,我们才能让模型知道究竟哪一些ziz_i是与哪一些xix_i对应着的。

因此,直接从p(z)p(z)中采样ziz_i,用来估计pθ(X)p_{\theta}(X)的策略几乎是不可行的。不过解决这个问题的思路也很直觉,那就是在Encoder中引入后验分布pθ(zxi)p_{\theta}(z \mid x_i)。

3.3 Encoder

具体来说,我们怎么在Encoder中利用后验分布呢?假设我们现在有后验分布pθ(zxi)p_{\theta}(z \mid x_i),这样的话,如下图,每次前向传播的时候,我们可以先将xix_i喂给Encoder,算出zxiz\mid x_i服从的分布。之后,我们就可以直接在这个分布中采样出ziz_i,喂给Decoder,然后得到XziX\mid z_i的分布,最后基于MLE优化模型。

在这个策略下,从pθ(zxi)p_{\theta}(z \mid x_i)中采样出来的ziz_i几乎都会和xix_i相关的,对比之前,我们可能就省去了很多采样的步骤,极大的提高了效率。

那现在的问题就是,我们怎么计算pθ(zxi)p_{\theta}(z \mid x_i)呢?我们不妨先尝试下贝叶斯公式:

pθ(zxi)=pθ(xiz)p(z)pθ(xi)=pθ(xiz)p(z)z^pθ(xiz^)p(z^)dz^.\begin{aligned} p_{\theta}(z \mid x_i) &= \frac{p_{\theta}(x_i \mid z) p(z)}{p_{\theta}(x_i)} \\ &= \frac{p_{\theta}(x_i \mid z) p(z)}{\int_{\hat{z}} p_{\theta}(x_i \mid \hat{z}) p(\hat{z}) d \hat{z}}. \end{aligned}\\

辛运的是,我们之前已经假设了pθ(Xz)p(z)p_{\theta}(X \mid z)和p(z)的分布,所以对于上式的分子,我们是可以直接算出来的。不幸的是,上式的分母又有一个积分,如果去估计这个积分的话,又会需要从p(z)p(z)中采样大量的ziz_i。这显然是代价极大,不太可行的。

这时候我们就可以应用变分贝叶斯算法了!我们不妨令由ϕ\phi参数化的Encoder去拟合对任意xix_i的分布qϕ(zxi)q_{\phi}(z \mid x_i),我们希望这个分布能够尽可能的逼近真实的后验分布pθ(zxi)p_{\theta}(z \mid x_i)。如果qϕ(zxi)q_{\phi}(z \mid x_i)能够足够逼近真实的后验分布的话,我们就可以直接通过Encoder得到zxiz \mid x_i的分布了!

我们怎么用神经网络去拟合后验分布呢?和之前一样,我们只要知道这个后验是服从的什么分布,然后让模型拟合这个分布所需的参数就行了。举个例子,如果这个后验分布本质上是一个多元高斯分布,那么我们让Encoder输出μ\muΣ2\Sigma^2就能拟合这个分布了。

回忆一下,我们之前已经对似然pθ(Xz)p_{\theta}(X \mid z)和先验p(z)p(z)的分布做了假设——它们都服从高斯分布。在这种情况下,不难证明,真实的后验分布pθ(zX)p_{\theta}(z \mid X)也服从高斯分布。

那不妨令近似后验分布对任意xix_i都有

qϕ(zxi)=N(zμ(xi;ϕ),σ2(xi;ϕ)I),q_{\phi}(z \mid x_i) = \mathcal{N}(z \mid \mu(x_i; \phi), \sigma^2(x_i; \phi) * I),\\ \\

即,它也是一个各维度独立的多元高斯分布。这样一来,整个VAE的架构就非常明了了。

3.4 Architecture

下图即是VAE的架构示例。其中xi(j)x_i^{(j)}代表第ii个数据点的第jj的特征。

总结一下VAE的架构:

  1. 我们首先给Encoder输入一个数据点xix_i,通过神经网络,我们得到隐变量zz服从的近似后验分布qϕ(zxi)q_{\phi}(z \mid x_i)的参数。我们往往认为后验分布是一个各维度独立的高斯分布,因此令Encoder输出zxiz\mid x_i服从的高斯分布的参数σi2\sigma_i^2μi\mu_i即可。
  2. 有了zxiz \mid x_i分布的参数σi2\sigma_i^2μi\mu_i后,我们从对应的高斯分布中采样出一个ziz_i,这个ziz_i应当代表与xix_i相似的一类样本。
  3. 我们令Decoder拟合似然的分布pθ(Xzi)p_{\theta}(X \mid z_i)。喂给Decoder一个ziz_i,它应当返回XziX \mid z_i服从的分布的参数。我们往往认为似然也服从一个各维度独立的高斯分布,因此令Decoder输出XziX\mid z_i服从的高斯分布的参数σi2\sigma^{\prime2}_iμi\mu^{\prime}_i即可。
  4. 在得到XziX\mid z_i的分布的参数后,理论上我们需要从这个分布中进行采样,来生成可能的数据点xix_i

上述第四点中值得注意的是,在大部分实现中,人们往往不进行采样,而是直接将模型输出的μi\mu^{\prime}_i当作是给定ziz_i生成的数据点xix_i。

除此之外,人们也往往认为pθ(Xzi)p_{\theta}(X \mid z_i)是一个固定方差的各维度独立的多元高斯分布,即pθ(Xzi)=N(Xμi(zi;θ),σ2I)p_{\theta}(X \mid z_i) = \mathcal{N}(X \mid \mu_i^{\prime}(z_i; \theta), \sigma^{\prime2} * I),其中σ2\sigma^{\prime2}是一个人为给定的超参数。这意味着我们实际中并不真的让模型输出σi2\sigma^{\prime2}_i,模型只要输出μi\mu_i^{\prime}就行了。

3.5 Reparameterization Trick

上述VAE的架构应该是比较清晰的,但让我们再仔细研究一下这个架构。尽管现在我们还没有推导得到最终的损失函数,但让我们先假设,在上述步骤4后,我们会接某个损失函数L\mathcal{L}来训练神经网络。

这样的话,从神经网络训练的角度来看,这个架构的前向传播过程是没有问题的,上述步骤1-4均可顺利的进行前向传播,然后计算出损失的值。

然而,令人在意的一个问题是:我们在前向传播的第2步,居然调用了一个"采样函数",从zxiz\mid x_i中采样出来了ziz_i喂给Decoder!那采样函数能够进行反向传播吗?

答案显然是不能的。因此,为了让整个网络能够正常的训练,作者们提出了Reparameterization Trick。这一技巧将上述第2步改为:

  1. 有了zxiz \mid x_i分布的参数σi2\sigma_i^2μi\mu_i后,我们先从N(0,I)\mathcal{N}(0,I)中采样得到一个ϵi\epsilon_i,然后我们令zi=μi+σiϵiz_i = \mu_i + \sigma_i \odot \epsilon_i,这个ziz_i应当代表与xix_i相似的一类样本。

其中,\odot代表逐元素相乘操作。不难证明,此时ziz_i背后的分布依然是由σi2\sigma_i^2μi\mu_i参数化的一个高斯分布。

利用了Reparameterization Trick后,VAE的架构变成了下图中的模样,其中ϵi\epsilon_i可以看作是伴随ziz_i喂给Decoder的一个特征。这样一来,这个架构的前向、反向传播就都能跑通了。

3.6 Evidence Lower Bound

好了,我们已经把VAE的架构定下来了。现在我们只要顺着3.2节中MLE的思想,然后在最大化logpθ(X)\log p_{\theta}(X)时,加入变分推断的思想,引入ELBO (Evidence Lower Bound),我们就能得到一个靠谱的目标函数了。

让我们来推一下:

logpθ(X)=zqϕ(zX)logpθ(X)dz=zqϕ(zX)logpθ(X,z)pθ(zX)dz=zqϕ(zX)log(pθ(X,z)qϕ(zX)qϕ(zX)pθ(zX))dz=zqϕ(zX)logpθ(X,z)qϕ(zX)dz+zqϕ(zX)logqϕ(zX)pθ(zX)dz=(pθ,qϕ)+DKL(qϕ,pθ)(pθ,qϕ)KL.\begin{aligned} \log p_{\theta}(X) &=\int_{z} q_{\phi}(z \mid X) \log p_{\theta}(X) dz \quad 全概率定理\\ &=\int_{z} q_{\phi}(z \mid X) \log \frac{p_{\theta}(X, z)}{p_{\theta}(z \mid X)} dz \quad 贝叶斯定理\\ &=\int_{z} q_{\phi}(z \mid X) \log \left(\frac{p_{\theta}(X, z)}{q_{\phi}(z \mid X)} \cdot \frac{q_{\phi}(z \mid X)}{p_{\theta}(z \mid X)}\right) dz\\ &=\int_{z} q_{\phi}(z \mid X) \log \frac{p_{\theta}(X, z)}{q_{\phi}(z \mid X)} dz + \int_{z} q_{\phi}(z \mid X) \log \frac{q_{\phi}(z \mid X)}{p_{\theta}(z \mid X)} dz\\ &=\ell\left(p_{\theta}, q_{\phi}\right)+D_{K L}\left(q_{\phi}, p_{\theta}\right) \\ & \geq \ell\left(p_{\theta}, q_{\phi}\right) \quad KL散度非负. \end{aligned}\\

我们已经在之前的文章机器学习理论—信息论:自信息、熵、交叉熵与KL散度 中的第四章证明了KL散度是恒大于等于零的,因此显然上式中(pθ,qϕ)\ell\left(p_{\theta}, q_{\phi}\right)logpθ(X)\log p_{\theta}(X)的一个下界,也因此我们称\ell为ELBO (Evidence Lower Bound)。

我们不妨在把上式变换一下,易得:

(pθ,qϕ)=logpθ(X)DKL(qϕ,pθ).\ell\left(p_{\theta}, q_{\phi}\right) = \log p_{\theta}(X) - D_{K L}\left(q_{\phi}, p_{\theta}\right).\\ \\

这个式子实在是太完美了!这个式子告诉我们,我们只需要最大化\ell,就能最大化logpθ(X)\log p_{\theta}(X),并且最小化DKL(qϕ,pθ)D_{K L}\left(q_{\phi}, p_{\theta}\right)。

最大化logpθ(X)\log p_{\theta}(X)的理由是显然的,因为我们希望最大化似然。我们为什么希望最小化DKL(qϕ,pθ)D_{K L}\left(q_{\phi}, p_{\theta}\right)呢?其实原因也是显然的,因为我们希望近似后验qϕ(zX)q_{\phi}(z\mid X)能够逼近真实后验pθ(zX)p_{\theta}(z \mid X),否则的话Encoder可能只能输出一些无意义的分布。

既然我们希望最大化\ell,现在我们进一步对其进行展开,不难得到:

(pθ,qϕ)=zqϕ(zX)logpθ(X,z)qϕ(zX)dz=zqϕ(zX)logpθ(Xz)p(z)qϕ(zX)dz=zqϕ(zX)logp(z)qϕ(zX)dz+zqϕ(zX)logpθ(Xz)dz=DKL(qϕ,p)+Eqϕ[logpθ(Xz)].\begin{aligned} \ell\left(p_{\theta}, q_{\phi}\right) &= \int_{z} q_{\phi}(z \mid X) \log \frac{p_{\theta}(X, z)}{q_{\phi}(z \mid X)} dz \\ &=\int_{z} q_{\phi}(z \mid X) \log \frac{p_{\theta}(X \mid z) p(z)}{q_{\phi}(z \mid X)} dz \quad 贝叶斯定理 \\ &=\int_{z} q_{\phi}(z \mid X) \log \frac{p(z)}{q_{\phi}(z \mid X)} dz + \int_{z} q_{\phi}(z \mid X) \log p_{\theta}(X \mid z) dz \\ &=-D_{K L}\left(q_{\phi}, p\right)+\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(X \mid z)\right]. \end{aligned}\\

让我们再将上述两项分别展开。

首先,让我们看下DKL(qϕ,p)-D_{K L}\left(q_{\phi}, p\right)这一项。人们通常称这一项为Latent Loss或者将其看做一个正则项。回忆一下,我们之前已经假设了qϕ(zX)q_{\phi}(z \mid X) p(z)p(z)均服从高斯分布,辛运的是,在这种情况下,我们能够得到DKL(qϕ,p)D_{K L}\left(q_{\phi}, p\right)的解析解。

更加幸运的是,我们把它们都设成了各维度独立的高斯分布,所以我们可以直接从一维的情况进行推导:

DKL(N(μ,σ2)N(0,1))=z12πσ2exp((zμ)22σ2)log12πσ2exp((zμ)22σ2)12πexp(z22)dz=z((zμ)22σ2+z22logσ)N(μ,σ2)dz=z(zμ)22σ2N(μ,σ2)dz+zz22N(μ,σ2)dzzlogσN(μ,σ2)dz=E[(zμ)2]2σ2+E[z2]2logσ=12(1+σ2+μ2logσ2).\begin{aligned} D_{K L}(\mathcal{N}\left(\mu, \sigma^{2}\right)\| \mathcal{N}(0,1)) &=\int_{z} \frac{1}{\sqrt{2 \pi \sigma^{2}}} \exp \left(-\frac{\left(z-\mu\right)^{2}}{2 \sigma^{2}}\right) \log \frac{\frac{1}{\sqrt{2 \pi \sigma^{2}}} \exp \left(-\frac{\left(z-\mu\right)^{2}}{2 \sigma^{2}}\right)}{\frac{1}{\sqrt{2 \pi}} \exp \left(-\frac{z^{2}}{2}\right)} d z \\ &=\int_{z}\left(\frac{-\left(z-\mu\right)^{2}}{2 \sigma^{2}}+\frac{z^{2}}{2}-\log \sigma\right) \mathcal{N}\left(\mu, \sigma^{2}\right) d z \\ &=-\int_{z} \frac{\left(z-\mu\right)^{2}}{2 \sigma^{2}} \mathcal{N}\left(\mu, \sigma^{2}\right) d z+\int_{z} \frac{z^{2}}{2} \mathcal{N}\left(\mu, \sigma^{2}\right) d z-\int_{z} \log \sigma \mathcal{N}\left(\mu, \sigma^{2}\right) d z \\ &=-\frac{\mathbb{E}\left[\left(z-\mu\right)^{2}\right]}{2 \sigma^{2}}+\frac{\mathbb{E}\left[z^{2}\right]}{2}-\log \sigma \\ &= \frac{1}{2}(-1 + \sigma^2 + \mu^2 - \log \sigma^2). \end{aligned}

当它们都是dd元高斯分布时,易得:

DKL(qϕ(zX),p(z))=j=1d12(1+σ(j)2+μ(j)2logσ(j)2).D_{K L}\left(q_{\phi}(z\mid X), p(z)\right) = \sum_{j=1}^d \frac{1}{2}(-1 + {\sigma^{(j)}}^{2} + {\mu^{(j)}}^{2} - \log {\sigma^{(j)}}^{2}).\\ \\

其中a(j)2{a^{(j)}}^{2}代表向量aa的第jj个元素的平方。

至此,最后的问题就是,Eqϕ[logpθ(Xz)]\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(X \mid z)\right]怎么求呢?这一项往往被称为Reconstruction Loss,人们通常从qϕ(zX)q_{\phi}(z\mid X)中采样多个ziz_i来近似求解这一项,即:

Eqϕ[logpθ(Xz)]1mi=1mlogpθ(Xzi),\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(X \mid z)\right] \approx \frac{1}{m} \sum_{i=1}^{m} \log p_{\theta}\left(X \mid z_{i}\right), \\

其中,ziqϕ(zxi)=N(zμ(xi;ϕ),σ2(xi;ϕ)I)z_{i} \sim q_{\phi}\left(z \mid x_{i}\right)=\mathcal{N}\left(z \mid \mu\left(x_{i} ; \phi\right), \sigma^2\left(x_{i} ; \phi\right) * I\right)。

现在我们来看logpθ(Xzi)\log p_{\theta}\left(X \mid z_{i}\right)这一项怎么展开。我们之前已经假设过XziX\mid z_i服从一个固定方差的各维度独立的多元高斯分布,即pθ(Xzi)=N(Xμi(zi;θ),σ2I)p_{\theta}(X \mid z_i) = \mathcal{N}(X \mid \mu_i^{\prime}(z_i; \theta), \sigma^{\prime2} * I)。

有了之前的文章损失函数(二):MSE、0-1 Loss与Logistic Loss中的2.2节的基础后,我们知道,若假设数据为固定方差的高斯分布,MLE后得到的目标函数,等价于MSE。但我们这里还是先把它写开,设每个数据点xix_i的维度为KK,即XziX\mid z_i服从一个KK维高斯分布,易得:

logpθ(Xzi)=logexp(12(Xμ)TΣ1(Xμ))(2π)k|Σ|=12(Xμ)TΣ1(Xμ)log(2π)k|Σ|=12k=1K(X(k)μ(k))2σ(k)log(2π)Kk=1Kσ(k).\begin{aligned} \log p_{\theta}\left(X \mid z_{i}\right) &= \log \frac{\exp \left(-\frac{1}{2}(X-\mu^{\prime})^{\mathrm{T}} {\Sigma}^{\prime-1}({X}-{\mu^{\prime}})\right)}{\sqrt{(2 \pi)^{k}|{\Sigma^{\prime}}|}} \\ &= -\frac{1}{2}(X-\mu^{\prime})^{\mathrm{T}} {\Sigma}^{\prime-1}({X}-{\mu^{\prime}}) - \log \sqrt{(2 \pi)^{k}|\Sigma^{\prime}|} \\ &= -\frac{1}{2} \sum_{k=1}^K \frac{(X^{(k)}-\mu^{\prime(k)})^2}{\sigma^{\prime(k)}} - \log \sqrt{(2 \pi)^{K}\prod_{k=1}^{K} \sigma^{\prime(k)}}. \end{aligned}\\

这样,我们就有了最终的损失函数所需要的所有模块了。

3.7 Loss Function

让我们把上一节中的推导整合起来。现在希望最小化损失函数:

L=1ni=1n(pθ,qϕ)=1ni=1nDKL(qϕ,p)1ni=1nEqϕ[logpθ(xiz)]=1ni=1nDKL(qϕ,p)1nmi=1nj=1mlogpθ(xizj).\begin{aligned} \mathcal{L} &= - \frac{1}{n} \sum_{i=1}^n \ell(p_{\theta}, q_{\phi}) \\ &= \frac{1}{n} \sum_{i=1}^nD_{K L}\left(q_{\phi}, p\right) - \frac{1}{n} \sum_{i=1}^n \mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(x_i \mid z)\right] \\ &= \frac{1}{n} \sum_{i=1}^nD_{K L}\left(q_{\phi}, p\right) - \frac{1}{nm} \sum_{i=1}^n \sum_{j=1}^{m} \log p_{\theta}\left(x_i \mid z_{j}\right). \end{aligned}\\

上式即是通过从qϕ(zxi)q_{\phi}(z \mid x_i)中采样mmzjz_j,来逼近Eqϕ[logpθ(xiz)]\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(x_i \mid z)\right]。也许我们会好奇,之前两次我们都说积分太难求了,采样逼近代价太大了,所以不能采样逼近,为什么这里又可以采样逼近了呢?

答案就是:之前我们都只能从p(z)p(z)中采样zjz_j,这样的话,采样到和xix_i有关联的zjz_j的概率实在是很低,所以为了更好的逼近积分只能采样大量的zjz_j,这样的代价自然是极大的;然而,在上式中,我们其实是从qϕ(zxi)q_{\phi}(z \mid x_i)中采样得到zjz_j。随着网络的训练,近似后验qϕ(zxi)q_{\phi}(z \mid x_i),很快就会比较接近真实的后验分布。这样一来,我们有很大可能能够在有限次数的采样中,采样到与xix_i关联的zjz_j。

事实上,从经验来看,从qϕ(zxi)q_{\phi}(z \mid x_i)中采样zjz_j估计Eqϕ[logpθ(xiz)]\mathbb{E}_{q_{\phi}}\left[\log p_{\theta}(x_i \mid z)\right]是比较高效的。在实践中我们往往对一个xix_i只采样一个zjz_j,即m=1m=1,就能达到可观的效果。所以我们可以将损失改写,并继续往下展开:

L=1ni=1nDKL(qϕ,p)1ni=1nlogpθ(xizi)=1ni=1nj=1d12(1+σi(j)2+μi(j)2logσi(j)2)1ni=1n(12k=1K(xi(k)μi(k))2σi(k)log(2π)Kk=1Kσi(k)).\begin{aligned} \mathcal{L} &= \frac{1}{n} \sum_{i=1}^nD_{K L}\left(q_{\phi}, p\right) - \frac{1}{n} \sum_{i=1}^n \log p_{\theta}\left(x_i \mid z_{i}\right) \\ &= \frac{1}{n} \sum_{i=1}^n \sum_{j=1}^d \frac{1}{2}(-1 + {\sigma_i^{(j)}}^{2} + {\mu_i^{(j)}}^{2} - \log {\sigma_i^{(j)}}^{2}) \\ &\quad\quad\quad - \frac{1}{n} \sum_{i=1}^n \left( -\frac{1}{2} \sum_{k=1}^K \frac{(x_i^{(k)}-\mu_i^{\prime(k)})^2}{\sigma_i^{\prime(k)}} - \log \sqrt{(2 \pi)^{K}\prod_{k=1}^{K} \sigma_i^{\prime(k)}} \right). \end{aligned}\\

值得注意的是,我们已经假设了pθ(Xzi)p_{\theta}(X \mid z_i)对任意ziz_i均是方差固定的各维度独立的KK维高斯分布,我们不妨令超参数σ\sigma^{\prime}为元素值全为12\frac{1}{2}KK维向量。这样一来,损失可以改写为:

(1)L=1ni=1nj=1d12(1+σi(j)2+μi(j)2logσi(j)2)+1ni=1nxiμi2.\mathcal{L} = \frac{1}{n} \sum_{i=1}^n \sum_{j=1}^d \frac{1}{2}(-1 + {\sigma_i^{(j)}}^{2} + {\mu_i^{(j)}}^{2} - \log {\sigma_i^{(j)}}^{2}) + \frac{1}{n} \sum_{i=1}^n \|x_i - \mu_i^{\prime}\|^2. \tag{1}\\

其中,xix_i代表第ii个样本,是Encoder的输入。μi\mu_iσi2\sigma_i^2是Encoder的输出,代表zxiz \mid x_i的分布的参数。ziz_i是从zxiz\mid x_i中采样得到的一个样本,它是Decoder的输入。μi\mu_i^{\prime}是Decoder的输出,代表利用ziz_i解码后对应的数据点x~i\tilde{x}_i。

到这里,我们终于得到了在假设先验、后验、似然均是高斯分布的情况下,VAE最终的损失函数。值得一提的是,通常人们采用高斯分布只是因为其简便性。我们也可以根据数据的情况,假设更加复杂分布来推导、训练VAE。在这种情况下,VAE可能计算会更加复杂,但也可能会得到更强的表达能力。

4. Conditional VAE

根据上面的推导,我们已经可以训练得到一个原版的VAE了。模型训练结束后,我们从p(z)p(z)中采样得到ziz_i,再喂给Decoder,就能生成可能的数据点了。

但这里有个问题,尽管现在我们几乎可以确保从p(z)p(z)中采样得到的ziz_i,总能重建出某个xix_i,但是我们无法控制生成的是哪一类xix_i。

举个MNIST手写数字的例子,原版VAE只能采样得到ziz_i后随机生成数字。然而,更多的时候我们可能会希望模型能够生成我们指定的数字。这就引出了CVAE (Conditional VAE)。

假设我们现在的数据集为XX,我们现在希望利用它的标注YY来控制生成的结果。在MNIST的场景下,就是我们希望能够告诉Decoder:我现在想生成一个标注为"7"的数字,你帮我生成一个看看。

CVAE的思路非常简单,这里我们简单介绍一下。

  1. 原来MLE是最大化数据集出现的概率,也就是对pθ(X)p_{\theta}(X)建模,那么现在我们需要对pθ(X|Y)p_{\theta}(X|Y)建模。
  2. 原来我们对p(z)p(z)进行建模,现在对p(zyi)p(z\mid y_i)建模。
  3. 原来Decoder是对似然pθ(Xzi)p_{\theta}\left(X \mid z_{i}\right)建模,现在即是对pθ(Xzi,yi)p_{\theta}\left(X \mid z_{i}, y_i\right)建模。
  4. 原来Encoder是对近似后验qϕ(zxi)q_{\phi}\left(z \mid x_{i}\right)建模,现在则需要对qϕ(zxi,yi)q_{\phi}\left(z \mid x_{i}, y_i\right)建模。

顺着推导,到最后我们其实只需要让Encoder和Decoder由yiy_i"参数化"就好。这里做法就很多了,一个直观的做法是将yiy_i作为Encoder和Decoder的输入,这样它们不就等于被yiy_i"参数化"了嘛。

5. Implementation

我们在VAE.py中实现了VAE和CVAE。VAE的实现非常简单,主要就是损失函数的实现。我们在代码中的变量名与该文章中的符号是一致的。

下图是我在MNIST上跑的一组示例。

也许我们会注意到,VAE的实现中,人们往往令Encoder输出logσ2\log \sigma^2,而不直接输出σ\sigma。这是因为根据定义,我们必须让模型输出σ0\sigma \geq 0。出于方便,我们通过取对数后再取指数的方法,获得σ\sigma。而取平方只是为了计算损失的时候不再需要取平方。

除此之外,在VAE损失函数的实现中,有一个更需要注意的地方。我们先把之前推的损失函数抄下来:

L=1ni=1nj=1d12(1+σi(j)2+μi(j)2logσi(j)2)+1ni=1nxiμi2.\mathcal{L} = \frac{1}{n} \sum_{i=1}^n \sum_{j=1}^d \frac{1}{2}(-1 + {\sigma_i^{(j)}}^{2} + {\mu_i^{(j)}}^{2} - \log {\sigma_i^{(j)}}^{2}) + \frac{1}{n} \sum_{i=1}^n \|x_i - \mu_i^{\prime}\|^2.\\ \\

可见,上式中第二部分有一个类似MSE的项1ni=1nxiμi2\frac{1}{n} \sum_{i=1}^n \|x_i - \mu_i^{\prime}\|^2。也因此,很多基于Pytorch实现VAE的Repo,直接采用F.mse_loss(mu_prime, x, reduction='mean')来计算这一项。这是错误的!

xix_i的维度为KK,Pytorch中的F.mse_loss等价于:

1nKi=1nxiμi2.\frac{1}{nK} \sum_{i=1}^n \|x_i - \mu_i^{\prime}\|^2.\\ \\

如果单纯的使用MSE损失训练模型的话,常数项的改变并不会影响模型的结果。但是在VAE中,Reconstruction Loss这一项的常数项是有意义的。

直观的来说,这一的常数项控制Reconstruction Loss和Latent Loss之间的权重。如果利用F.mse_loss实现的话,等价于将Reconstruction Loss的权重降的很低,Decoder将无法准确重建xix_i。

抽象的来说,这一常数项代表Decoder拟合的分布pθ(Xzi)p_{\theta}\left(X \mid z_{i}\right)的方差σ2\sigma^{\prime2}。对于图片生成模型,KK往往非常大,比如MNIST里K=28×28K=28 \times 28。平白无故的多除以了个KK等价于我们将pθ(Xzi)p_{\theta}\left(X \mid z_{i}\right)的方差设的非常大,那它生成的图片全都是噪声也不会令人惊讶。也因此,我们往往在设置超参数σ\sigma^{\prime}的时候,必然将其设置的较小。

6. Discussion

VAE中最老生常谈的问题就是,它为什么生成的图片是模糊的?

我在寻找这个问题的答案的时候,从Reddit的一个Post上看到一个高赞回答:

Vanilla VAEs with Gaussian posteriors / priors and factorized pixel distributions aren't blurry, they're noisy. People tend to show the mean value of p(x|z) rather than drawing samples from it. Hence the reported blurry samples aren't actually samples from the model, and they don't reveal the extent to which variability is captured by pixel noise. Real samples would typically demonstrate salt and pepper noise due to independent samples from the pixel distributions.

知乎上也有引用这段话的关于VAE的文章

这一类回答的意思是:高斯分布假设下,VAE生成的图像并不模糊,而是因为有噪声。为什么呢?因为我们本应该利用Decoder拟合一个高斯分布,然后从这个分布中采样得到xix_i的。但是人们偷懒,直接认为拟合出的高斯分布的均值μ\mu^{\prime}就是生成的数据xix_i。你想,本来Decoder告诉你的是给定ziz_iXX可能的分布,你到好,直接把这个分布的均值作为生成的图像了,那能不模糊吗?

知乎上另一类回答说,VAE产生的图像之所以模糊,就是因为高斯分布的假设,比如回答1回答2。这类回答的点在于:如果对pθ(Xzi)p_{\theta}\left(X \mid z_{i}\right)进行高斯分布的假设,那么我们等同于假设数据是一个单峰分布,但是现实中数据往往的多峰 (Multimodal) 的,你用单峰的分布去拟合多峰的分布,那模型只能把多峰进行平均来降低损失了,这样一来,拟合的分布对应的图像,自然也就是模糊的了。

这两类回答看问题的角度是不一样的。但我觉得它们都存在一定的问题,至少不能把我完全说服,我列一下各自可能的疑点:

  1. 对于第一类回答:确实,本来让你采样,现在你直接拿分布的均值出来,似乎图片注定会变得模糊。但是给你一个模糊的均值,再给你一个高斯的方差,你去采样不依然很可能是模糊的?
  2. 对于第二类答案:我数据肯定是多峰的,但我假设pθ(Xzi)p_{\theta}\left(X \mid z_{i}\right)是单峰的为什么不行?只要我能确保每个ziz_i对应的XX的分布是单峰的不就行了?那这样来看,这个问题的本质是因为模型拟合能力不行,输出隐变量无法捕捉充分的信息,而高斯分布也只是受害者?

综上,目前最能说服我的观点是这样的:

模型拟合能力就是没那么强,模型习得的隐变量就是无法完全对应出单峰的XX来。在这种前提下,你再假设是高斯分布,那模型只能把多峰的分布给平均了。所以一个更成功的生成模型,就是允许XziX\mid z_i是一个更复杂的分布,从而使得模型容错率变高:就算你ziz_i对应的XX是多峰的,我这个复杂的分布也能拟合这个多峰的分布。

至于直接取均值而不采样,故而导致模糊的观点,我觉得只能是非常次要的原因。毕竟你输出的均值就已经是模糊的了,再采样也没有意义。

7. References

[1] Doersch, Carl. "Tutorial on variational autoencoders." arXiv preprint arXiv:1606.05908 (2016).

[2] Slides from UIUC CS446: Machine Learning

[3] Slides from Hung-yi Lee's ML Lecture

[4] Zhao, Shengjia, Jiaming Song, and Stefano Ermon. "Towards deeper understanding of variational autoencoding models." arXiv preprint arXiv:1702.08658 (2017).

送礼物用户
送礼物用户
送礼物用户
4 人已送礼物
编辑于 2022-08-25 02:17
理性发言,友善互动

160 条评论
默认
最新
林中鸟
看了三天VAE了,这是讲的最清楚的一篇,作者真强!!!!
2021-07-13
飞毛腿

"回忆一下,我们之前已经对似然pθ(X∣z)和先验p(z)的分布做了假设——它们都服从高斯分布。在这种情况下,不难证明,真实的后验分布pθ(z∣X)也服从高斯分布”。

其实这是不正确的。初看可能认为似然函数和先验均为高斯分布,后验根据共轭的性质也应当是高斯。但是似然函数的mean实际上为经过非线性decoder变化之后的\sigma(z),而非z本身,且z和\sigma(z)的维度都未必相同。若将两者的概率密度函数写出来就非常明了了,不可能能将不同维度的两个指数相整理成关于z的正态分布。或者换一种思路,如果后验分布真的是正态分布的话,为什么还要用encoder去拟合它呢?直接将它写出来就可以了。其实用正态分布去拟合后验分布是VAE的最大假设,且个人认为它拟合的不精确也是VAE效果不好的原因之一。

2023-03-27
在水一方

确实,VAE的后验不一定是正太分布的,我们强行用正太分布去拟合后验,个人认为这也是VAE效果不好的原因。像Diffusion 因为堆叠了很多层加噪的过程,每个过程都是微小的正太分布加噪,所以用正太分布来近似拟合去噪逆过程更较为合理。这也是为什么diffusion效果会好于VAE的原因

2023-05-30
夕心木
说的太棒了,我也发现这个地方错误了,后验分布的真实分布根本就不会受到前面假设的影响,它该是什么分布就是什么分布,而不是假设了似然是高斯分布,那么后验分布的真实分布也会跟着改变
2024-03-16
阳光

Hi~你的讲解很棒!

但我有一个小问题:3.6节 log P_theta(X) = \sum q_phi(Z|X) log P_theta(X) dz 这一步应该不是全概率公式吧?

应该是 log P_theta(X) = 1 * log P_theta(X) = \sum q_phi(Z|X) dz * log P_theta(X), 因为log P_theta(X)与z 无关,所以可以进一步放在积分号内。

不知道我说的对不对,请指教!

2021-03-25
苗思奇
作者

你是对的,我也记不得当时是怎么想的了。。感谢指正!我偷偷改一下去

2021-03-25
发财有双大耳朵

这个不能叫全概率公式,就是概率和为1

2023-03-04
再见孙悟空
牛逼,把从直觉到数学推导打通了
2021-12-19
煮茶论英雄
真实的后验分布未必是高斯分布,这才是推断的意义
2021-12-06
物理路上的小废物
我的理解是这里多个高斯分布组一组去近似它的一个真实分布[捂脸]
04-27
知乎用户902mcp

机器学习中理论优美的模型不一定效果好,比如流形方法、VAE模型等,反倒是CNN、GAN这种理论上不严谨、直觉更多的模型,效果更好。数据的真实分布难以建模是主因,真实的数据,未必遵循严格的分布规律。理论和效果都好的方法,我感觉就是SVM了。

2023-09-25
Ricky
说因为取均值所以模糊的观点应该是错的。均值也是一个采样点,而且对高斯分布来说是概率最大的采样点
2021-12-03
这玩意好难
VAE为什么模糊那块,解释得太NICE了
2021-10-11
机器学习废物
为什么p(x|z),p(z)都服从高斯分布,p(z|x)也是服从高斯分布呢?
2021-12-30
肯肯星
两个高斯分布相乘,得到的分布的kernel是高斯分布,可以自己推一下
2022-04-04
知乎用户H7u3m0

也可以是其他分布,原论文是以高斯分布为例

2023-12-10
知乎用户UthrXz

您好,是不是VAE中的latent variable不是必须要服从Gaussian distribution,理论上可以改为多峰的distribution[好奇]

2022-08-17
点击查看全部评论
想来知乎工作?请发送邮件到 jobs@zhihu.com