Kernel Distances
for Better Deep Generative Models

Dougal J. Sutherland
Gatsby Computational Neuroscience Unit, University College London

Based on “Demystifying MMD GANs” [ICLR-18]
and “On gradient regularizers for MMD GANs” [NIPS-18], with:

Gatsby, UCL
Imperial College London
Gatsby, UCL

GPSS 2018 - Advances in Kernel Methods, 6 Sep 2018

(Swipe or arrow keys to move in slides; m for a menu to jump; ? to show more.)

\DeclareMathOperator{\D}{\mathcal{D}} \DeclareMathOperator*{\E}{\mathbb{E}} \newcommand{\F}{\mathcal{F}} \newcommand{\h}{\mathcal{H}} \DeclareMathOperator{\mean}{mean} \newcommand{\PP}{\mathbb P} \newcommand{\QQ}{\mathbb Q} \newcommand{\R}{\mathbb R} \DeclareMathOperator{\W}{\mathcal{W}} \newcommand{\X}{\mathcal X} \newcommand{\Z}{\mathcal Z} \newcommand{\ZZ}{\mathbb Z} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator*{\argmax}{argmax} \DeclareMathOperator{\mmd}{MMD} \DeclareMathOperator{\smmd}{SMMD} \DeclareMathOperator{\mmdhat}{\widehat{MMD}} \DeclareMathOperator{\optmmd}{\mathcal{D}_\mathrm{MMD}} \DeclareMathOperator{\optmmdhat}{\hat{\mathcal{D}}_\mathrm{MMD}} \newcommand{\target}{\mathrm{target}} \newcommand{\ktop}{k_\mathrm{top}}

Implicit generative models

Given samples from a distribution \PP over \X ,
want a model that can produce new samples from \QQ_\theta \approx \PP



Don't necessarily care about likelihoods, interpretability, …

Model: Generator network

Deep network (params \theta ) mapping from noise \ZZ to images \X

DCGAN generator [Radford+ ICLR-16]

\ZZ is e.g. uniform on [-1, 1]^{100}

Choose \theta by minimizing…something

Loss function

  • Can't evaluate likelihood of samples under model
  • Likelihood maybe not the best choice anyway [Theis+ ICLR-16]
    • Doesn't hurt likelihoods much to take 99% white noise
  • Instead, we'll minimize some \D(\PP, \QQ_\theta)
Max-likelihood objective vs WGAN objective [Danihelka+ 2017]

Maximum Mean Discrepancy [Gretton+ 2012]

\mmd(\PP, \QQ) = \sup_{\lVert f \rVert_\h \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] ( \h is RKHS with kernel k : \X \times \X \to \R )

Can do optimization in closed form: \begin{gather} f^*_k(t) \propto \E_{X \sim \PP} k(X, t) - \E_{Y \sim \QQ} k(Y, t) \\ \mmd^2 = \E k(X, X') + \E k(Y, Y') - 2 \E k(X, Y) \end{gather}

Deriving MMD

\begin{align} % \mmd(\PP, \QQ) & \sup_{\lVert f \rVert_\h \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] \\&\fragment{ = \sup_{\lVert f \rVert_\h \le 1} \E_{X \sim \PP}[\langle f, k(X, \cdot) \rangle_\h] - \E_{Y \sim \QQ}[\langle f, k(Y, \cdot) \rangle_\h] } \\&\fragment{ = \sup_{\lVert f \rVert_\h \le 1} \left\langle f, \E_{X \sim \PP}[k(X, \cdot)] - \E_{Y \sim \QQ}[k(Y, \cdot)] \right\rangle_\h } \\&\fragment{= \left\lVert \E_{X \sim \PP}[k(X, \cdot)] - \E_{Y \sim \QQ}[k(Y, \cdot)] \right\rVert_\h } \end{align}

Unbiased estimator of \mmd^2

\fragment[0][highlight-current-blue]{\operatorname{mean}(K_{XX})} \fragment[1][highlight-current-blue]{+ \operatorname{mean}(K_{YY})} \fragment[2][highlight-current-blue]{- 2 \operatorname{mean}(K_{XY})}

