Thursday, June 1, 2023

Solving Elastic-Net problem using Proximal Gradient Descent

Solving Elastic-Net problem

Solving Elastic-Net problem using Proximal Gradient Descent

The Python code of this post is available on my Github.

Problem definition

In this note, we consider the so-called Elastic-Net problem

minxRn12bAx22+λ1x1+λ22x22\min_{x\in \mathbb R^n} \quad \frac{1}{2}||b-Ax||_2^2 + \lambda_1||x||_1 + \frac{\lambda_2}{2}||x||_2^2

where bRmb\in \mathbb R^m and ARm×nA \in \mathbb R^{m\times n}, λ1,λ20\lambda_1, \lambda_2\geq 0. The optimal solution to this problem is sparse if λ2\lambda_2 is small. In particular, the Lasso problem corresponds to λ2=0\lambda_2=0.

The following figures demonstrate the difference between the solutions of the Lasso and Elastic-Net problems. As you can see, the solution obtained through Lasso is sparser, while the solution obtained through Elastic-Net is smoother.

Elastic-Net and Lasso

The evolution of iterative solutions of the two problems is

Elastic-Net and Lasso evolution

Solving method using Proximal Gradient Descent

The proximal gradient method is the following updating rule

x(k+1)=proxg/α(x(k)1αf(x(k)))x^{(k+1)} = \text{prox}_{g/\alpha} \left(x^{(k)} - \frac{1}{\alpha} \nabla f(x^{(k)})\right)

Here α>0\alpha>0 is the Lipschitz constant of the gradient of ff (assuming that it exists), in this case ff is said to be α\alpha-smooth.

Here let’s explain how such above update is a convergence algorithm.
Note that the Elastic-Net problem is a special case of the more general form

minxP(x)=f(x)+g(x)\min_{x} \quad P(x) = f(x) + g(x)

where ff is the main smooth loss function and gg is a regularization function which is usually non-smooth. If one assumes that ff has α\alpha-Lipschitz continuous gradient then

P(x)f(x)+f(x),xx+α2xx22+g(x)=M(x)P(x') \leq f(x) + \langle \nabla f(x), x'-x\rangle +\frac{\alpha}{2} ||x'-x||^2_2 + g(x') = M(x')

The function MM is called the surrogate function of PP with equality if x=xx'=x. To minimize PP we can minimize MM, this is known as the Majorize-Minimization method. To do that, let us rewrite MM,

M(x)=f(x)12αf(x)22+α2xx+1αf(x)22+g(x)M(x') = f(x) - \frac{1}{2\alpha} ||\nabla f(x)||^2_2 + \frac{\alpha}{2} ||x' - x + \frac{1}{\alpha}\nabla f(x)||_2^2 + g(x')

Then, to minimize MM we apply the proximal operator on g/αg/\alpha with argument x+x^+, where x+=x1αf(x)x^+ = x - \frac{1}{\alpha} \nabla f(x) is the gradient descent update from xx, i.e.

argminx12xx+22+1αg(x)=proxg/α(x+).\arg\min_{x'} \quad \frac{1}{2} ||x' - x^+||_2^2 + \frac{1}{\alpha} g(x') = \text{prox}_{g/\alpha} (x^+).

This is exactly the proximal gradient descent method described above.

Back to Elastic-Net problem

To apply the proximal gradient descent, we need the closed form expressions for α\alpha, f\nabla f and proxg/α\text{prox}_{g/\alpha}. This is a simple task.

For the Elastic-Net problem, we can choose α=AFro2=i+1nai22\alpha = ||A||^2_{Fro} = \sum_{i+1}^n ||a_i||_2^2 the squared Frobenius norm, i.e. the total squared norms of the ii-th column aia_i of AA.

With f(x)=12bAx22f(x) = \frac{1}{2} ||b - Ax||_2^2, we have f(x)=AT(bAx)\nabla f(x) = -A^T(b-Ax).

For g(x)=λ1x1+λ22x22g(x) = \lambda_1||x||_1 +\frac{\lambda_2}{2}||x||_2^2, we find the proximal associated with g/αg/\alpha as follows.

argminx12xx+22+λ1αx1+λ22αx22\arg\min_{x'} \quad \frac{1}{2} ||x' - x^+||_2^2 + \frac{\lambda_1}{\alpha}||x'||_1 + \frac{\lambda_2}{2\alpha} ||x'||_2^2

Let s=sign(x)s = \text{sign}(x'), then the Fermat’s rule reads as follows

xx++λ1αs+λ2αx=0x=αα+λ2[x+λ1αs]x' - x^+ + \frac{\lambda_1}{\alpha} s + \frac{\lambda_2}{\alpha} x'=0 \leftrightarrow x' = \frac{\alpha}{\alpha + \lambda_2} \left[x^+ - \frac{\lambda_1}{\alpha}s\right]

By considering different cases of sign of xx' and x+±λ1αx^+ \pm \frac{\lambda_1}{\alpha}, we can eliminate ss in the formula of xx' as follows

proxg/α(x+)=x=αα+λ2sign(x+)[x+λ1α]+. \text{prox}_{g/\alpha}(x^+) = x' = \frac{\alpha}{\alpha + \lambda_2} \text{sign}(x^+) \left[|x^+| - \frac{\lambda_1}{\alpha}\right]_+.
This is the closed form expression for proximal of g/αg/\alpha.

Popular Posts