Maximum Mean Discrepancy: Comparing Distributions with Kernels

math
statistics
machine learning
Published

June 22, 2026

I’ve recently had need to understand metrics assessing quality in some generative modeling research I’ve been involved in. One such metric is the Maximum Mean Discrepany (MMD). The best way to understand something is to write and/or teach it, so here it goes.

This post is based on my reading of the excellent article A Kernel Two-Sample Test (Gretton et al. 2012), which I believe is one of the original resources for the topic.

Broadly speaking, we want to understand a way to measure how different two probability distributions are using only samples. That is, we want a scalar quantity which answers the question

Do samples from model distribution \(Q\) look like samples from target distribution \(P\)?

Fundamentals

Suppose we have two distributions \(P\) and \(Q\) over some space \(\mathcal{X}\). The idea is to try and understand the average behavior of both distributions under some rich family of test functions. For any test function \(f\), we can compare the two expectations

\[ \mathbb{E}_{X \sim P}[f(X)] \quad\text{and}\quad \mathbb{E}_{Y \sim Q}[f(Y)]. \]

If these expectations differ for some meaningful function \(f\), then \(P\) and \(Q\) are different. We formulate the MMD by seeking the function \(f\) that separates the two distributions the most: \[ \operatorname{MMD}(P, Q) = \sup_{\lVert f \rVert_\mathcal{H} \le 1} \left( \mathbb{E}_{X \sim P}[f(X)] - \mathbb{E}_{Y \sim Q}[f(Y)] \right). \]

Here \(\mathcal{H}\) is a reproducing kernel Hilbert space (RKHS), i.e. a function space defined by a positive definite kernel \(k(x, y)\) which measures similarity between two points.

NoteRemark (The Gaussian/RBF Kernel)

Remark (The Gaussian/RBF Kernel). The Gaussian/RBF kernel is \[ k(x, y) = \exp\left(-\frac{\lVert x - y \rVert^2}{2\sigma^2}\right). \]

This kernel is close to 1 when \(x\) and \(y\) are close, and close to 0 when they are far apart. The bandwidth \(\sigma\) controls what “close” means.

Using this, we can write an alternative, attractice, definition of the MMD. Let’s first define a kernel mean embedding, \[ \mu_P(\cdot) = \mathbb{E}_{X \sim P}[k(X, \cdot)]. \]

This is the average feature representation of samples from \(P\). It can be shown that the MMD is the distance between these two embeddings: \[ \operatorname{MMD}(P, Q) = \lVert \mu_P - \mu_Q \rVert_\mathcal{H}. \]

Proposition 1 Let \(\mathcal{H}\) be a real RKHS with reproducing kernel \(k\). Assume the kernel mean embeddings \[ \mu_P = \mathbb{E}_{X \sim P}[k(X, \cdot)], \qquad \mu_Q = \mathbb{E}_{Y \sim Q}[k(Y, \cdot)] \] exist as elements of \(\mathcal{H}\). Then \[ \sup_{\lVert f \rVert_\mathcal{H} \le 1} \left( \mathbb{E}_{X \sim P}[f(X)] - \mathbb{E}_{Y \sim Q}[f(Y)] \right) = \lVert \mu_P - \mu_Q \rVert_\mathcal{H}. \]

Proof. The defining property of an RKHS is the reproducing property: for every \(f \in \mathcal{H}\) and every point \(x\), \[ f(x) = \langle f, k(x, \cdot) \rangle_\mathcal{H}. \]

Using this property under the expectations, \[ \begin{aligned} \mathbb{E}_{X \sim P}[f(X)] - \mathbb{E}_{Y \sim Q}[f(Y)] &= \mathbb{E}_{X \sim P}[\langle f, k(X, \cdot) \rangle_\mathcal{H}] - \mathbb{E}_{Y \sim Q}[\langle f, k(Y, \cdot) \rangle_\mathcal{H}] \\ &= \left\langle f, \mathbb{E}_{X \sim P}[k(X, \cdot)] - \mathbb{E}_{Y \sim Q}[k(Y, \cdot)] \right\rangle_\mathcal{H} \\ &= \langle f, \mu_P - \mu_Q \rangle_\mathcal{H}. \end{aligned} \]

