Better GANs by using the MMD

Dougal J. Sutherland
Gatsby Computational Neuroscience Unit, University College London

Based on “Demystifying MMD GANs” [ICLR-18]
and “On gradient regularizers for MMD GANs” [arXiv:1805.11565], with:

Gatsby, UCL
Mikołaj Bińkowski
Imperial College London
Gatsby, UCL

Google New York, 14 June 2018

(Swipe or arrow keys to move in slides; m for a menu to jump; ? to show more.)

\DeclareMathOperator{\D}{\mathcal{D}} \DeclareMathOperator*{\E}{\mathbb{E}} \newcommand{\F}{\mathcal{F}} \newcommand{\h}{\mathcal{H}} \DeclareMathOperator{\mean}{mean} \newcommand{\PP}{\mathbb P} \newcommand{\QQ}{\mathbb Q} \newcommand{\R}{\mathbb R} \DeclareMathOperator{\W}{\mathcal{W}} \newcommand{\X}{\mathcal X} \newcommand{\Z}{\mathcal Z} \newcommand{\ZZ}{\mathbb Z} \DeclareMathOperator*{\argmin}{argmin} \DeclareMathOperator*{\argmax}{argmax} \DeclareMathOperator{\mmd}{MMD} \DeclareMathOperator{\smmd}{SMMD} \DeclareMathOperator{\mmdhat}{\widehat{MMD}} \DeclareMathOperator{\optmmd}{\mathcal{D}_\mathrm{MMD}} \DeclareMathOperator{\optmmdhat}{\hat{\mathcal{D}}_\mathrm{MMD}} \newcommand{\target}{\mathrm{target}} \newcommand{\ktop}{k_\mathrm{top}}

Implicit generative models

Given samples from distribution \PP over \X
Want model that can produce new samples from \QQ_\theta \approx \PP



Don't necessarily care about likelihoods, interpretability, …

Model: Generator network

Deep network (params \theta ) mapping from noise \ZZ to images \X

DCGAN generator [Radford+ ICLR-16]

\ZZ is e.g. uniform on [-1, 1]^{100}

Choose \theta by minimizing…something

Loss function

  • Can't evaluate likelihood of samples under model
  • Likelihood maybe not the best choice anyway [Theis+ ICLR-16]
    • Doesn't hurt likelihoods much to take 99% white noise
  • Instead, we'll minimize some \D(\PP, \QQ_\theta)
Max-likelihood objective vs WGAN objective [Danihelka+ 2017]

Maximum Mean Discrepancy [Gretton+ 2012]

\mmd(\PP, \QQ) = \sup_{\lVert f \rVert_\h \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] ( \h is RKHS with kernel k : \X \times \X \to \R )

Can do optimization in closed form: \begin{gather} f^*_k(t) \propto \E_{X \sim \PP} k(X, t) - \E_{Y \sim \QQ} k(Y, t) \\ \mmd^2 = \E k(X, X') + \E k(Y, Y') - 2 \E k(X, Y) \end{gather}

Unbiased estimator of \mmd^2

\fragment[0][highlight-current-blue]{\operatorname{mean}(K_{XX})} \fragment[1][highlight-current-blue]{+ \operatorname{mean}(K_{YY})} \fragment[2][highlight-current-blue]{- 2 \operatorname{mean}(K_{XY})}

MMD as loss [Li+ ICML-15, Dziugaite+ UAI-15]

Estimate generator \hat\theta based on SGD with minibatches
X \sim \PP^m , Z \sim \ZZ^n , loss \mmdhat^2(X, G_\theta(Z))

Hard to pick a good kernel for images

MMD GANs: Deep kernels [Li+ NIPS-17]

  • Use a class of deep kernels: k_\psi(x, y) = \ktop(\phi_\psi(x), \phi_\psi(y))
  • Choose most-discriminative out of those kernels: \optmmd(\PP, \QQ) = \sup_{\psi} \mmd_{\psi}(\PP, \QQ)
  • Initialize random generator G_\theta and representation \phi_\psi
    • Repeat: SGD step in \theta to minimize \hat{\mathcal{D}}^2_\mathrm{MMD}(\PP, G_\theta(\ZZ))
      • 5 times:
        • Take SGD step in \psi to maximize \mmdhat^2_\psi(\PP, G_\theta(\ZZ))
      • Take SGD step in \theta to minimize \mmdhat^2_\psi(\PP, G_\theta(\ZZ))

Wasserstein and WGANs

