@@ -31,11 +31,15 @@ def gram_lasso(X, y, alpha, max_iter, tol, w_init=None, weights=None, check_freq
3131 lipschitz = np .zeros (n_features , dtype = X .dtype )
3232 for j in range (n_features ):
3333 lipschitz [j ] = (X [:, j ] ** 2 ).sum () / len (y )
34- w = w_init if w_init is not None else np .zeros (n_features )
34+ w = w_init .copy () if w_init is not None else np .zeros (n_features )
35+ z = w_init .copy () if w_init is not None else np .zeros (n_features )
36+ beta_0 = beta_1 = 1
3537 weights = weights if weights is not None else np .ones (n_features )
3638 # CD
3739 for n_iter in range (max_iter ):
38- cd_epoch (X , G , grads , w , alpha , lipschitz , weights )
40+ beta_1 = (1 + np .sqrt (1 + 4 * beta_0 ** 2 )) / 2
41+ cd_epoch (X , G , grads , w , z , alpha , beta_1 , beta_0 , lipschitz , weights )
42+ beta_0 = beta_1
3943 if n_iter % check_freq == 0 :
4044 p_obj = primal (alpha , y , X , w , weights )
4145 if p_obj_prev - p_obj < tol :
@@ -58,7 +62,7 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
5862 for g in range (n_groups ):
5963 X_g = X [:, grp_indices [grp_ptr [g ]:grp_ptr [g + 1 ]]]
6064 lipschitz [g ] = norm (X_g , ord = 2 ) ** 2 / len (y )
61- w = w_init if w_init is not None else np .zeros (n_features )
65+ w = w_init . copy () if w_init is not None else np .zeros (n_features )
6266 weights = weights if weights is not None else np .ones (n_groups )
6367 # BCD
6468 for n_iter in range (max_iter ):
@@ -74,15 +78,17 @@ def gram_group_lasso(X, y, alpha, groups, max_iter, tol, w_init=None, weights=No
7478
7579
7680@njit
77- def cd_epoch (X , G , grads , w , alpha , lipschitz , weights ):
81+ def cd_epoch (X , G , grads , w , z , alpha , beta_1 , beta_0 , lipschitz , weights ):
7882 n_features = X .shape [1 ]
7983 for j in range (n_features ):
8084 if lipschitz [j ] == 0. or weights [j ] == np .inf :
8185 continue
8286 old_w_j = w [j ]
83- w [j ] = ST (w [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ] * weights [j ])
84- if old_w_j != w [j ]:
85- grads += G [j , :] * (old_w_j - w [j ]) / len (X )
87+ old_z_j = z [j ]
88+ w [j ] = ST (z [j ] + grads [j ] / lipschitz [j ], alpha / lipschitz [j ] * weights [j ])
89+ z [j ] = w [j ] + ((beta_0 - 1 ) / beta_1 ) * (w [j ] - old_w_j )
90+ if old_z_j != z [j ]:
91+ grads += G [j , :] * (old_z_j - z [j ]) / len (X )
8692
8793
8894@njit
0 commit comments