Sunday, March 20, 2022

Scaling factor of Lasso problem

a

Scaling factor of Lasso problem

Lasso problem is defined as an optimization problem

minxRnP(x):=12yAx2+λx1\min_{x\in \mathbb R^n} P(x):= \frac{1}{2}||y-Ax||^2 + \lambda||x||_1
where λ>0\lambda>0, yRmy\in \mathbb R^m and ARm×nA\in \mathbb R^{m\times n}. Here ||\cdot||, 1||\cdot||_1 is 2\ell_2 and 1\ell_1-norm, respectively.

For xRnx\in \mathbb R^n, x0nx\neq 0_n, the scaling factor of xx is defined by
[x]:=argminτ0P(τx)[x]:=\text{argmin}_{\tau\geq 0} P(\tau x)

If Ax,y>λx1\langle Ax, y\rangle>\lambda||x||_1, then Ax0mAx\neq 0_m and we have a closed form expression of xx,
[x]=Ax,yλx1Ax2>0.[x]=\frac{\langle Ax, y\rangle-\lambda||x||_1}{||Ax||^2}>0.

The following experiment illustrates an observation that: [x][x] is lower bounded by 11 and converges to 11 when xx is generated by ISTA method.

import torch 
import matplotlib.pyplot as plt 
import seaborn as sns 
sns.set_theme()

def dic(x, t, w):
    diff = x.reshape(-1, 1) - t.reshape(1, -1) 
    g = torch.pow(2, -diff**2/w**2)
    norms = g.norm(dim=0) 
    # assert norms.shape[0] == len(t)
    # print(norms.shape)
    return g/norms

def setup(x1, x2):
    m, n = 100, 16
    x = torch.linspace(0., 1., m, dtype=torch.float64)
    t = torch.linspace(0., 1., n, dtype=torch.float64) 
    w = 0.1
    atoms = dic(x, t, w) 

    k = 3
    ids = [n//k, (k-1)*(n//k)]
    coeffs = torch.zeros(n, dtype=torch.float64)
    coeffs[ids[0]] = x1
    coeffs[ids[1]] = x2
    coeffs = coeffs.reshape(-1, 1)
    y = atoms @ coeffs 

    infor = {
        "m": m,
        "n": n,
        "x": x,
        "t": t,
        "w": w,
        "ids": ids,
        "coeffs": coeffs,

    }

    lbd = 1.

    return y, atoms, lbd, infor

def ISTA(y, atoms, lbd, max_iters=500):
    m, n = atoms.shape 
    assert y.shape == (m, 1) 
    assert y.dtype == atoms.dtype == torch.float64
    x = torch.zeros(n, dtype=torch.float64).reshape(-1, 1)
    A = atoms 
    _, s, _ = torch.linalg.svd(A)
    L = s.max()**2 
    x_list = []
    for i in range(max_iters):
        x = x + (1/L) * A.T @ (y- A @ x) 
        x = (x - lbd/L).clip(min=0.) 
        x_list += [x]
    return x_list

def run():
    y, atoms, lbd, infor = setup(1, 1)
    lbd = 0.1
    x_list = ISTA(y, atoms, lbd, max_iters=200)

    # factor 
    normx1_list = [x.norm(1) for x in x_list]
    Ax_list = [(atoms @ x) for x in x_list]
    dotp_list = [(v*y).sum() for v in Ax_list] 
    normv2_list = [v.norm()**2 for v in Ax_list] 
    factor_list = [(a-lbd*b)/c -1 for a, b, c in zip(dotp_list, normx1_list, normv2_list)]

    f, ax  = plt.subplots(1, 1, figsize=(7, 5))
    ax.plot(factor_list)
    ax.set_yscale("log")
    ax.set_ylabel("[x]-1") 
    ax.set_xlabel("Iterartions") 
    ax.set_title("[x]-1 vs iterations")
    plt.show() 

run()

Popular Posts