Skip to content

Mirror Descent

Bregman Divergence

In the previous lecture, we introduced the proximal perspective of gradient descent. To minimize \(f(x)\), we approximate the objective function \(f(x)\) around \(x=x_t\) using a quadratic function:

\[ f(x) \approx f(x_t)+\langle \nabla f(x_t), x-x_t\rangle + \frac{1}{2\eta_t}\|x-x_t\|_2^2 \]

This is composed of the first-order Taylor expansion and a proximal term. For constrained optimization \(\min_{x\in \mathcal{X}}f(x)\), starting at \(x_t\), we update the next step by minimizing this quadratic approximation:

\[ x_{t+1} = \arg\min_{x\in\mathcal{X}}\left\{ f(x_t)+\langle \nabla f(x_t), x-x_t\rangle + \frac{1}{2\eta_t}\|x-x_t\|_2^2 \right\} \]

If there are no constraints, i.e., \(\mathcal{X} = \mathbb{R}^d\), this step simplifies to \(x_{t+1} = x_t -\eta_t\nabla f(x_t)\). Otherwise, it becomes projected gradient descent. Without the proximal term, it reduces to the Frank-Wolfe algorithm. The proximal term \(\frac{1}{2\eta_t}\|x-x_t\|_2^2\) prevents \(x_{t+1}\) from straying too far from \(x_t\). A natural question arises: why use the \(\ell_2\)-norm in the proximal term? Can we use another distance?

Example: Quadratic Optimization

Consider the quadratic optimization problem:

\[ \min_{x\in \mathbb{R}^d}f(x) = \min_{x\in \mathbb{R}^d} \frac{1}{2}(x - x^*)^\top Q (x - x^*) \]

where \(Q\) is a positive definite matrix.

Using the \(\ell_2\)-norm in the proximal term, we have gradient descent:

\[ x_{t+1} = x_t - \eta_t Q(x_t-x^*) \]

Quadratic optimization In figure above, the trajectory of gradient descent is zigzag. This zigzag pattern occurs because the \(\ell_2\)-norm is not the ideal distance for the objective \(f(x)\). The contour is scaled by matrix \(Q\). What if we use the norm \(\|x\|_Q^2 = x^\top Q x\) in the proximal term? Then we update \(x_{t+1}\) as:

\[ x_{t+1} = x_t - \eta_t Q^{-1}\nabla f(x_t) = x_t - \eta_t (x_t-x^*) \]

Quadratic optimization In figure above, the descent direction directly points to the minimizer \(x^*\), resulting in a much faster algorithm.

Bregman Divergence

The previous example shows the need for a better distance fitting the problem's geometry. This motivates the definition of Bregman divergence.

Definition (Bregman Divergence): Let \(\varphi: \mathcal{X} \rightarrow \mathbb{R}\) be a convex and differentiable function. The Bregman divergence between \(x\) and \(z\) is:

\[ D_{\varphi}(x,z) = \varphi(x) - \varphi(z) - \langle \nabla\varphi(z), x - z\rangle \]

By convexity, \(D_{\varphi}(x,z) \ge 0\), making it a type of distance. The Bregman divergence generalizes the quadratic norm \(\|\cdot\|_Q\).

The table below shows some common Bregman divergences.

Function Name \(\phi(x)\) \(\text{dom } \phi\) \(D_{\phi}(x, y)\)
Squared norm \(\frac{1}{2} x^2\) \((-\infty, +\infty)\) \(\frac{1}{2} (x - y)^2\)
Shannon entropy \(x \log x - x\) \([0, +\infty)\) \(x \log \frac{x}{y} - x + y\)
Bit entropy \(x \log x + (1 - x) \log(1 - x)\) \([0, 1]\) \(x \log \frac{x}{y} + (1 - x) \log \frac{1 - x}{1 - y}\)
Burg entropy \(-\log x\) \((0, +\infty)\) \(\frac{x}{y} - \log \frac{x}{y} - 1\)
Hellinger \(-\sqrt{1 - x^2}\) \([-1, 1]\) \((1 - xy)(1 - y^2)^{-1/2} - (1 - x^2)^{1/2}\)
Exponential \(\exp x\) \((-\infty, +\infty)\) \(\exp x - (x - y + 1) \exp y\)
Inverse \(1/x\) \((0, +\infty)\) \(\frac{1}{x} + \frac{x}{y^2} - \frac{2}{y}\)