The second equality uses linearity of the inner product and the definition of the mean embedding. So the original MMD definition is \[ \sup_{\lVert f \rVert_\mathcal{H} \le 1} \langle f, \mu_P - \mu_Q \rangle_\mathcal{H}. \]

Let \(g = \mu_P - \mu_Q\). For any \(f\) with \(\lVert f \rVert_\mathcal{H} \le 1\), Cauchy-Schwarz gives \[ \langle f, g \rangle_\mathcal{H} \le \lvert \langle f, g \rangle_\mathcal{H} \rvert \le \lVert f \rVert_\mathcal{H}\lVert g \rVert_\mathcal{H} \le \lVert g \rVert_\mathcal{H}. \] Thus the supremum is at most \(\lVert g \rVert_\mathcal{H}\).

If \(g \ne 0\), choose \[ f = \frac{g}{\lVert g \rVert_\mathcal{H}}. \] Then \(\lVert f \rVert_\mathcal{H} = 1\), so this is an allowed test function, and \[ \langle f, g \rangle_\mathcal{H} = \left\langle \frac{g}{\lVert g \rVert_\mathcal{H}}, g \right\rangle_\mathcal{H} = \lVert g \rVert_\mathcal{H}. \] So the upper bound is achieved. If \(g = 0\), then \(\langle f, g \rangle_\mathcal{H} = 0\) for every \(f\), and the supremum is also \(0 = \lVert g \rVert_\mathcal{H}\). Therefore the supremum definition and the distance-between-mean-embeddings definition are the same quantity. \(\square\)

For characteristic kernels such as the Gaussian RBF kernel, this distance is zero if and only if the two distributions are the same. That is why MMD can be used as a distribution-level discrepancy, not just a comparison of means in the original data space.

I’ve skipped a lot of details, mainly regarding assumptions, existance, and other technicalities. I refer the interested reading back to (Gretton et al. 2012) for these.

As an additional note, computationally, the following representation of the squared MMD is often attractive.

Proposition 2 The squared MMD has the expectation form \[ \operatorname{MMD}^2(P, Q) = \mathbb{E}_{X, X' \sim P}[k(X, X')] + \mathbb{E}_{Y, Y' \sim Q}[k(Y, Y')] - 2\mathbb{E}_{X \sim P, Y \sim Q}[k(X, Y)]. \] where \(X'\) is an independent copy of \(X\) and \(Y'\) is an independent copy of \(Y\).

Proof. By Proposition 1, \[ \operatorname{MMD}^2(P, Q) = \lVert \mu_P - \mu_Q \rVert_\mathcal{H}^2. \]

Expanding the squared norm gives \[ \begin{aligned} \lVert \mu_P - \mu_Q \rVert_\mathcal{H}^2 &= \langle \mu_P, \mu_P \rangle_\mathcal{H} + \langle \mu_Q, \mu_Q \rangle_\mathcal{H} - 2\langle \mu_P, \mu_Q \rangle_\mathcal{H}. \end{aligned} \]

Using the definitions \[ \mu_P = \mathbb{E}_{X \sim P}[k(X, \cdot)], \qquad \mu_Q = \mathbb{E}_{Y \sim Q}[k(Y, \cdot)], \] linearity of the inner product, and the reproducing property, \[ \begin{aligned} \langle \mu_P, \mu_P \rangle_\mathcal{H} &= \mathbb{E}_{X, X' \sim P} \left[ \langle k(X, \cdot), k(X', \cdot) \rangle_\mathcal{H} \right] \\ &= \mathbb{E}_{X, X' \sim P}[k(X, X')], \end{aligned} \] and similarly \[ \langle \mu_Q, \mu_Q \rangle_\mathcal{H} = \mathbb{E}_{Y, Y' \sim Q}[k(Y, Y')], \qquad \langle \mu_P, \mu_Q \rangle_\mathcal{H} = \mathbb{E}_{X \sim P, Y \sim Q}[k(X, Y)]. \]

Substituting these three identities into the squared norm expansion proves the formula. \(\square\)

A Tiny Numerical Example

Let’s compare two tiny one-dimensional sample sets: \[ X = \{0, 1, 2\}, \qquad Y = \{0, 1, 3\}. \]

