Generative Modeling by Estimating Gradients of the Data Distribution

alt text

这篇论文发表早于DDPM,但是任务跟DDPM一样——生成图像。涉及到的关键词为:

  1. 朗之万采样
  2. Score Matching

b站上讲的非常好的一个视频

你可能对这两个名词一头雾水,别着急,在具体看论文算法之前,我们先铺垫一下。

1. 朗之万采样

首先,它是一种采样方法。什么意思呢?就是给定一个已知的概率分布p(x)p(x),通过采样的方式,得到这个概率分布的样本。

那你可能会问,干嘛不直接随机均匀采样就好了?这里需要考虑到实际数据分布的一些特点:随机采的数据点x^\hat{x}大概率是p(x^)0p(\hat{x})\approx 0。所以现在的问题就是,我们采样得到的数据点需要尽可能“合理”。

ok,朗之万采样提供了一个解决方案,我们从一个数据点x^\hat{x}开始,经过如下方程的若干次迭代之后,便可以得到一个合理的采样点xxalt text

刨除噪声项,其实朗之万采样就是一个朝着pdf最大的方向进行采样的过程——这样也就能保证采样得到的数据点xx尽可能“合理”。

看到这里,我们给出一个实际例子方便理解。我们从[-3, 3]中均匀采样10000个点,每一个点都按照前面的朗之万迭代的方式更新100次得到最终的点。 可以看到,一开始均匀采样的点并不能很好的体现pdf的形状,但是经过朗之万采样之后,pdf的形状就出现了。

alt text

总结:朗之万采样是一个迭代式地更新一开始不那么“好”的采样点,使之更“合理”的过程。这个过程可以理解成梯度上升的思路。

诶看到这里,你可能会很好奇,前面那个迭代式是怎么算出来的?我这里放出一个博客的证明,我自己就不证明了:

朗之万采样的数学推导

总结:朗之万采样方法告诉我们,如果我们通过某些方式知道了(近似了)xlogp(x)\triangledown_{\mathbf{x}} \log p (\mathbf{x}),那么即使我们不知道p(x)p(\mathbf{x}),我们依旧可以进行“合理”采样。而接下来我们将会介绍如何使用Score Matching来近似得到xlogp(x)\triangledown_{\mathbf{x}} \log p (\mathbf{x})

论文片段: alt text

看点代码:

    def Langevin_dynamics(self, x_mod, scorenet, n_steps=200, step_lr=0.00005):
        images = []

        labels = torch.ones(x_mod.shape[0], device=x_mod.device) * 9
        labels = labels.long()

        with torch.no_grad():
            for _ in range(n_steps):
                images.append(torch.clamp(x_mod, 0.0, 1.0).to('cpu'))
                noise = torch.randn_like(x_mod) * np.sqrt(step_lr * 2)
                grad = scorenet(x_mod, labels)
                x_mod = x_mod + step_lr * grad + noise
                x_mod = x_mod
                print("modulus of grad components: mean {}, max {}".format(grad.abs().mean(), grad.abs().max()))

            return images

2. Score Matching

强烈推荐阅读

第一小节中,我们介绍了朗之万采样,定义s(x)=xlogp(x)s(\mathbf{x}) = \triangledown_{\mathbf{x}} \log p (\mathbf{x})s(x)s(x)便是所谓的score function。假设我们通过某种方式训练得到sθ(x)s_{\theta}(x)来近似s(x)s(x),那么我们便可以使用朗之万采样方式进行采样了。

下图为一个高斯混合分布以及其score function。看看y=0与score function的交点,你发现什么没有? alt text

分析至此,你发现问题没有,我们明明要求的是数据集的分布p(x)p(x),但是你现在把p(x)p(x)当成是已知,这不是很奇怪吗?

别急,我们慢慢来。假设给定数据集D={x1,x2,,xN}D=\{x_1, x_2, \cdots, x_N\},使用核密度估计,定义: 我们使用已知的数据集和qq近似p(x)p(x),则我们训练神经网络sθ(x)s_\theta(x)来拟合s(x)s(x)。其损失为: alt text

用KDE估计的缺点是效果比较差。Estimation of Non-Normalized Statistical Models by Score Matching这篇论文通过一定的等价变形将MSE中的真实score一项消掉了,其的出来的损失函数为(推导见知乎推文,在Score Matching这篇论文中也提到了这个,叫做Sliced Score Matching: alt text

下面介绍另一个方法:Denoising Score Matching,也是论文中使用的方法。我们先讲这个方法,随后在第三小节我们再解释原因。

给数据x\textbf{x}加噪音,获得q(x^)=N(x,σ2I)q(\hat{\textbf{x}}) = \mathcal{N}(\textbf{x}, \sigma^2 \textbf{I}),这一步可以理解成数据增强,即x^=x+σz\hat{\textbf{x}} = \textbf{x} + \sigma \textbf{z},其中zN(0,I)\textbf{z} \sim \mathcal{N}(\textbf{0}, \textbf{I})

于是我们的损失可以化简为: alt text

看点代码

def dsm_score_estimation(scorenet, samples, sigma=0.01):
    perturbed_samples = samples + torch.randn_like(samples) * sigma
    target = - 1 / (sigma ** 2) * (perturbed_samples - samples)
    scores = scorenet(perturbed_samples)
    target = target.view(target.shape[0], -1)
    scores = scores.view(scores.shape[0], -1)
    loss = 1 / 2. * ((scores - target) ** 2).sum(dim=-1).mean(dim=0)

    return loss

3. 解释第二小节,为什么要引入噪声,损失函数为什么那么写?(对样本进行perturb)

第一个原因:

我们的样本空间实际上只是高维空间中的一个流形子空间,高维空间中的绝大多数点都是p(x)=0p(x)=0,如果只是在样本空间上学习score function,这意味着我们只会在样本空间上学好score function,对于稍微偏离的点,score function将会是不准的,而score function不准将会导致我们朗之万采样的采样效果不好。这样子,即使是working space中的点,score function的“准”性也会变差。所以给每一个样本点加噪声主要是使得score function在空间中的“覆盖面”更大,从而能保证score function的“准”性。

第二个原因:

我们之前说真实s(x)并不知道,但是其实通过加噪之后能够恰好进行有效化简,从而消掉真实s(x)项。

4. 实现上,用了退火采样

不同噪声强度都需要进行优化:

alt text

采样时,先迈大步子,再细粒度地走: alt text

results matching ""

    No results matching ""