Proximal Gradient Descent¶
Proximal Perspective¶
In the previous lecture, we introduced gradient descent and accelerated gradient algorithms for solving unconstrained optimization problems. We showed the convergence rates of these algorithms when the objective function is smooth. However, in problems like Lasso:
the \(\ell_1\)-norm penalty term is not smooth. Many high-dimensional \(M\)-estimators can be formulated as:
where \(f(x)\) is the loss function, typically convex and smooth, and \(h(x)\) is the penalty term, convex but usually non-differentiable. Directly applying (sub)-gradient descent to these problems deteriorates the convergence rate due to the non-smooth part.
We focus on algorithms for solving this type of composite loss, aiming for convergence rates similar to gradient descent for smooth functions. Before introducing the new algorithm, let's gain insight into gradient descent:
Previously, we motivated gradient descent by showing \(- \nabla f(x_t)\) as the steepest descent direction. An alternative perspective is approximating \(f(x)\) around \(x=x_t\) with a quadratic function:
Instead of minimizing \(f(x)\), we minimize its quadratic approximation:
This problem has a closed-form solution, which is exactly gradient descent:
From the proximal perspective, gradient descent minimizes a local quadratic approximation of the objective function in each iteration.
Now, let's return to the composite loss \(F(x) = f(x)+h(x)\). We modify the proximal perspective of gradient descent as:
This leads to the following algorithm for solving \(\min_{x\in\mathbb{R}^d} f(x)+h(x)\).
Proximal Gradient Descent:
Define the proximal operator as:
The proximal gradient descent can be written as:
Proximal Gradient Descent
Examples of Proximal Gradient Descent¶
Example: Constrained Optimization¶
Although proximal gradient descent is designed for unconstrained problems, we can reformulate constrained optimization \(\min_{x \in \mathcal{X}} f(x)\) as the unconstrained composite form \(\min_{x\in\mathbb{R}^d} f(x)+h(x)\), where the indicator function is:
Solving the proximal operator:
This projects \(x\) onto the constraint \(\mathcal{X}\). Thus, proximal gradient descent gives us the projected gradient descent algorithm for constrained optimization.
Projected Gradient Descent:
The above two figures show the comparison between the projected gradient descent and the Frank-Wolfe algorithm.
Example: Lasso¶
The \(\ell_1\)-penalized optimization has the objective function \(\min_x f(x)+ \lambda \|x\|_1\). The proximal operator becomes the soft-threshold operator:
for all \(j = 1, \ldots, d\).
The above two figures show the soft-thresholding operator.
Solves the Lasso problem \(\min_{x \in \mathbb{R}^d} f(x)+\lambda \|x\|_1\) as:
Iterative Shrinkage-Thresholding Algorithm (ISTA)
Here is the implementation of ISTA by PyTorch:
import torch
# Define function f(x) = 0.5 * ||Ax - b||^2 (L2 loss)
dim = 5
A = torch.randn(dim, dim) # Random matrix A
b = torch.randn(dim) # Random vector b
def f(x):
return 0.5 * torch.norm(A @ x - b)**2 # Quadratic loss function
# Soft-thresholding function
def soft_thresholding(y, threshold):
return torch.sign(y) * torch.clamp(torch.abs(y) - threshold, min=0)
x = torch.zeros(dim, requires_grad=True) # Initialize x
# ISTA Parameters
lr = 0.1 # Learning rate (eta_t)
lambda_ = 0.1 # Regularization parameter
num_iters = 100 # Number of iterations
for t in range(num_iters):
loss = f(x)
loss.backward()
with torch.no_grad():
# Gradient descent step
y_t = x - lr * x.grad
# Soft-thresholding step
x_new = soft_thresholding(y_t, lambda_ * lr)
# Update variables
x.copy_(x_new) # In-place update to maintain tracking
# Zero out gradients for the next iteration
x.grad.zero_()
Even if the objective function \(F(x) = f(x) + h(x)\) has a non-smooth \(h(x)\), the following theorem shows that proximal gradient descent has the same convergence rate \(O(1/t)\) as gradient descent.
Theorem (Convergence rate of proximal gradient descent): Suppose \(f\) is convex and \(L\)-smooth and \(h\) is convex. If \(\eta_t = 1/{L}\), the proximal gradient descent achieves:
Accelerated Proximal Gradient Descent¶
Theorem shows that proximal gradient descent has a convergence rate \(O(1/t)\), similar to gradient descent. In the previous lecture, we introduced Nesterov's accelerated gradient descent, which converges faster with a rate \(O(1/t^2)\).
We can apply Nesterov's idea to proximal gradient descent, resulting in the following algorithm.
Accelerated Proximal Gradient Descent
Initialize \(x_0 = y_0\):
where \(\lambda_0 = 1, \lambda_t = \frac{1 + \sqrt{1+4\lambda_{t-1}^2}}{2}\).
For example, we can accelerate ISTA for Lasso \(\min_{x} f(x)+ \lambda \|x\|_1\) using the following algorithm.
Solves the Lasso problem \(\min_{x \in \mathbb{R}^d} f(x)+\lambda \|x\|_1\) as:
Fast Iterative Shrinkage-Thresholding Algorithm (FISTA)
where \(\lambda_0 = 1, \lambda_t = \frac{1 + \sqrt{1+4\lambda_{t-1}^2}}{2}\) and \(x_0=y_0\).
Here is the implementation of FISTA:
# FISTA parameters
lr = 0.1 # Learning rate
lambda_ = 0.1 # l1 regularization parameter
# Momentum term
lambda_t = 1.0
for t in range(100):
# Compute function value at y_t
loss = f(y)
# Compute gradient using autograd
loss.backward()
with torch.no_grad():
# Gradient descent step + soft-thresholding
x_new = soft_thresholding(y - lr * y.grad, lambda_ * lr)
# Update momentum parameter
lambda_new = (1 + torch.sqrt(1 + 4 * lambda_t ** 2)) / 2
# Compute y_{t+1} using acceleration
y = x_new + ((lambda_t - 1) / lambda_new) * (x_new - x)
# Update variables for the next iteration
x.copy_(x_new) # In-place update to keep autograd tracking
lambda_t = lambda_new
y.grad.zero_()
The convergence rate of accelerated proximal gradient descent is \(O(1/t^2)\).