Description
Describe the bug
This may be more of a discussion, but overall I am not sure what treatment of weighting would preserve the convergence guarantees for the SAG(A) solver. So far as I see it, at each update step we uniformly select some index
Where
For frequency based weighting, one could sample
Alternatively as currently done, the weights could be multiplied through with the gradient update and that could also work, however I am not sure which method is best (we also here need to additionally consider the division by the cardinality of the set of "seen" elements within each update step).
Steps/Code to Reproduce
import numpy as np
from scipy.stats import kstest
from sklearn.linear_model.tests.test_sag import sag, squared_dloss
from sklearn.datasets import make_regression
from sklearn.utils._testing import assert_allclose_dense_sparse
step_size=0.01
alpha=1
n_features = 1
rng = np.random.RandomState(0)
X, y = make_regression(n_samples=10000,random_state=77,n_features=n_features)
weights = rng.randint(0,5,size=X.shape[0])
X_repeated = np.repeat(X,weights,axis=0)
y_repeated = np.repeat(y,weights,axis=0)
weights_w_all = np.zeros([n_features,100])
weights_r_all = np.zeros([n_features,100])
for random_state in np.arange(100):
weights_w, int_w = sag(X,y,step_size=step_size,alpha=alpha,sample_weight=weights,dloss=squared_dloss,random_state=random_state)
weights_w_all[:,random_state] = weights_w
weights_r, int_r = sag(X_repeated,y_repeated,step_size=step_size,alpha=alpha,dloss=squared_dloss,random_state=random_state)
weights_r_all[:,random_state] = weights_r
print(kstest(weights_r_all[0],weights_w_all[0]))
note that I modified sag
in test_sag.py to accept random_state:
def sag(
X,
y,
step_size,
alpha,
n_iter=1,
dloss=None,
sparse=False,
sample_weight=None,
fit_intercept=True,
saga=False,
random_state=77,
):
n_samples, n_features = X.shape[0], X.shape[1]
weights = np.zeros(X.shape[1])
sum_gradient = np.zeros(X.shape[1])
gradient_memory = np.zeros((n_samples, n_features))
intercept = 0.0
intercept_sum_gradient = 0.0
intercept_gradient_memory = np.zeros(n_samples)
rng = np.random.RandomState(random_state)
decay = 1.0
seen = set()
# sparse data has a fixed decay of .01
if sparse:
decay = 0.01
for epoch in range(n_iter):
for k in range(n_samples):
#if sample_weight is not None:
# idx = rng.choice(np.arange(n_samples),p=sample_weight/np.sum(sample_weight))
#else:
idx = int(rng.rand()*n_samples)
# idx = k
entry = X[idx]
seen.add(idx)
p = np.dot(entry, weights) + intercept
gradient = dloss(p, y[idx])
if sample_weight is not None:
gradient *= sample_weight[idx]
update = entry * gradient + alpha * weights
gradient_correction = update - gradient_memory[idx]
sum_gradient += gradient_correction
gradient_memory[idx] = update
if saga:
weights -= gradient_correction * step_size * (1 - 1.0 / len(seen))
if fit_intercept:
gradient_correction = gradient - intercept_gradient_memory[idx]
intercept_gradient_memory[idx] = gradient
intercept_sum_gradient += gradient_correction
gradient_correction *= step_size * (1.0 - 1.0 / len(seen))
if saga:
intercept -= (
step_size * intercept_sum_gradient / len(seen) * decay
) + gradient_correction
else:
intercept -= step_size * intercept_sum_gradient / len(seen) * decay
weights -= step_size * sum_gradient / len(seen)
return weights, intercept
Expected Results
kstest should have p-value larger than 0.025
Actual Results
KstestResult(statistic=np.float64(0.44), pvalue=np.float64(4.414205948474835e-09), statistic_location=np.float64(24.644506472064027), statistic_sign=np.int8(-1))
With an example histogram of:
Versions
System:
python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
machine: macOS-14.3-arm64-arm-64bit
Python dependencies:
sklearn: 1.8.dev0
pip: 24.0
setuptools: 75.8.0
numpy: 2.0.0
scipy: 1.14.0
Cython: 3.0.10
pandas: 2.2.2
matplotlib: 3.9.0
joblib: 1.4.2
threadpoolctl: 3.5.0
Built with OpenMP: True
threadpoolctl info:
user_api: blas
internal_api: openblas
num_threads: 8
prefix: libopenblas
...
num_threads: 8
prefix: libomp
filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
version: None
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...