The two sets share two points, but \(Y\) has a point at 3 where \(X\) has a point at 2. We will use an RBF kernel with \(\sigma = 1\).

Code
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics.pairwise import euclidean_distances, rbf_kernel

np.set_printoptions(precision=3, suppress=True)

X = np.array([[0.0], [1.0], [2.0]])
Y = np.array([[0.0], [1.0], [3.0]])
sigma = 1.0


def rbf_kernel_matrix(a, b, sigma):
    gamma = 1 / (2 * sigma**2)
    return rbf_kernel(a, b, gamma=gamma)


Kxx = rbf_kernel_matrix(X, X, sigma)
Kyy = rbf_kernel_matrix(Y, Y, sigma)
Kxy = rbf_kernel_matrix(X, Y, sigma)

print("Kxx: similarities within X")
print(Kxx)
print("\nKyy: similarities within Y")
print(Kyy)
print("\nKxy: similarities between X and Y")
print(Kxy)
Kxx: similarities within X
[[1.    0.607 0.135]
 [0.607 1.    0.607]
 [0.135 0.607 1.   ]]

Kyy: similarities within Y
[[1.    0.607 0.011]
 [0.607 1.    0.135]
 [0.011 0.135 1.   ]]

Kxy: similarities between X and Y
[[1.    0.607 0.011]
 [0.607 1.    0.135]
 [0.135 0.607 0.607]]

The diagonal entries are 1 because every point is exactly similar to itself. The off-diagonal entries decay as points get farther apart.

The simplest empirical estimate includes all entries of these kernel matrices: \[ \widehat{\operatorname{MMD}}^2_{\text{biased}} = \frac{1}{m^2}\sum_{i=1}^m\sum_{j=1}^m k(x_i, x_j) + \frac{1}{n^2}\sum_{i=1}^n\sum_{j=1}^n k(y_i, y_j) - \frac{2}{mn}\sum_{i=1}^m\sum_{j=1}^n k(x_i, y_j). \]

It is called the biased estimator because it includes the diagonal self-similarities \(k(x_i, x_i)\) and \(k(y_i, y_i)\). It is still useful and is always nonnegative up to floating point roundoff.

Code
def mmd2_biased(x, y, sigma):
    kxx = rbf_kernel_matrix(x, x, sigma)
    kyy = rbf_kernel_matrix(y, y, sigma)
    kxy = rbf_kernel_matrix(x, y, sigma)
    return kxx.mean() + kyy.mean() - 2 * kxy.mean()


mmd2 = mmd2_biased(X, Y, sigma)
print(f"biased MMD^2 = {mmd2:.6f}")
print(f"biased MMD   = {np.sqrt(mmd2):.6f}")
biased MMD^2 = 0.087438
biased MMD   = 0.295699

For this example the value is small but not zero. That matches the data: the samples mostly agree, but one point moved from 2 to 3.

The Impact of Bandwidth

MMD depends on both the samples and the kernel. The RBF bandwidth \(\sigma\) determines the scale at which we compare the distributions.

If \(\sigma\) is very small, then only nearly identical points count as similar. If \(\sigma\) is very large, then almost every point looks similar to every other point and the discrepancy can get washed out.

Code
sigmas = np.array([0.25, 0.5, 1.0, 2.0, 4.0])
values = np.array([mmd2_biased(X, Y, s) for s in sigmas])

for s, v in zip(sigmas, values):
    print(f"sigma={s:>4}: biased MMD^2={v:.6f}")
sigma=0.25: biased MMD^2=0.222148
sigma= 0.5: biased MMD^2=0.192148
sigma= 1.0: biased MMD^2=0.087438
sigma= 2.0: biased MMD^2=0.026112
sigma= 4.0: biased MMD^2=0.006837
Code
fig, ax = plt.subplots(figsize=(8,4),layout="constrained")
ax.plot(sigmas, values, marker="o")
ax.set_xscale("log")
ax.set_xlabel(r"RBF bandwidth $\sigma$")
ax.set_ylabel(r"biased $\widehat{\mathrm{MMD}}^2$")
ax.set_title("MMD depends on the comparison scale")
ax.grid(True, alpha=0.3)
plt.show()

