Sunday, December 12, 2021

A simple demo of using Pytorch to solve Lasso problem with ISTA method

ista

A simple demo of using Pytorch to solve Lasso problem with ISTA method

import torch 
import matplotlib.pyplot as plt 
import seaborn as sns 
sns.set_theme()

def dic(x, t, w):
    diff = x.reshape(-1, 1) - t.reshape(1, -1) 
    g = torch.pow(2, -diff**2/w**2)
    norms = g.norm(dim=0) 
    # assert norms.shape[0] == len(t)
    # print(norms.shape)
    return g/norms

def setup(x1, x2):
    m, n = 100, 16
    x = torch.linspace(0., 1., m, dtype=torch.float64)
    t = torch.linspace(0., 1., n, dtype=torch.float64) 
    w = 0.1
    atoms = dic(x, t, w) 

    k = 3
    ids = [n//k, (k-1)*(n//k)]
    coeffs = torch.zeros(n, dtype=torch.float64)
    coeffs[ids[0]] = x1
    coeffs[ids[1]] = x2
    coeffs = coeffs.reshape(-1, 1)
    y = atoms @ coeffs 

    infor = {
        "m": m,
        "n": n,
        "x": x,
        "t": t,
        "w": w,
        "ids": ids,
        "coeffs": coeffs,

    }

    lbd = 1.

    return y, atoms, lbd, infor

def ISTA(y, atoms, lbd, max_iters=500):
    m, n = atoms.shape 
    assert y.shape == (m, 1) 
    assert y.dtype == atoms.dtype == torch.float64
    x = torch.zeros(n, dtype=torch.float64).reshape(-1, 1)
    A = atoms 
    _, s, _ = torch.linalg.svd(A)
    L = s.max()**2 
    x_list = []
    for i in range(max_iters):
        x = x + (1/L) * A.T @ (y- A @ x) 
        x = (x - lbd/L).clip(min=0.) 
        x_list += [x]
    return x_list

def run():
    y, atoms, lbd, infor = setup(1, 1)
    lbd = 0.1
    x_list = ISTA(y, atoms, lbd, max_iters=200)

    # computing
    y_hat_list = [atoms @ x for x in x_list]
    y_hat = y_hat_list[-1]
    r_list = [y-yh for yh in y_hat_list]
    max_cert_list = [(atoms.T @ r).max() for r in r_list] 
    u_list = [(lbd/mc)*r for mc, r in zip(max_cert_list, r_list)]
    primal_list = [0.5* r.norm()**2 + lbd*x.norm(p=1) for r, x in zip(r_list, x_list)]
    dual_list = [0.5*y.norm()**2 - 0.5*(y-u).norm()**2 for u in u_list] 
    gap_list = [p-d for p, d in zip(primal_list, dual_list)]

    f, ax  = plt.subplots(1, 3, figsize=(21, 5))

    ax0 = ax[0]
    ax0.plot(infor["x"], y.reshape(-1), label="y")
    ax0.plot(infor["x"], y_hat.reshape(-1), label="y_hat")
    ax0.legend()

    ax1 = ax[1] 
    vals = [(cm-lbd)/lbd for cm in max_cert_list]
    ax1.plot(vals, label="(max_cert-lbd)/lbd")
    ax1.legend()
    ax1.set_yscale("log")

    ax2 = ax[2]
    ax2.plot(gap_list, label="dual gap")
    ax2.legend()
    ax2.set_yscale("log")
    plt.show()

run()

Result

Popular Posts