MMD as loss [Li+ ICML-15, Dziugaite+ UAI-15]

Estimate generator \hat\theta based on SGD with minibatches
X \sim \PP^m , Z \sim \ZZ^n , loss \mmdhat^2(X, G_\theta(Z))

Hard to pick a good kernel for images

MMD GANs: Deep kernels [Li+ NIPS-17]

  • Use a class of deep kernels: k_\psi(x, y) = \ktop(\phi_\psi(x), \phi_\psi(y))
  • Choose most-discriminative out of those kernels: \optmmd(\PP, \QQ) = \sup_{\psi} \mmd_{\psi}(\PP, \QQ)
  • Initialize random generator G_\theta and representation \phi_\psi
    • Repeat: SGD step in \theta to minimize \hat{\mathcal{D}}^2_\mathrm{MMD}(\PP, G_\theta(\ZZ))
      • 5 times:
        • Take SGD step in \psi to maximize \mmdhat^2_\psi(\PP, G_\theta(\ZZ))
      • Take SGD step in \theta to minimize \mmdhat^2_\psi(\PP, G_\theta(\ZZ))

Smoothness of \mathcal{D}_\mathrm{MMD}

  • Toy problem in \R , DiracGAN [Mescheder+ ICML-18]:
    • Point mass target \PP = \delta_0 , model \QQ_\theta = \delta_\theta
    • Representation \phi_\psi(x) = \psi x , \psi \in \R
    • Kernel \ktop(a, b) = \exp\left( - \tfrac12 (a - b)^2 \right)

Plain \optmmd doesn't work [Sutherland+ ICLR-17]:

Wasserstein and WGANs

\W(\PP, \QQ) = % \sup_{\sup_x \lVert \nabla f(x) \rVert \le 1} \sup_{\lVert f \rVert_L \le 1} \E_{X \sim \PP}[ f(X) ] - \E_{Y \sim \QQ}[ f(Y) ]

\lVert f \rVert_L = \sup_{x, x'} \frac{\lvert f(x) - f(x') \rvert}{\lVert x - x' \rVert} = \sup_x \lVert \nabla f(x) \rVert

  • WGANs [Arjovsky+ ICML-17], WGAN-GPs [Gulrajani+ NIPS-17]:
    • Update a critic neural network f_\psi
      \hat\psi_\theta \approx \argmax_\psi \E_{X \sim \PP} f_\psi(X) - \E_{Z \sim \ZZ} f_\psi(G_\theta(Z))
      • Enforce Lipschitz constraint on f_\psi (more shortly)
    • Run SGD to minimize estimate of \mathcal W with that critic
      \hat\W_{\hat\psi_\theta} %(X, G_\theta(Z)) = \hat{\mathbb E}_{X \sim \PP} f_{\hat\psi_\theta}(X) - \hat{\mathbb E}_{Z \sim \ZZ} f_{\hat\psi_\theta}(G_\theta(Z)) % = \frac{1}{m} \sum_{i=1}^m f_{\hat\psi_\theta}(X_i) % - \frac1n \sum_{j=1}^n f_{\hat\psi_\theta}(G_\theta(Z_j))

Enforcing Lipschitz constraint

  • One strategy: only optimize over f_\theta that are 1 -Lipschitz
    • Hard to closely specify Lipschitz constant of deep nets
    • WGAN [Arjovsky+ ICML-17] tried with simple box constraint
    • Also original MMD GAN paper [Li+ NIPS-17]
  • WGAN-GP [Gulrajani+ NIPS-17]: penalize non-Lipschitzness \text{regularizer: } % \E_{X \sim \PP} f_\theta(X) - \E_{Y \sim \ZZ} f_\theta(G_\psi(Z)) + \lambda \E_{\tilde X} \left( \lVert \nabla_{\tilde X} f_\theta(\tilde X) \rVert - 1 \right)^2 (with \tilde X drawn in between the X and Z minibatches)
    • Looser constraint than Lipschitz
    • Tends to work better in practice

MMD-GAN with gradient control

  • If we optimize MMD over kernels giving uniformly Lipschitz critics, loss will be continuous, a.e. differentiable
  • Original MMD GAN paper [Li+ NIPS-17] used a box constraint
  • We [Binkowski+ ICLR-18] used gradient penalty on critic instead
    • Better in practice, but doesn't fix the Dirac problem…


  • Consider linear-kernel MMD GAN, k(x, y) = \phi(x) \phi(y) : \begin{gather} \text{loss} %= \mmd_\phi(\PP, \QQ) = \lvert \E_\PP \phi(X) - \E_\QQ \phi(Y) \rvert \\ %f^*(t) \operatorname{critic}(t) = \operatorname{sign}\left( \E_\PP \phi(X) - \E_\QQ \phi(Y) \right) \phi(t) \end{gather}
  • WGAN has: \begin{gather} \text{loss} = \E_\PP \phi(X) - \E_\QQ \phi(Y) \\ \operatorname{critic}(t) = \phi(t) \end{gather}
  • Linear-kernel MMD GAN-GP and WGAN-GP almost the same
  • MMD GAN “offloads” some of the critic's work to closed-form optimization in the RKHS

Built-in gradient constraints

Lipschitz MMD ( \lambda = 1 )

  • \mmd_k = \sup_{\lVert f \rVert_{\h_k}^2 \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)]
  • Lipschitz MMD: \mathcal L_{k,\lambda} = \sup_{\lVert f \rVert_L^2 + \lambda \lVert f \rVert_{\h_k}^2 \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)]
  • Approximation, Gradient-Constrained MMD: \begin{gather} S_{k,\mu,\lambda} = \sup_{\lVert f \rVert_{S(\mu),\lambda} \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] \\ \lVert f \rVert_{S(\mu),\lambda}^2 := \E_{X \sim \mu} \left[ f(X)^2 + \lVert \nabla f(X) \rVert^2 \right] + \lambda \lVert f \rVert^2_{\h_k} \end{gather}
    • Variance constraint makes it like a Sobolev norm
    • Doesn't quite constrain the Lipschitz constant