\W(\PP, \QQ) = % \sup_{\sup_x \lVert \nabla f(x) \rVert \le 1} \sup_{\lVert f \rVert_L \le 1} \E_{X \sim \PP}[ f(X) ] - \E_{Y \sim \QQ}[ f(Y) ]

  • WGANs [Arjovsky+ ICML-17], WGAN-GPs [Gulrajani+ NIPS-17]:
    • Train a neural network f_\psi for critic
      \hat\psi_\theta \approx \argmax_\psi \E_{X \sim \PP} f_\psi(X) - \E_{Z \sim \ZZ} f_\psi(G_\theta(Z))
    • Enforce Lipschitz constraint on f_\psi (more on this later) \textstyle \lVert f \rVert_L = \sup_{x, y} \frac{\lvert f(x) - f(y) \rvert}{\lVert x - y \rVert} = \sup_x \lVert \nabla f(x) \rVert
    • Run SGD on minibatches X \sim \PP , Z \sim \ZZ with that critic
      \hat\W_{\hat\psi_\theta} %(X, G_\theta(Z)) = \hat{\mathbb E}_{X \sim \PP} f_{\hat\psi_\theta}(X) - \hat{\mathbb E_{Z \sim \ZZ}} f_{\hat\psi_\theta}(G_\theta(Z)) % = \frac{1}{m} \sum_{i=1}^m f_{\hat\psi_\theta}(X_i) % - \frac1n \sum_{j=1}^n f_{\hat\psi_\theta}(G_\theta(Z_j))


  • Consider linear-kernel MMD GAN, k(x, y) = \phi(x) \phi(y) : \begin{gather} \text{loss} %= \mmd_\phi(\PP, \QQ) = \lvert \E_\PP \phi(X) - \E_\QQ \phi(Y) \rvert \\ %f^*(t) \operatorname{critic}(t) = \operatorname{sign}\left( \E_\PP \phi(X) - \E_\QQ \phi(Y) \right) \phi(t) \end{gather}
  • WGAN has: \begin{gather} \text{loss} = \E_\PP \phi(X) - \E_\QQ \phi(Y) \\ \operatorname{critic}(t) = \phi(t) \end{gather}
  • Linear-kernel MMD GAN and WGAN almost the same
  • MMD GAN “offloads” some of the critic's work to closed-form optimization in the RKHS

Smooth critics in MMD GANs

  • Toy problem in \R , DiracGAN [Mescheder+ ICML-18]:
    • Point mass target \PP = \delta_0 , model \QQ_\theta = \delta_\theta
    • Representation \phi_\psi(x) = \psi x , \psi \in \R
    • Gaussian kernel \ktop(a, b) = \exp\left( - (a - b)^2 / 2 \right)

Smooth critics in MMD GANs

  • Toy problem in \R [Mescheder+ ICML-18]:
    • Point mass target \PP = \delta_0 , model \QQ_\theta = \delta_\theta
    • Representation \phi_\psi(x) = \psi x , \psi \in \R
    • Gaussian kernel \ktop(a, b) = \exp\left( - (a - b)^2 / 2 \right)
  • Taking \psi \to \infty gives \optmmd(\delta_0, \delta_\theta) = \begin{cases} \sqrt{2} & \theta \ne 0 \\ 0 & \theta = 0 \end{cases}
  • But if we restrain ourselves to \psi where optimal critic f^*_\psi is bounded Lipschitz, \QQ_\theta \stackrel{D}{\to} \PP implies \optmmd(\QQ_\theta, \PP) \to 0 , and \optmmd is continuous and a.e. differentiable

Enforcing Lipschitz constraint

  • First attempt: only optimize over f_\theta that are 1 -Lipschitz
    • Hard to closely specify Lipschitz constant of deep nets
    • WGAN [Arjovsky+ ICML-17] tried with simple box constraint
    • Also original MMD GAN paper [Li+ NIPS-17]
  • WGAN-GP [Gulrajani+ NIPS-17]: penalize non-Lipschitzness \E_{X \sim \PP} f_\theta(X) - \E_{Y \sim \ZZ} f_\theta(G_\psi(Z)) + \lambda \E_{\tilde X} \left( \lVert \nabla_{\tilde X} f_\theta(\tilde X) \rVert - 1 \right)^2 (with \tilde X drawn in between the X and Z minibatches)
    • Looser constraint than Lipschitz
    • Tends to work better in practice
    • But doesn't fix the Dirac problem…