There is no universally correct bandwidth. In practice, people often use the median pairwise distance heuristic, average several RBF kernels with different bandwidths, or tune the bandwidth for the application.

A Slightly Bigger Comparison

Now compare two small 2D sample clouds. The first cloud is centered near \((0, 0)\) and the second is shifted to the right. MMD should be small when we compare the cloud with itself and larger when we compare the original cloud with the shifted one.

Code
rng = np.random.default_rng(0)

A = rng.normal(loc=0.0, scale=0.7, size=(100, 2))
B = rng.normal(loc=np.array([0.8, 0.0]), scale=0.7, size=(100, 2))

pooled = np.vstack([A, B])
pairwise_distances = euclidean_distances(pooled)
median_sigma = np.median(pairwise_distances[pairwise_distances > 0])

aa = mmd2_biased(A, A, median_sigma)
ab = mmd2_biased(A, B, median_sigma)

print(f"median bandwidth = {median_sigma:.3f}")
print(f"MMD^2(A, A)      = {aa:.6f}")
print(f"MMD^2(A, B)      = {ab:.6f}")
median bandwidth = 1.242
MMD^2(A, A)      = 0.000000
MMD^2(A, B)      = 0.145871
Code
fig, ax = plt.subplots(figsize=(4.5,4),layout="constrained")
ax.scatter(A[:, 0], A[:, 1], s=24, alpha=0.75, label="A")
ax.scatter(B[:, 0], B[:, 1], s=24, alpha=0.75, label="B")
ax.set_aspect("equal", adjustable="box")
ax.set_title("Two sample clouds")
ax.legend()
ax.grid(True, alpha=0.25)
plt.show()

The self-comparison is exactly zero for the biased estimator because the same sample set appears on both sides. The shifted comparison is positive because the two clouds have different locations, which changes their kernel mean embeddings.

Biased vs. Unbiased Estimates

There is also an unbiased estimator: \[ \widehat{\operatorname{MMD}}^2_{\text{unbiased}} = \frac{1}{m(m-1)}\sum_{i \ne j} k(x_i, x_j) + \frac{1}{n(n-1)}\sum_{i \ne j} k(y_i, y_j) - \frac{2}{mn}\sum_{i=1}^m\sum_{j=1}^n k(x_i, y_j). \]

This removes the diagonal self-similarities. It is unbiased as an estimate of the population \(\operatorname{MMD}^2\), but for finite samples it can be slightly negative because it is no longer a squared norm exactly.

Code
def off_diagonal_mean(k):
    n = k.shape[0]
    return (k.sum() - np.trace(k)) / (n * (n - 1))


def mmd2_unbiased(x, y, sigma):
    kxx = rbf_kernel_matrix(x, x, sigma)
    kyy = rbf_kernel_matrix(y, y, sigma)
    kxy = rbf_kernel_matrix(x, y, sigma)
    return off_diagonal_mean(kxx) + off_diagonal_mean(kyy) - 2 * kxy.mean()


print(f"biased tiny example   = {mmd2_biased(X, Y, sigma):.6f}")
print(f"unbiased tiny example = {mmd2_unbiased(X, Y, sigma):.6f}")
print(f"biased A vs B         = {mmd2_biased(A, B, median_sigma):.6f}")
print(f"unbiased A vs B       = {mmd2_unbiased(A, B, median_sigma):.6f}")
biased tiny example   = 0.087438
unbiased tiny example = -0.345743
biased A vs B         = 0.145871
unbiased A vs B       = 0.138175

For optimization objectives, the biased estimator is often convenient because it behaves like a squared distance. For statistical testing, the unbiased estimator and permutation/bootstrap procedures are often more appropriate.

Optimized JAX Implementation

The sklearn implementation above is a clear baseline implementation. It builds three kernel matrices using rbf_kernel and then averages them. However, if you had to implement this outside of numpy, say in JAX, it’s worth thinking about what an optimized implementation would look like.

There are two useful JAX implementations, with different tradeoffs. The most direct implementation uses nested vmap over a scalar squared-distance function. This computes the distance as written, using the literal \((x - y)^2\) computation.

Code
from time import perf_counter

import jax
import jax.numpy as jnp


def squared_distance_direct(x, y):
    return jnp.sum((x - y) ** 2)


