Bures-Wasserstein Gradient Flow

Author

Arie Ogranovich

The Bures-Wasserstein gradient flow is a Wasserstein gradient flow in the space of Gaussian measures. Descriptions of Bures-Wasserstein gradient flows take advantage of the simplicity of transport maps between Gaussian distributions in order to provide an alternate, computationally tractable viewpoint on certain types of Wasserstein gradient flows (Chen, Georgiou, and Tannenbaum 2019; Delon and Desolneux 2020).

Definition of Bures-Wasserstein Space and Gradient

The Bures-Wasserstein Space \(\text{BW}(\mathbb R^d)\) is the Riemannian manifold of non-degenerate Gaussian distributions over \(\mathbb R^d\) endowed with the Wasserstein metric.

Letting \(S^d_{++}\) denote the cone of symmetric positive definite \(d \times d\) matrices, \(\text{BW}(\mathbb R^d)\) can be identified with the set \(\mathbb R^d \times S^d_{++}\) by identifying a Gaussian \(\mu\) with its mean \(\mathbb E[\mu] \in \mathbb R^d\) together with its variance \(\Sigma_{\mu} \in S^d_{++}\).

The space \(\mathbb R^d \times S^d_{++}\) carries a smooth structure, and the Wasserstein metric is a smooth metric on this space which is what gives \(\text{BW}(\mathbb R^d)\) the structure of a Riemannian manifold (in constrast to Wasserstein spaces \(\mathcal{P}_2(\mathbb R^d)\) or \(\mathcal{P}_{2,\text{ac}}(\mathbb R^d)\), see (Chewi, Niles-Weed, and Rigollet 2024)).

Properties of \(\text{BW}(\mathbb R^d)\)

As a smooth manifold, the tangent space \(T_{\mu}\text{BW}(\mathbb R^d)\) has a convenient characterization.

Tangent Space of \(\text{BW}(\mathbb R^d)\)

The tangent space of \(\text{BW}(\mathbb R^d)\) at a measure \(\mu = \mathcal{N}(m,\Sigma)\) satisfies \[T_{\mu}\text{BW}(\mathbb R^d)= \big\{x \rightarrow Sx+a \,\big|\, a \in \mathbb R^d, S \in S^d\big\}\] or alternatively \[T_{\mu}\text{BW}(\mathbb R^d)= \big\{x \rightarrow S(x-m)+a \,\big|\, a \in \mathbb R^d, S \in S^d\big\}.\]

It is also possible to analytically characterize the optimal transport plan between two non-degenerate Gaussian distributions in terms of their means and covariances.

Optimal Transport Plans and Geodesic Convexity

Let \(\mu_1, \mu_2 \in \text{BW}(\mathbb R^d)\) be two Gaussian measures with \(\mu_1 = \mathcal{N}(m_1, \Sigma_1)\) and \(\mu_2 = \mathcal{N}(m_2, \Sigma_2)\). The optimal transport map \(T_{\mu_1 \rightarrow \mu_2}\) is affine-linear and given by \[T_{\mu_1 \rightarrow \mu_2}(x) = \Sigma_1^{-1/2} \Big( \Sigma_1^{1/2} \Sigma_2 \Sigma_1^{1/2} \Big)^{1/2} \Sigma_1^{-1/2} (x-m_1) + m_2\] with Wasserstein distance given by \[W_2^2(\mu_1,\mu_2) = ||m_1-m_2||^2 + \text{tr}\Big[ \Sigma_1 + \Sigma_2 - 2 \big( \Sigma_1^{1/2} \Sigma_2 \Sigma_1^{1/2} \big)^{1/2} \Big]\]