Built-in gradient constraints

  • \mmd_k = \sup_{\lVert f \rVert_{\h_k}^2 \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)]
  • Lipschitz MMD: \mathcal L_{k,\lambda} = \sup_{\lVert f \rVert_L^2 + \lVert f \rVert_{\h_k}^2 \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)]
  • Approximation, Gradient-Constrained MMD: \begin{gather} S_{k,\mu,\lambda} = \sup_{\lVert f \rVert_{S(\mu),\lambda} \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] \\ \lVert f \rVert_{S(\mu),\lambda}^2 := \E_{X \sim \mu} \left[ f(X)^2 + \lVert \nabla f(X) \rVert^2 \right] + \lambda \lVert f \rVert^2_{\h_k} \end{gather}
    • Variance constraint makes it like a Sobolev norm
    • Doesn't quite constrain the Lipschitz constant

Gradient-Constrained MMD on MNIST

  • It's a reasonable distance to optimize:
  • …but this took days to run

Estimating Gradient-Constrained MMD

  • Say we have m samples \tilde X \sim \mu
  • Let \eta(t) = \E_{X \sim \PP} k(X, t) - \E_{Y \sim \QQ} k(Y, t)
  • Then S_{k,\hat\mu,\lambda}^2 = \frac1\lambda \left( \mmd^2 - \bar P \right) with \bar P = \begin{bmatrix} \eta(\tilde X) \\ \nabla\eta(\tilde X) \end{bmatrix}^\top \left( \begin{bmatrix} K & G^\top \\ G & H \end{bmatrix} + m \lambda I \right)^{-1} \begin{bmatrix} \eta(\tilde X) \\ \nabla\eta(\tilde X) \end{bmatrix} ; K has k(\tilde X, \tilde X') , G has \partial_i k(\tilde X, \tilde X') , H has \partial_i \partial_{i+d} k(\tilde X, \tilde X')
  • Dropping kernel matrix gives \frac1{\lambda} \hat{\mathbb E}_{\tilde X}[ \eta(\tilde X)^2 + \lVert \nabla \eta(\tilde X) \rVert^2]
  • Solving this linear system takes m^3 d^3 time!

The Scaled MMD

  • Using a bit of RKHS theory, can write \lVert f \rVert_{S(\mu),\lambda}^2 = \langle f, D_\lambda f \rangle_\h \fragment{\le \lVert D_\lambda \rVert_\mathrm{HS} \lVert f \rVert_\h^2} \fragment{\le \sigma_{k,\mu,\lambda}^{-2} \lVert f \rVert_\h^2} \!\!\!\! \sigma_{k,\mu,\lambda}^{-2} = \int k(x, x) \mu(\mathrm{d}x) + \sum_{i=1}^d \int \partial_i \partial_{i+d} k(x, x) \mu(\mathrm{d}x) + \lambda
  • Define lower bound on S_{k,\mu,\lambda} of \begin{align} \smmd_{k,\mu,\lambda}(\PP, \QQ) &= \sup_{\sigma_{k,\mu,\lambda}^{-1} \lVert f \rVert_\h \le 1} \E_{X \sim \PP}[f(X)] - \E_{Y \sim \QQ}[f(Y)] \\ &\fragment{= \sigma_{k,\mu,\lambda} \mmd_k(\PP, \QQ)} \end{align}

Scaled MMD vs MMD with Gradient Penalty

  • When k_\psi(x, y) = \exp\left( - \frac12 \lVert \phi_\psi(x) - \phi_\psi(y) \rVert^2 \right) , \smmd_{k,\mu,\lambda}^2 = \frac{\mmd_{k}^2}{1 + \E_\mu[ \lVert \nabla \phi_\psi(X) \rVert^2] + \lambda}



Rank collapse

  • Optimization failure we sometimes see on SMMD, GC-MMD:
    • Generator doing reasonably well
    • Critic filters become low-rank
    • Generator corrects it by breaking everything else
    • Generator gets stuck

Spectral parameterization [Miyato+ ICLR-18]

  • W = \gamma \bar W / \lVert \bar W \rVert_\text{op} ; learn \gamma and W freely
  • Encourages diversity without limiting representation

What if we just did spectral normalization?

  • W = \bar W / \lVert \bar W \rVert_\text{op} , so that \lVert W \rVert_\text{op} = 1 , \lVert \varphi_\psi \rVert_L \le 1
  • Works well for original GANs [Miyato+ ICLR-18]
  • …but doesn't work at all as only constraint in a WGAN
  • Limits representation too much
    • In the toy problem, constrains to \phi_\psi(x) = x
    • \lVert x \mapsto \sigma(W_n \cdots \sigma(W_1 x)) \rVert_L ≪ \lVert W_n \rVert_\text{op} \cdots \lVert W_1 \rVert_\text{op}

Implicit generative model evaluation

  • No likelihoods, so…how to compare models?
  • Main approach:
    look at a bunch of pictures and see if they're pretty or not
    • Easy to find (really) bad samples
    • Hard to see if modes are missing / have wrong probabilities
    • Hard to compare models beyond certain threshold
  • Need better, quantitative methods

Inception score [Salimans+ NIPS-16]

  • Current standard quantitative method
  • Based on ImageNet classifier label predictions
    • Classifier should be confident on individual images
    • Predicted labels should be diverse across sample
  • No notion of target distribution \PP
  • Scores completely meaningless on LSUN, Celeb-A, SVHN, …
  • Not great on CIFAR-10 either

Fréchet Inception Distance (FID) [Heusel+ NIPS-17]

  • Fit normals to Inception hidden layer activations of \PP and \QQ
  • Compute Fréchet (Wasserstein-2) distance between fits
  • Meaningful on not-ImageNet datasets
  • Estimator extremely biased, tiny variance
  • \operatorname{FID}(\PP_1, \QQ) < \operatorname{FID}(\PP_2, \QQ) , \E \operatorname{FID}(\hat \PP_1, \QQ) > \E \operatorname{FID}(\hat \PP_2, \QQ)

New method: Kernel Inception Distance (KID)

  • \mmdhat^2 between Inception hidden layer activations
  • Use default polynomial kernel: k(x, y) = \left( \frac1d \langle x, y \rangle + 1 \right)^3
  • Unbiased estimator: more able to compare estimates
  • Reasonable estimates with fewer samples

Automatic learning rate adaptation with KID

  • Models need appropriate learning rate schedule to work well
  • Automate with three-sample MMD test [Bounliphone+ ICLR-16]:
.116 ± .002
.026 ± .001
.032 ± .001
.027 ± .001
LSUN BedroomsSmallBig
.370 ± .003
.039 ± .002
.091 ± .002
.028 ± .002

Training on 160 \times 160 CelebA

160 \times 160 CelebA Samples


KID: 0.006

KID: 0.022

64 \times 64 ImageNet Samples


KID: 0.035

KID: 0.044

KID: 0.047

Demystifying MMD GANs [ICLR-18]
Mikołaj Bińkowski*, Dougal J. Sutherland*, Michael Arbel, Arthur Gretton

  • MMD GANs do some optimization in closed form
    • Can handle smaller critic networks
  • Bias situation is the same between WGAN and MMD GAN
  • Evaluation and learning rate adaptation with KID

On gradient regularizers for MMD GANs [arXiv:1805.11565]
Michael Arbel, Dougal J. Sutherland, Mikołaj Bińkowski, Arthur Gretton

  • Gradient control is important
    • Scaled MMD does it in closed form, seems to help a lot
  • Spectral normalization plays nice with SMMD

Estimator bias

  • Bellemare+ [2017] say that:
    • WGANs have biased gradients
    • which can lead SGD to wrong minimum even in expectation
    • but Cramér GANs have unbiased gradients
      • Cramér GAN \approx MMD GAN with particular kernel
  • We show:
    • Gradients of fixed critic, \nabla_{\psi\theta} \D_\psi(X, G_\theta(Z)) , are unbiased
    • Gradients of optimized critic, \nabla_\theta \D(X, G_\theta(Z)) , are biased
    • Exact same situation for WGAN and MMD GAN

Unbiasedness theorem for fixed critic

  • For almost all feedforward architectures G , h
    • Works for ReLU, max-pooling, …
  • For any distributions \PP , \ZZ with \E[ \lVert X \rVert^2 ] < \infty ,
  • For most kernels k_\mathrm{top} used in practice,
    • Includes linear kernel, RBF, RQ, distance kernel, …
  • For Lebesgue-almost all parameters (\theta, \psi) :
  • % \E \nabla_\psi \phi_\psi(X) % = \nabla_\psi \E \phi_\psi(X), \E \nabla_{\psi,\theta} \phi_\psi(G_\theta(Z)) = \nabla_{\psi,\theta} \E \phi_\psi(G_\theta(Z)) , so:
    • \nabla_{\psi} \phi_\psi(X) , \nabla_{\psi,\theta} \phi_\psi(G_\theta(Z)) unbiased (WGANs)
    • \nabla_{\psi,\theta} \mmdhat^2(\phi_\psi(X), \phi_\psi(G_\theta(Z))) unbiased

Proof of unbiasedness theorem

  • Can't use standard argument: needs \phi_\psi(X) differentiable in a neighborhood of \psi for almost all inputs X
  • But take \phi_\psi(X) = \max(0, \psi \cdot X)
  • Fixed \psi : differentiable everywhere but \psi \cdot X = 0
  • \{ X \mid \psi' \cdot X = 0; \psi' \in B(\theta, r) \} might have probability >0

Proof of unbiasedness theorem

  • Instead: Show gradient exchanges when \phi_\psi differentiable in \psi for almost all inputs, using Lipschitz properties of the network
  • Then show the set of non-differentiable \psi has 0 measure, with a bit of geometry
  • Example: for h_{\psi}(X) = \max(0,\psi X) , \psi = 0 is a critical point:

Proof of unbiasedness theorem

  • By Fubini theorem, only need to show:
    • For fixed input X , \phi_{\psi}(X) differentiable for almost all \theta
  • Recall: \phi_{\psi}^1(X) = \sigma (\phi_{\psi}^0(X)) , with piecewise smooth \sigma :
  • case 1: \phi_{\psi}^{0}(X) inside domain of analyticity
  • case 2: \phi_{\psi}^{0}(X) on boundary
  • case 3: \phi_{\psi}^{0}(X) crosses the boundary:
    • \psi in a union of manifolds of 0 measure

Gradient bias for “full” loss

  • Recall \D(\PP, \QQ) = \sup_\psi \D_\psi(\PP, \QQ)
    • Estimator splits data:
      • Pick \hat\psi on train set, estimate \D_{\hat\psi} on test set
      • GAN test set: current minibatch
  • Showed: \nabla_{(\psi,\theta)} \hat\D_\psi(X, G_\theta(Z)) unbiased for any fixed \psi
  • Now: \nabla_\theta \hat\D(X, G_\theta(Z)) is biased
    • Estimators have non-constant bias iff gradients are biased
    • Will show \hat\D(X, G_\theta(Z)) is biased

Bias of \hat\D

  • Eval on test set is unbiased: \E[\hat\D_\psi(X, Y)] = \D_\psi(\PP, \QQ)
  • But training introduces bias:
    • \D_\psi(\PP, \QQ) \le \sup_\psi \D_\psi(\PP, \QQ) = \D(\PP, \QQ)
    • If \Pr(\hat\psi \text{ is optimal}) \ne 1 , estimator must be biased down
  • Probably not a big deal in practice
    • No (direct) bias due to minibatch size
    • Can decrease bias by training critic longer
  • But informs theory
    • Convergence based on SGD of \hat\D made difficult

Non-existence of an unbiased IPM estimator

  • Beautiful argument of Bickel & Lehman [Ann. Math. Stat 1969]:
  • Let \PP_0 \ne \PP_1 , \PP_\alpha = (1 - \alpha) \PP_0 + \alpha \PP_1
  • Suppose \E_{\substack{\mathbf X \sim \PP_\alpha^m \\ Y \sim \QQ^n}} \hat\D(\mathbf X, \mathbf Y) = \D(\PP_\alpha, \QQ) ; then \begin{align} \textstyle %R(\alpha) \D(\PP_\alpha, \QQ) &= \iiint %_{X_1, \cdots, X_m} \int_{\mathbf Y} \hat\D(\mathbf X, \mathbf Y) [(1 - \alpha) \mathrm{d}\PP_0(X_1) + \alpha \mathrm{d}\PP_1(X_1)] \cdots \mathrm{d}\QQ(\mathbf Y) \\&= (1 - \alpha)^m \E_{\substack{\mathbf X \sim \PP_0^m \\ \mathbf Y \sim \QQ^n}}[ \hat\D(\mathbf X, \mathbf Y) ] + \cdots + \alpha^m \E_{\substack{\mathbf X \sim \PP_1^m \\ \mathbf Y \sim \QQ^n}}[ \hat\D(\mathbf X, \mathbf Y) ] \end{align}
  • But \D(\PP_\alpha, \PP_{\frac12}) = \lvert \frac12 - \alpha \rvert \D(\PP_0, \PP_1) isn't a polynomial
  • So no unbiased \hat\D can exist, though \hat\D^2 could