Skip to content

Mirror Descent

Bregman Divergence

In the previous lecture, we introduced the proximal perspective of gradient descent. To minimize f(x), we approximate the objective function f(x) around x=xt using a quadratic function:

f(x)f(xt)+f(xt),xxt+12ηtxxt22

This is composed of the first-order Taylor expansion and a proximal term. For constrained optimization minxXf(x), starting at xt, we update the next step by minimizing this quadratic approximation:

xt+1=argminxX{f(xt)+f(xt),xxt+12ηtxxt22}

If there are no constraints, i.e., X=Rd, this step simplifies to xt+1=xtηtf(xt). Otherwise, it becomes projected gradient descent. Without the proximal term, it reduces to the Frank-Wolfe algorithm. The proximal term 12ηtxxt22 prevents xt+1 from straying too far from xt. A natural question arises: why use the 2-norm in the proximal term? Can we use another distance?

Example: Quadratic Optimization

Consider the quadratic optimization problem:

minxRdf(x)=minxRd12(xx)Q(xx)

where Q is a positive definite matrix.

Using the 2-norm in the proximal term, we have gradient descent:

xt+1=xtηtQ(xtx)

Quadratic optimization In figure above, the trajectory of gradient descent is zigzag. This zigzag pattern occurs because the 2-norm is not the ideal distance for the objective f(x). The contour is scaled by matrix Q. What if we use the norm xQ2=xQx in the proximal term? Then we update xt+1 as:

xt+1=xtηtQ1f(xt)=xtηt(xtx)

Quadratic optimization In figure above, the descent direction directly points to the minimizer x, resulting in a much faster algorithm.

Bregman Divergence

The previous example shows the need for a better distance fitting the problem's geometry. This motivates the definition of Bregman divergence.

Definition (Bregman Divergence): Let φ:XR be a convex and differentiable function. The Bregman divergence between x and z is:

Dφ(x,z)=φ(x)φ(z)φ(z),xz

By convexity, Dφ(x,z)0, making it a type of distance. The Bregman divergence generalizes the quadratic norm Q.

The table below shows some common Bregman divergences.

Function Name ϕ(x) dom ϕ Dϕ(x,y)
Squared norm 12x2 (,+) 12(xy)2
Shannon entropy xlogxx [0,+) xlogxyx+y
Bit entropy xlogx+(1x)log(1x) [0,1] xlogxy+(1x)log1x1y
Burg entropy logx (0,+) xylogxy1
Hellinger 1x2 [1,1] (1xy)(1y2)1/2(1x2)1/2
Exponential expx (,+) expx(xy+1)expy
Inverse 1/x (0,+) 1x+xy22y

Mirror Descent

Replacing the 2-norm in the proximal term with the Bregman divergence, we have:

xt+1=argminxX{f(xt),x+1ηtDφ(x,xt)}

This leads to the mirror descent algorithm:

Mirror Descent Algorithm:

xt+1=argminxX{f(xt),x+1ηtDφ(x,xt)}

This example shows the importance of considering Bregman divergence due to the objective function's geometry. However, in most cases, we need a proper Bregman divergence due to the constraint's geometry. The following example illustrates choosing the right Bregman divergence under a specific constraint.

Example: Probability Simplex

In many statistical problems, we estimate the probability mass function of a discrete distribution. The parameters are constrained to the probability simplex:

Δ={xRd | i=1dxi=1,xi0 for all i=1,,d}

Simplex

A widely used Bregman divergence for the probability simplex uses φ as the negative entropy:

φ(x)=i=1dxilogxi

The corresponding Bregman divergence becomes the Kullback–Leibler (KL) divergence:

Dφ(x,z)=i=1dxilogxizi

MD

This is also known as DKL(xz). The maximum log-likelihood estimator is essentially finding a distribution Pθ closest under the KL-divergence to the true distribution Pθ.

Therefore, for the constrained optimization problem

minxΔf(x)

the mirror descent algorithm with the KL-divergence is:

xt+1=argminxΔ{f(xt),x+1ηtDKL(xxt)}

And it has a closed-form solution:

Mirror Descent for Probability Simplex

xt+1=softmax(1ηtf(xt))=exp(1ηtf(xt))i=1dexp(1ηtf(xt)i)

The algorithm can be implemented as:

def f(x): # define the objective function
    ...

x = torch.ones(dim, requires_grad=True) / dim  # Initialize x in the probability simplex

# Mirror Descent Parameters
lr = 0.1  # Learning rate (η_t)
num_iters = 100  # Number of iterations
for t in range(num_iters):
    # Compute function value at x_t
    loss = f(x)
    # Compute gradient using autograd
    loss.backward()
    with torch.no_grad():
        # Compute softmax mirror descent update
        x_new = torch.softmax((1 / lr) * x.grad, dim=0)
        # Update variables
        x.copy_(x_new)  # In-place update to maintain tracking
        # Zero out gradients for the next iteration
        x.grad.zero_()

Summary for Constrained Optimization

Now consider the constrained optimization problem minxΔf(x). We have learned three algorithms for solving constrained optimization:

  1. Frank-Wolfe Algorithm: Essentially mirror descent with φ=0. It involves solving a linear programming sub-problem.
  2. Projected Gradient Descent: Uses the 2-norm as the proximal term, involving a quadratic programming sub-problem.
  3. Mirror Descent: Using KL-divergence DKL(xxt), the sub-problem has a closed-form solution.