pairwise_squared_distances_vmap = jax.vmap(
    jax.vmap(squared_distance_direct, in_axes=(None, 0)),
    in_axes=(0, None),
)


def rbf_mean_jax_vmap(x, y, sigma):
    squared_distances = pairwise_squared_distances_vmap(x, y)
    return jnp.mean(jnp.exp(-squared_distances / (2 * sigma**2)))


@jax.jit
def mmd2_biased_jax_vmap(x, y, sigma):
    return (
        rbf_mean_jax_vmap(x, x, sigma)
        + rbf_mean_jax_vmap(y, y, sigma)
        - 2 * rbf_mean_jax_vmap(x, y, sigma)
    )

The faster implementation computes squared distances through matrix multiplication: \[ \lVert x_i - y_j \rVert^2 = \lVert x_i \rVert^2 + \lVert y_j \rVert^2 - 2x_i^\top y_j. \]

This would have the corresponding implementation

Code
def rbf_mean_jax_gram(x, y, sigma):
    x2 = jnp.sum(x * x, axis=1, keepdims=True)
    y2 = jnp.sum(y * y, axis=1, keepdims=True).T
    squared_distances = jnp.maximum(x2 + y2 - 2 * (x @ y.T), 0.0)
    return jnp.mean(jnp.exp(-squared_distances / (2 * sigma**2)))


@jax.jit
def mmd2_biased_jax_gram(x, y, sigma):
    return (
        rbf_mean_jax_gram(x, x, sigma)
        + rbf_mean_jax_gram(y, y, sigma)
        - 2 * rbf_mean_jax_gram(x, y, sigma)
    )

This can be far faster!

Here’s a small benchmark against the sklearn version. Note, scikit-learn is actually implementing the Gram verion above. The benchmark checks that the implementations produce the same biased MMD estimate on a fixed synthetic two-sample problem, then compares their median wall-clock runtime (post-JAX compilation).

Show timing helper
def timed(f):
    jax.block_until_ready(f())
    t = perf_counter()
    v = jax.block_until_ready(f())
    return v, perf_counter() - t


def median_timed(f, repeats=7):
    value = None
    seconds = []
    for _ in range(repeats):
        value, elapsed = timed(f)
        seconds.append(elapsed)
    return value, float(np.median(seconds))
Show benchmark code
bench_rng = np.random.default_rng(1)
bench_n = 2_000
bench_dim = 16

X_bench = bench_rng.normal(size=(bench_n, bench_dim)).astype(np.float32)
Y_bench = bench_rng.normal(loc=0.15, size=(bench_n, bench_dim)).astype(np.float32)
sigma_bench = np.float32(np.sqrt(bench_dim))

X_bench_jax = jnp.asarray(X_bench)
Y_bench_jax = jnp.asarray(Y_bench)
sigma_bench_jax = jnp.asarray(sigma_bench)

# Compile once before benchmarking.
jax.block_until_ready(
    mmd2_biased_jax_gram(X_bench_jax, Y_bench_jax, sigma_bench_jax)
)
jax.block_until_ready(
    mmd2_biased_jax_vmap(X_bench_jax, Y_bench_jax, sigma_bench_jax)
)

sklearn_value, sklearn_seconds = median_timed(
    lambda: mmd2_biased(X_bench, Y_bench, sigma_bench)
)
jax_gram_value, jax_gram_seconds = median_timed(
    lambda: mmd2_biased_jax_gram(X_bench_jax, Y_bench_jax, sigma_bench_jax)
)
jax_vmap_value, jax_vmap_seconds = median_timed(
    lambda: mmd2_biased_jax_vmap(X_bench_jax, Y_bench_jax, sigma_bench_jax)
)

