Diffusion Models Beat GANs on Image Synthesis

alt text

Classfier Guidance

只是在采样方式上修改了DDPM,其它保持一致,也就是说我们这里所描述的是一个采样方式。

预步骤:训练DDPM和分类器

采样方式1: alt text

假如按照DDIM的方式进行采样: alt text

推导这里就略去了,可以看看原论文,从score-based的角度来看,DDIM的采样原理挺好证明的。

代码笔记

ref

guided_diffusion/gaussian_diffusion.py

def p_sample(
        self,
        model,
        x,
        t,
        clip_denoised=True,
        denoised_fn=None,
        cond_fn=None,
        model_kwargs=None,
    ):
        """
        Sample x_{t-1} from the model at the given timestep.
        :param cond_fn: if not None, this is a gradient function that acts
                        similarly to the model.
        """
        out = self.p_mean_variance(
            model,
            x,
            t,
            clip_denoised=clip_denoised,
            denoised_fn=denoised_fn,
            model_kwargs=model_kwargs,
        )
        noise = th.randn_like(x)
        nonzero_mask = (
            (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
        )  # no noise when t == 0
        if cond_fn is not None:
            out["mean"] = self.condition_mean(
                cond_fn, out, x, t, model_kwargs=model_kwargs
            )
        sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
        return {"sample": sample, "pred_xstart": out["pred_xstart"]}
if cond_fn is not None:
   out["mean"] = self.condition_mean(
      cond_fn, out, x, t, model_kwargs=model_kwargs
   )

condition_mean

guided_diffusion/gaussian_diffusion.py

    def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
        """
        Compute the mean for the previous step, given a function cond_fn that
        computes the gradient of a conditional log probability with respect to
        x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
        condition on y.

        This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
        """
        gradient = cond_fn(x, self._scale_timesteps(t), **model_kwargs)
        new_mean = (
            p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
        )
        return new_mean

cond_fn scripts/classifier_sample.py

def cond_fn(x, t, y=None):
   assert y is not None
   with th.enable_grad():
       x_in = x.detach().requires_grad_(True)
       logits = classifier(x_in, t)
       log_probs = F.log_softmax(logits, dim=-1)
       selected = log_probs[range(len(logits)), y.view(-1)]
       return th.autograd.grad(selected.sum(), x_in)[0] * args.classifier_scale

results matching ""

    No results matching ""