Lipschitz MMD ( \lambda = 1 )

GC-MMD, \mu = \mathrm{Uniform}[0, \theta]

Gradient-Constrained MMD on MNIST

  • It's a reasonable distance to optimize:
  • …but this took days to run

Estimating Gradient-Constrained MMD

  • Say we have m samples \tilde X \sim \mu
  • Let \eta(t) = \E_{X \sim \PP} k(X, t) - \E_{Y \sim \QQ} k(Y, t)
  • Then S_{k,\hat\mu,\lambda}^2 = \frac1\lambda \left( \mmd^2 - \bar P \right) with \bar P = \begin{bmatrix} \eta(\tilde X) \\ \nabla\eta(\tilde X) \end{bmatrix}^\top \left( \begin{bmatrix} K & G^\top \\ G & H \end{bmatrix} + m \lambda I \right)^{-1} \begin{bmatrix} \eta(\tilde X) \\ \nabla\eta(\tilde X) \end{bmatrix} ; K has k(\tilde X, \tilde X') , G has \partial_i k(\tilde X, \tilde X') , H has \partial_i \partial_{i+d} k(\tilde X, \tilde X')
  • Dropping kernel matrix gives \frac1{\lambda} \hat{\mathbb E}_{\tilde X}[ \eta(\tilde X)^2 + \lVert \nabla \eta(\tilde X) \rVert^2]
  • Solving this linear system takes m^3 d^3 time!

Efficent approximation: Scaled MMD

  • Using a bit of RKHS theory, can write \lVert f \rVert_{S(\mu),\lambda}^2 = \langle f, D_\lambda f \rangle_\h \fragment{\le \lVert D_\lambda \rVert_\mathrm{op} \lVert f \rVert_\h^2} \fragment{\le \sigma_{k,\mu,\lambda}^{-2} \lVert f \rVert_\h^2} \!\!\!\! \sigma_{k,\mu,\lambda}^{-2} = \E_{X \sim \mu} k(X, X) + \sum_{i=1}^d \E_{X \sim \mu} \partial_i \partial_{i+d} k(X, X) + \lambda
  • Define scaled MMD as \begin{align} \smmd_{k,\mu,\lambda}(\PP, \QQ) &= \sup_{\sigma_{k,\mu,\lambda}^{-1} \lVert f \rVert_\h \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] \\ &\fragment{= \sigma_{k,\mu,\lambda} \mmd_k(\PP, \QQ)} \end{align}
  • Have \smmd \le S_{k,\mu,\lambda}

Deriving the Scaled MMD: \lVert D_\lambda \rVert_\mathrm{op}

  • Operator D_\lambda : \h \to \h given by D_\lambda = C + D + \lambda I : \begin{align} \langle f, C g \rangle_\h &= \E_{X \sim \mu} f(X) g(X) \\&\fragment[1]{= \langle f, \left[\E_{X \sim \mu} k(X, \cdot) \otimes k(X, \cdot)\right] g \rangle_\h} \\ \fragment[2]{\lVert C \rVert_{\mathrm{op}}}\, &\fragment[2]{\le \E_{X \sim \mu} \lVert k(X, \cdot) \otimes k(X, \cdot) \rVert_{\mathrm{op}}} \fragment[3]{= \E_{X \sim \mu} k(X, X)} \\ \langle f, D g \rangle_\h &= \E_{X \sim \mu} \sum_{i=1}^d \partial_i f(X) \, \partial_i g(X) \\ \fragment[4]{ \lVert D \rVert_{\mathrm{op}}\,} &\fragment[4]{ \le \E_{X \sim \mu} \sum_{i=1}^d \partial_i \partial_{i+d} k(X, X)} \end{align}

Deriving the Scaled MMD: Lower Bound

\begin{align} S_{k,\mu,\lambda} &= \sup_{\lVert f \rVert_{S(\mu),\lambda} \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] \\&\le \sup_{\sigma_{k,\mu,\lambda}^{-1} \lVert f \rVert_\h \le 1} \left\langle f, \left[\E_{X \sim \PP} k(X, \cdot) - \E_{Y \sim \QQ} k(X, \cdot) \right] \right\rangle_\h \\&= \sup_{\lVert f \rVert_\h \le \sigma_{k,\mu,\lambda}} \left\langle f, \left[\E_{X \sim \PP} k(X, \cdot) - \E_{Y \sim \QQ} k(X, \cdot) \right] \right\rangle_\h \\&= \sigma_{k,\mu,\lambda} \mmd \end{align}

Scaled MMD vs MMD with Gradient Penalty

  • When k_\psi(x, y) = \exp\left( - \frac12 \lVert \phi_\psi(x) - \phi_\psi(y) \rVert^2 \right) , \smmd_{k,\mu,\lambda}^2 = \frac{\mmd_{k}^2}{1 + \E_\mu[ \lVert \nabla \phi_\psi(X) \rVert^2] + \lambda}

Rank collapse

  • Optimization failure we sometimes see on SMMD, GC-MMD:
    • Generator doing reasonably well
    • Critic filters become low-rank
    • Generator corrects it by breaking everything else
    • Generator gets stuck

Spectral parameterization [Miyato+ ICLR-18]

  • W = \gamma \bar W / \lVert \bar W \rVert_\text{op} ; learn \gamma and W freely
  • Encourages diversity without limiting representation

What if we just did spectral normalization?

  • W = \bar W / \lVert \bar W \rVert_\text{op} , so that \lVert W \rVert_\text{op} = 1 , \lVert \varphi_\psi \rVert_L \le 1
  • Works well for original GANs [Miyato+ ICLR-18]
  • …but doesn't work at all as only constraint in a WGAN
  • Limits representation too much
    • In the toy problem, constrains to \phi_\psi(x) = x
    • \lVert x \mapsto \sigma(W_n \cdots \sigma(W_1 x)) \rVert_L ≪ \lVert W_n \rVert_\text{op} \cdots \lVert W_1 \rVert_\text{op}

Implicit generative model evaluation

  • No likelihoods, so…how to compare models?
  • Main approach:
    look at a bunch of pictures and see if they're pretty or not
    • Easy to find (really) bad samples
    • Hard to see if modes are missing / have wrong probabilities
    • Hard to compare models beyond certain threshold
  • Need better, quantitative methods

Inception score [Salimans+ NIPS-16]

  • Previously standard quantitative method
  • Based on ImageNet classifier label predictions
    • Classifier should be confident on individual images
    • Predicted labels should be diverse across sample
  • No notion of target distribution \PP
  • Scores completely meaningless on LSUN, Celeb-A, SVHN, …
  • Not great on CIFAR-10 either

Fréchet Inception Distance (FID) [Heusel+ NIPS-17]

  • Fit normals to Inception hidden layer activations of \PP and \QQ
  • Compute Fréchet (Wasserstein-2) distance between fits
  • Meaningful on not-ImageNet datasets
  • Estimator extremely biased, tiny variance
  • \operatorname{FID}(\PP_1, \QQ) < \operatorname{FID}(\PP_2, \QQ) , \E \operatorname{FID}(\hat \PP_1, \QQ) > \E \operatorname{FID}(\hat \PP_2, \QQ)

New method: Kernel Inception Distance (KID)

  • \mmdhat^2 between Inception hidden layer activations
  • Use default polynomial kernel: k(x, y) = \left( \frac1d \langle x, y \rangle + 1 \right)^3
  • Unbiased estimator, reasonable with few samples
  • (Almost) in tensorflow.contrib.gan.eval (tensorflow#21066)

Automatic learning rate adaptation with KID

  • Models need appropriate learning rate schedule to work well
  • Automate with three-sample MMD test [Bounliphone+ ICLR-16]:
.116 ± .002
.026 ± .001
.032 ± .001
.027 ± .001
LSUN BedroomsSmallBig
.370 ± .003
.039 ± .002
.091 ± .002
.028 ± .002

Training on 160 \times 160 CelebA

160 \times 160 CelebA Samples


KID: 0.006

KID: 0.022

64 \times 64 ImageNet Samples


KID: 0.035

KID: 0.044

KID: 0.047

Demystifying MMD GANs [ICLR-18]
Mikołaj Bińkowski*, Dougal J. Sutherland*, Michael Arbel, Arthur Gretton

  • MMD GANs do some optimization in closed form
    • Can handle smaller critic networks
  • Not in this talk: study of gradient bias for WGAN, MMD GAN
  • Unbiased evaluation and learning rate adaptation with KID

On gradient regularizers for MMD GANs [NIPS-18]
Michael Arbel, Dougal J. Sutherland, Mikołaj Bińkowski, Arthur Gretton

  • Gradient control is important
    • Scaled MMD does it in closed form, seems to help a lot
  • Spectral normalization plays nice with SMMD

Estimator bias

  • Bellemare+ [2017] say that:
    • WGANs have biased gradients
    • which can lead SGD to wrong minimum even in expectation
    • but Cramér GANs have unbiased gradients
      • Cramér GAN \approx MMD GAN with particular kernel
  • We show:
    • Gradients of fixed critic, \nabla_{\psi\theta} \D_\psi(X, G_\theta(Z)) , are unbiased
    • Gradients of optimized critic, \nabla_\theta \D(X, G_\theta(Z)) , are biased
    • Exact same situation for WGAN and MMD GAN

Unbiasedness theorem for fixed critic

  • For almost all feedforward architectures G , h
    • Works for ReLU, max-pooling, …
  • For any distributions \PP , \ZZ with \E[ \lVert X \rVert^2 ] < \infty ,
  • For most kernels k_\mathrm{top} used in practice,
    • Includes linear kernel, RBF, RQ, distance kernel, …
  • For Lebesgue-almost all parameters (\theta, \psi) :
  • % \E \nabla_\psi \phi_\psi(X) % = \nabla_\psi \E \phi_\psi(X), \E \nabla_{\psi,\theta} \phi_\psi(G_\theta(Z)) = \nabla_{\psi,\theta} \E \phi_\psi(G_\theta(Z)) , so:
    • \nabla_{\psi} \phi_\psi(X) , \nabla_{\psi,\theta} \phi_\psi(G_\theta(Z)) unbiased (WGANs)
    • \nabla_{\psi,\theta} \mmdhat^2(\phi_\psi(X), \phi_\psi(G_\theta(Z))) unbiased

Proof of unbiasedness theorem

  • Can't use standard argument: needs \phi_\psi(X) differentiable in a neighborhood of \psi for almost all inputs X
  • But take \phi_\psi(X) = \max(0, \psi \cdot X)
  • Fixed \psi : differentiable everywhere but \psi \cdot X = 0
  • \{ X \mid \psi' \cdot X = 0; \psi' \in B(\theta, r) \} might have probability >0

Proof of unbiasedness theorem

  • Instead: Show gradient exchanges when \phi_\psi differentiable in \psi for almost all inputs, using Lipschitz properties of the network
  • Then show the set of non-differentiable \psi has 0 measure, with a bit of geometry
  • Example: for h_{\psi}(X) = \max(0,\psi X) , \psi = 0 is a critical point:

Proof of unbiasedness theorem

  • By Fubini theorem, only need to show:
    • For fixed input X , \phi_{\psi}(X) differentiable for almost all \theta
  • Recall: \phi_{\psi}^1(X) = \sigma (\phi_{\psi}^0(X)) , with piecewise smooth \sigma :
  • case 1: \phi_{\psi}^{0}(X) inside domain of analyticity
  • case 2: \phi_{\psi}^{0}(X) on boundary
  • case 3: \phi_{\psi}^{0}(X) crosses the boundary:
    • \psi in a union of manifolds of 0 measure

Gradient bias for “full” loss

  • Recall \D(\PP, \QQ) = \sup_\psi \D_\psi(\PP, \QQ)
    • Estimator splits data:
      • Pick \hat\psi on train set, estimate \D_{\hat\psi} on test set
      • GAN test set: current minibatch
  • Showed: \nabla_{(\psi,\theta)} \hat\D_\psi(X, G_\theta(Z)) unbiased for any fixed \psi
  • Now: \nabla_\theta \hat\D(X, G_\theta(Z)) is biased
    • Estimators have non-constant bias iff gradients are biased
    • Will show \hat\D(X, G_\theta(Z)) is biased

Bias of \hat\D

  • Eval on test set is unbiased: \E[\hat\D_\psi(X, Y)] = \D_\psi(\PP, \QQ)
  • But training introduces bias:
    • \D_\psi(\PP, \QQ) \le \sup_\psi \D_\psi(\PP, \QQ) = \D(\PP, \QQ)
    • If \Pr(\hat\psi \text{ is optimal}) \ne 1 , estimator must be biased down
  • Probably not a big deal in practice
    • No (direct) bias due to minibatch size
    • Can decrease bias by training critic longer
  • But informs theory
    • Convergence based on SGD of \hat\D made difficult

Non-existence of an unbiased IPM estimator

  • Beautiful argument of Bickel & Lehman [Ann. Math. Stat 1969]:
  • Let \PP_0 \ne \PP_1 , \PP_\alpha = (1 - \alpha) \PP_0 + \alpha \PP_1
  • Suppose \E_{\substack{\mathbf X \sim \PP_\alpha^m \\ Y \sim \QQ^n}} \hat\D(\mathbf X, \mathbf Y) = \D(\PP_\alpha, \QQ) ; then \begin{align} \textstyle %R(\alpha) \D(\PP_\alpha, \QQ) &= \iiint %_{X_1, \cdots, X_m} \int_{\mathbf Y} \hat\D(\mathbf X, \mathbf Y) [(1 - \alpha) \mathrm{d}\PP_0(X_1) + \alpha \mathrm{d}\PP_1(X_1)] \cdots \mathrm{d}\QQ(\mathbf Y) \\&\fragment{ = (1 - \alpha)^m \E_{\substack{\mathbf X \sim \PP_0^m \\ \mathbf Y \sim \QQ^n}}[ \hat\D(\mathbf X, \mathbf Y) ] + \cdots + \alpha^m \E_{\substack{\mathbf X \sim \PP_1^m \\ \mathbf Y \sim \QQ^n}}[ \hat\D(\mathbf X, \mathbf Y) ] } \end{align}
  • But \D(\PP_\alpha, \PP_{1/2}) = \lvert \frac12 - \alpha \rvert \D(\PP_0, \PP_1) isn't a polynomial
  • So no unbiased \hat\D can exist – though \hat\D^2 can