print(f"JAX device       = {jax.devices()[0]}")
print(f"sample shape     = {X_bench.shape}")
print(f"sklearn MMD^2    = {sklearn_value:.8f}")
print(f"JAX Gram MMD^2   = {float(jax_gram_value):.8f}")
print(f"JAX vmap MMD^2   = {float(jax_vmap_value):.8f}")
print(f"sklearn time     = {sklearn_seconds * 1_000:.2f} ms")
print(f"JAX Gram time    = {jax_gram_seconds * 1_000:.2f} ms")
print(f"JAX vmap time    = {jax_vmap_seconds * 1_000:.2f} ms")
print(f"Gram speedup     = {sklearn_seconds / jax_gram_seconds:.2f}x")
print(f"vmap speedup     = {sklearn_seconds / jax_vmap_seconds:.2f}x")
JAX device       = cpu:0
sample shape     = (2000, 16)
sklearn MMD^2    = 0.00887656
JAX Gram MMD^2   = 0.00887656
JAX vmap MMD^2   = 0.00887656
sklearn time     = 83.09 ms
JAX Gram time    = 15.05 ms
JAX vmap time    = 88.92 ms
Gram speedup     = 5.52x
vmap speedup     = 0.93x

The Gram version is usually the fastest of these dense implementations. However, as discussed in this JAX thread on pairwise distances, the Gram version can be numerically unstable when the true distance is small relative to the absolute coordinate values. For example consider two points which are distance \(\sqrt{2}\) apart with a large absolute offset. The direct computation sees the distance, while the Gram calculation loses it.

Code
edge_offset = jnp.float32(4096.0)
edge_x = jnp.array([[0.0, 1.0]], dtype=jnp.float32) + edge_offset
edge_y = jnp.array([[1.0, 0.0]], dtype=jnp.float32) + edge_offset
edge_sigma = jnp.float32(1.0)

edge_gram_distance2 = (
    jnp.sum(edge_x * edge_x, axis=1, keepdims=True)
    + jnp.sum(edge_y * edge_y, axis=1, keepdims=True).T
    - 2 * (edge_x @ edge_y.T)
)
edge_direct_distance2 = pairwise_squared_distances_vmap(edge_x, edge_y)

edge_gram_kernel = jnp.exp(
    -jnp.maximum(edge_gram_distance2, 0.0) / (2 * edge_sigma**2)
)
edge_direct_kernel = jnp.exp(
    -edge_direct_distance2 / (2 * edge_sigma**2)
)

print(f"Gram squared distance   = {float(edge_gram_distance2[0, 0]):.1f}")
print(f"direct squared distance = {float(edge_direct_distance2[0, 0]):.1f}")
print(f"JAX Gram RBF            = {float(edge_gram_kernel[0, 0]):.8f}")
print(f"JAX direct RBF          = {float(edge_direct_kernel[0, 0]):.8f}")
Gram squared distance   = 0.0
direct squared distance = 2.0
JAX Gram RBF            = 1.00000000
JAX direct RBF          = 0.36787945

The same issue can change the MMD itself. Below, X_edge and Y_edge are not the same point cloud, but the unstable implementations make many cross-similarities look like self-similarities.

Code
X_edge =  jnp.array([[0.0, 1.0], [0.0, 2.0]], dtype=jnp.float32) + edge_offset
Y_edge =  jnp.array([[1.0, 0.0], [1.0, 2.0]], dtype=jnp.float32) + edge_offset
sigma_edge = jnp.float32(1.0)

edge_gram_mmd = mmd2_biased_jax_gram(
    X_edge, Y_edge, sigma_edge
)
edge_vmap_mmd = mmd2_biased_jax_vmap(
    X_edge, Y_edge, sigma_edge
)

print(f"JAX Gram MMD^2 = {float(edge_gram_mmd):.8f}")
print(f"JAX vmap MMD^2 = {float(edge_vmap_mmd):.8f}")
JAX Gram MMD^2 = 0.00000000
JAX vmap MMD^2 = 0.65874577

Practically speaking, that means that I would likely reach for the naive implementation until convinced that the MMD itself is a bottleneck. Alternatively, one could consider centering the data before the computation, as Euclidean distances are translation-invariant but the floating point cancellation in the Gram identity is not.

With that being said, one can construct estimators of the MMD by considering statistically principled subsamplings of the dataset. This is essentially the concentration of the back half of the paper (Gretton et al. 2012).

References

Gretton, Arthur, Karsten M. Borgwardt, Malte J. Rasch, Bernhard Schölkopf, and Alexander Smola. 2012. “A Kernel Two-Sample Test.” Journal of Machine Learning Research 13 (25): 723–73. https://jmlr.csail.mit.edu/papers/v13/gretton12a.html.