In particular, the Benamou-Brenier Theorem (https://arxiv.org/pdf/2407.18163) implies that the Wasserstein geodesic, i.e the displacement interpolation, between Gaussian measures \(\mu_1 = \mathcal{N}(m_1, \Sigma_1)\) and \(\mu_2 = \mathcal{N}(m_2, \Sigma_2)\) is given as the curve \[\mu_t = \Big[(1-t) x + t\Sigma_1^{-1/2} \Big( \Sigma_1^{1/2} \Sigma_2 \Sigma_1^{1/2} \Big)^{1/2} \Sigma_1^{-1/2} (x-m_1) + m_2\Big]\#\mu_0.\] This means that \(\mu_t\) is the pushforward of \(\mu_0\) with respect to an invertible affine-linear map, which implies that \(\mu_t\) is itself a nondegenerate Gaussian measure. Then the geodesic \((\mu_t)_{t\in [0,1]}\) is contained inside \(\text{BW}(\mathbb R^d)\), which gives us that \(\text{BW}(\mathbb R^d)\) is geodesically convex.

Characterizations of the Bures-Wasserstein Gradient Flow

We are now prepared to compute the Bures-Wasserstein gradient and the corresponding gradient flow. Recall that \(\mathcal{P}_{2,\text{ac}}(\mathbb R^d)\) has its own tangent structure and the tangent space \(T_{\mu}\text{BW}(\mathbb R^d)\) is a subspace of \(T_{\mu}\mathcal{P}_{2,\text{ac}}(\mathbb R^d)\), which means the Bures-Wasserstein gradient is the projection of the Wasserstein gradient to \(T_{\mu}\text{BW}(\mathbb R^d)\). The Bures-Wasserstein gradient can be characterized in terms of first variations based on the variational characterization of the Wasserstein gradient. From this characterization it is possible to describe the Bures-Wasserstein gradient flow in terms of the mean and covariance.

Variational Characterization

Let \(\mathcal{F} : \mathcal{P}_{2,\text{ac}}(\mathbb R^d)\rightarrow \mathbb R\) be a functional with first variation \(\delta \mathcal{F}(\mu)\) at \(\mu = \mathcal{N}(m,\Sigma) \in \text{BW}(\mathbb R^d)\). Then the Bures-Wasserstein gradient of \(\mathcal{F}\) at \(\mu\) is the affine mapping \[x \mapsto \bigg( \int\nabla^2 \delta \mathcal{F}(\mu) d\mu\bigg)(x-m) + \int\nabla \delta \mathcal{F}(\mu) d\mu\] where \(m\) is the mean of \(\mu\).

Characterization in terms of Mean and Variance

The Bures-Wasserstein gradient flow of the functional \(\mathcal{F}\) is the curve \((\mu_t=\mathcal{N}(m_t,\Sigma_t))_{t \geq 0}\), where \[\begin{align*} \dot{m}_t &= - \mathbb E\nabla \delta \mathcal{F}(\mu_t) (X_t)\\ \dot{\Sigma}_t &= -\mathbb E\nabla^2 \delta \mathcal{F}(\mu_t) (X_t)\Sigma_t - \Sigma_t \mathbb E\nabla^2 \delta \mathcal{F}(\mu_t) (X_t) \end{align*}\] where \(X_t \sim \mu_t\).

The theorem stated above provides a characterization of the Bures-Wasserstein gradient flow based entirely on the mean and covariance of each measure \(\mu_t\) in the gradient flow. This can also be viewed as a Lagrangian or particle-based interpretation of the gradient flow.

Extensions and Applications

Wasserstein Gradients flows over Gaussian Mixture Models

Bures-Wasserstein gradients can be used to provide an alternative interpretation of gradient flows in \(\mathcal{P}_{2,\text{ac}}(\mathbb R^d)\) by treating probability measures as mixtures of (possibly degenerate) Gaussians. Traditionally, a Gaussian mixture is a measure whose density is a convex combination of Gaussian densities. This is generalized by defining a Gaussian mixture in terms of what is known as a mixing measure \(\nu \in \mathcal{P}_2(\text{BW}(\mathbb R^d)) \cong \mathcal{P}_2(\mathbb R^d \times S^d_{++})\), so that the corresponding Gaussian mixture \(G_{\nu}\) is defined as \(G_{\nu} = \int \mathcal{N}(m,\Sigma) \nu(dm, d\Sigma)\). One may treat \(\nu\) as the distribution of mean and covariance variables \(m, \Sigma\), i.e \(\nu = \text{law}(m,\Sigma)\) for some random variables \(m\) and \(\Sigma\).

A Wasserstein metric can be defined on \(\mathcal{P}_2(\text{BW}(\mathbb R^d))\) in the same fashion as \(\mathcal{P}_2(\mathbb R^d)\) is defined, with the underlying metric space changed but the fundamental concepts unaltered. The mapping \(\nu \mapsto G_{\nu}\) provides a surjective mapping from \(\mathcal{P}_2(\text{BW}(\mathbb R^d))\) to \(\mathcal{P}_2(\mathbb R^d)\), which allows us to reinterpret gradient flows of a functional \(\mathcal{F} : \mathcal{P}_2(\mathbb R^d)\rightarrow \mathbb R\) in terms of the functional \(\mathcal{G}: \nu \mapsto \mathcal{F}(G_{\nu})\) via the proposition below.

Gradient Flow Expression in \(\mathcal{P}_2(\text{BW}(\mathbb R^d))\)

Let \(\nu \in \mathcal{P}_2(\text{BW}(\mathbb R^d))\) be given and let \(G_{\nu} = \int \mathcal{N}(m,\Sigma) \nu(dm, d\Sigma)\) denote the Gaussian mixture corresponding to \(\nu\). Let \(\mathcal{F}\) be a functional over \(\mathcal{P}_2(\mathbb R^d)\) and let \(\mathcal{G}\) be the functional over \(\mathcal{P}_2(\text{BW}(\mathbb R^d))\) defined by \(\mathcal{G}(\nu) = \mathcal{F}(G_{\nu})\). The Wasserstein gradient flow of \(\mathcal{G}\) in \(\mathcal{P}_2(\text{BW}(\mathbb R^d))\) is given by the curve \((\text{law}(m_t,\Sigma_t))_{t \geq 0}\) for random variables \(m_t\) and \(\Sigma_t\) that follow the equations \[\begin{align*} \dot{m}_t &= - \mathbb E\nabla \delta \mathcal{F}(G_t) (X_t)\\ \dot{\Sigma}_t &= -\mathbb E\nabla^2 \delta \mathcal{F}(G_t) (X_t)\Sigma_t - \Sigma_t \mathbb E\nabla^2 \delta \mathcal{F}(G_t) (X_t) \end{align*}\] where \(X_t \sim \mathcal{N}(m_t,\Sigma_t)\).

Gaussian Variational Inference

Bures-Wasserstein gradient flows can be used to solve variational inference problems when the set of admissible measures is \(\text{BW}(\mathbb R^d)\) (Lambert et al. 2023).

The variational inference is an optimization problem which seeks to find a simple distribution in an admissible set \(\Omega\) which closely approximates a given measure \(\pi \in \mathcal{P}_2(\mathbb R^d)\) in terms of KL divergence. More precisely, variational inference is the problem of finding a measure \(q_*\) such that \[q_* = \underset{q \in \Omega}{\arg\min}\, \text{KL}(q || \pi).\] Gaussian variational inference is the problem \[q_* = \underset{q \in \text{BW}(\mathbb R^d)}{\arg\min}\, \text{KL}(q || \pi).\]

Recent work (cite https://arxiv.org/pdf/2205.15902) takes a gradient flow approach to solving Gaussian variational inference. By computing the Bures-Wasserstein gradient flow of the functional $ = (,,|| ) $, one can discretize the gradient flow and obtain a numerical algorithm to approximate a local minimizer of \(\mathcal{F}\).

In order to do so, one must derive the Bures-Wasserstein gradient of \(\mathcal{F}\). This is done via the variational characterization of Bures-Wasserstein gradients. For an absolutely continuous measure \(\pi\) with a density \(\exp(-V)\) (\(V\) is called the potential), the first variation of $ = (,,|| ) $ is \[\nabla\delta\mathcal{F}(q) = \nabla V + \nabla \log q\] and \[\nabla^2\delta\mathcal{F}(q) = \nabla^2 V + \nabla^2 \log q\] where \(\log q(x)\) denotes the logarithm of the density of \(q\) at \(x\).

By the formula for Bures-Wasserstein gradients, we see that \[\begin{align*} \nabla_{\text{BW}}\mathcal{F}(q)(x) &= \bigg( \int \big( \nabla^2 V + \nabla^2 \log q \big) dq \bigg) (x-m_q) + \int \big( \nabla V + \nabla \log q \big) dq\\ &= \bigg( \int \big( \nabla^2 V + \nabla^2 \log q \big) dq \bigg) (x-m_q) + \int \nabla V dq + \int \nabla \log q dq. \end{align*}\] Because \(q = \mathcal{N}(m,\Sigma)\) is Gaussian we know \(\log q (x) = - \frac{1}{2}(x-m)\Sigma^{-1}(x-m) + c\) where \(c\) is some constant. This implies \(\log q\) symmetric about the mean \(m\) of \(q\), \(\nabla \log q\) is antisymmetric about \(m\), and hence the integral \(\int_{\mathbb R} \nabla \log q dq = 0\). Thus, \[\begin{align*} \nabla_{\text{BW}}\mathcal{F}(q)(x) &= \bigg( \int \big( \nabla^2 V + \nabla^2 \log q \big) dq \bigg) (x-m_q) + \int \nabla V dq. \end{align*}\] Further, since \(\log q (x) = (x-m)\Sigma^{-1}(x-m) + c\) we know that \(\nabla^2 \log q (x) = -\Sigma^{-1}\), so we conclude that \[\begin{align*} \nabla_{\text{BW}}\mathcal{F}(q)(x) &= \bigg( \int \nabla^2 V dq - \Sigma^{-1} \bigg) (x-m_q) + \int \nabla V dq. \end{align*}\]

In addition, since \(\text{KL}(\,\cdot\,|| \pi)\) is convex on \(\mathcal{P}_2(\mathbb R^d)\), its restriction to \(\text{BW}(\mathbb R^d)\) is also a convex functional because \(\text{BW}(\mathbb R^d)\) is geodesically convex. This implies that \(\text{KL}(q_t|| \pi)\) will converge to \(\text{KL}(q_*|| \pi)\) exponentially for any gradient flow \((q_t)_{t\geq 0}\) of \(\text{KL}(\,\cdot\,|| \pi)\).

References

Chen, Yongxin, Tryphon T. Georgiou, and Allen Tannenbaum. 2019. “Optimal Transport for Gaussian Mixture Models.” IEEE Access 7: 6269–78. https://doi.org/10.1109/ACCESS.2018.2889838.
Chewi, Sinho, Jonathan Niles-Weed, and Philippe Rigollet. 2024. “Statistical Optimal Transport.” https://arxiv.org/abs/2407.18163.
Delon, Julie, and Agnès Desolneux. 2020. “A Wasserstein-Type Distance in the Space of Gaussian Mixture Models.” SIAM Journal on Imaging Sciences 13 (2): 936–70. https://doi.org/10.1137/19M1301047.
Lambert, Marc, Sinho Chewi, Francis Bach, Silvère Bonnabel, and Philippe Rigollet. 2023. “Variational Inference via Wasserstein Gradient Flows.” https://arxiv.org/abs/2205.15902.