Mirror Descent

Replacing the \(\ell_2\)-norm in the proximal term with the Bregman divergence, we have:

\[ x_{t+1} = \arg\min_{x\in\mathcal{X}}\left\{ \langle \nabla f(x_t), x\rangle + \frac{1}{\eta_t}D_{\varphi}(x,x_t) \right\} \]

This leads to the mirror descent algorithm:

Mirror Descent Algorithm:

\[ x_{t+1} = \arg\min_{x\in\mathcal{X}}\left\{ \langle \nabla f(x_t), x\rangle + \frac{1}{\eta_t}D_{\varphi}(x,x_t) \right\} \]

This example shows the importance of considering Bregman divergence due to the objective function's geometry. However, in most cases, we need a proper Bregman divergence due to the constraint's geometry. The following example illustrates choosing the right Bregman divergence under a specific constraint.

Example: Probability Simplex

In many statistical problems, we estimate the probability mass function of a discrete distribution. The parameters are constrained to the probability simplex:

\[ \Delta = \left\{x \in \mathbb{R}^d ~\Big|~ \sum_{i=1}^d x_i = 1, x_i \ge 0 \text{ for all } i = 1, \ldots, d\right\} \]

Simplex

A widely used Bregman divergence for the probability simplex uses \(\varphi\) as the negative entropy:

\[ \varphi(x) = \sum_{i=1}^d x_i\log x_i \]

The corresponding Bregman divergence becomes the Kullback–Leibler (KL) divergence:

\[ D_{\varphi}(x,z) = \sum_{i=1}^d x_i\log \frac{x_i}{z_i} \]

MD

This is also known as \(D_{\rm KL}(x\|z)\). The maximum log-likelihood estimator is essentially finding a distribution \(P_{\theta}\) closest under the KL-divergence to the true distribution \(P_{\theta^*}\).

Therefore, for the constrained optimization problem

\[ \min_{x\in \Delta} f(x) \]

the mirror descent algorithm with the KL-divergence is:

\[ x_{t+1} = \arg\min_{x\in\Delta}\left\{ \langle \nabla f(x_t), x\rangle + \frac{1}{\eta_t}D_{\rm KL}(x\|x_t) \right\} \]

And it has a closed-form solution:

Mirror Descent for Probability Simplex

\[ x_{t+1} = \text{softmax}\left(\frac{1}{\eta_t}\nabla f(x_t)\right) = \frac{\exp\left(\frac{1}{\eta_t}\nabla f(x_t)\right)}{\sum_{i=1}^d \exp\left(\frac{1}{\eta_t}\nabla f(x_t)_i\right)} \]

The algorithm can be implemented as:

def f(x): # define the objective function
    ...

x = torch.ones(dim, requires_grad=True) / dim  # Initialize x in the probability simplex

# Mirror Descent Parameters
lr = 0.1  # Learning rate (η_t)
num_iters = 100  # Number of iterations
for t in range(num_iters):
    # Compute function value at x_t
    loss = f(x)
    # Compute gradient using autograd
    loss.backward()
    with torch.no_grad():
        # Compute softmax mirror descent update
        x_new = torch.softmax((1 / lr) * x.grad, dim=0)
        # Update variables
        x.copy_(x_new)  # In-place update to maintain tracking
        # Zero out gradients for the next iteration
        x.grad.zero_()

Summary for Constrained Optimization

Now consider the constrained optimization problem \(\min_{x\in \Delta} f(x)\). We have learned three algorithms for solving constrained optimization:

  1. Frank-Wolfe Algorithm: Essentially mirror descent with \(\varphi = 0\). It involves solving a linear programming sub-problem.
  2. Projected Gradient Descent: Uses the \(\ell_2\)-norm as the proximal term, involving a quadratic programming sub-problem.
  3. Mirror Descent: Using KL-divergence \(D_{\rm KL}(x\|x_t)\), the sub-problem has a closed-form solution.