8000 Improve sample_weight handling in sag(a) · Issue #31536 · scikit-learn/scikit-learn · GitHub
[go: up one dir, main page]

Skip to content
Improve sample_weight handling in sag(a) #31536
Open
@snath-xoc

Description

@snath-xoc

Describe the bug

This may be more of a discussion, but overall I am not sure what treatment of weighting would preserve the convergence guarantees for the SAG(A) solver. So far as I see it, at each update step we uniformly select some index $i_j$ such that the update steps can be generalised as:

$x^{k+1} = x^{k} - \sum_{j=1}^{k} \alpha_{j} S(j, i_{1:k}) f'_{i_j}(x^j)$

Where $S(j, i_{1:k}) = 1/n$ if $j$ is the maximum iteration at which $i_j$ is selected.

For frequency based weighting, one could sample $i_j$ using weights as a probability, and under non-uniform sampling the SAG(A) convergence guarantees still seem to hold, (see here).

Alternatively as currently done, the weights could be multiplied through with the gradient update and that could also work, however I am not sure which method is best (we also here need to additionally consider the division by the cardinality of the set of "seen" elements within each update step).

Steps/Code to Reproduce

import numpy as np
from scipy.stats import kstest
from sklearn.linear_model.tests.test_sag import sag, squared_dloss
from sklearn.datasets import make_regression
from sklearn.utils._testing import assert_allclose_dense_sparse

step_size=0.01
alpha=1

n_features = 1

rng = np.random.RandomState(0)
X, y = make_regression(n_samples=10000,random_state=77,n_features=n_features)
weights = rng.randint(0,5,size=X.shape[0])

X_repeated = np.repeat(X,weights,axis=0)
y_repeated = np.repeat(y,weights,axis=0)

weights_w_all = np.zeros([n_features,100])
weights_r_all = np.zeros([n_features,100])

for random_state in np.arange(100):

    weights_w, int_w = sag(X,y,step_size=step_size,alpha=alpha,sample_weight=weights,dloss=squared_dloss,random_state=random_state)
    weights_w_all[:,random_state] = weights_w
    weights_r, int_r = sag(X_repeated,y_repeated,step_size=step_size,alpha=alpha,dloss=squared_dloss,random_state=random_state)
    weights_r_all[:,random_state] = weights_r

print(kstest(weights_r_all[0],weights_w_all[0]))

note that I modified sag in test_sag.py to accept random_state:

def sag(
    X,
    y,
    step_size,
    alpha,
    n_iter=1,
    dloss=None,
    sparse=False,
    sample_weight=None,
    fit_intercept=True,
    saga=False,
    random_state=77,
):
    n_samples, n_features = X.shape[0], X.shape[1]

    weights = np.zeros(X.shape[1])
    sum_gradient = np.zeros(X.shape[1])
    gradient_memory = np.zeros((n_samples, n_features))

    intercept = 0.0
    intercept_sum_gradient = 0.0
    intercept_gradient_memory = np.zeros(n_samples)

    rng = np.random.RandomState(random_state)
    decay = 1.0
    seen = set()

    # sparse data has a fixed decay of .01
    if sparse:
        decay = 0.01

    for epoch in range(n_iter):
        for k in range(n_samples):
            #if sample_weight is not None:
            #    idx = rng.choice(np.arange(n_samples),p=sample_weight/np.sum(sample_weight))
            #else:
            idx = int(rng.rand()*n_samples)
            # idx = k
            entry = X[idx]
            seen.add(idx)
            p = np.dot(entry, weights) + intercept
            gradient = dloss(p, y[idx])
            if sample_weight is not None:
                gradient *= sample_weight[idx]
            update = entry * gradient + alpha * weights
            gradient_correction = update - gradient_memory[idx]
            sum_gradient += gradient_correction
            gradient_memory[idx] = update
            if saga:
                weights -= gradient_correction * step_size * (1 - 1.0 / len(seen))

            if fit_intercept:
                gradient_correction = gradient - intercept_gradient_memory[idx]
                intercept_gradient_memory[idx] = gradient
                intercept_sum_gradient += gradient_correction
                gradient_correction *= step_size * (1.0 - 1.0 / len(seen))
                if saga:
                    intercept -= (
                        step_size * intercept_sum_gradient / len(seen) * decay
                    ) + gradient_correction
                else:
                    intercept -= step_size * intercept_sum_gradient / len(seen) * decay

            weights -= step_size * sum_gradient / len(seen)

    return weights, intercept

Expected Results

kstest should have p-value larger than 0.025

Actual Results

KstestResult(statistic=np.float64(0.44), pvalue=np.float64(4.414205948474835e-09), statistic_location=np.float64(24.644506472064027), statistic_sign=np.int8(-1))

With an example histogram of:

Image

Versions

System:
    python: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ]
executable: /Users/shrutinath/micromamba/envs/scikit-learn/bin/python
   machine: macOS-14.3-arm64-arm-64bit

Python dependencies:
      sklearn: 1.8.dev0
          pip: 24.0
   setuptools: 75.8.0
        numpy: 2.0.0
        scipy: 1.14.0
       Cython: 3.0.10
       pandas: 2.2.2
   matplotlib: 3.9.0
       joblib: 1.4.2
threadpoolctl: 3.5.0

Built with OpenMP: True

threadpoolctl info:
       user_api: blas
   internal_api: openblas
    num_threads: 8
         prefix: libopenblas
...
    num_threads: 8
         prefix: libomp
       filepath: /Users/shrutinath/micromamba/envs/scikit-learn/lib/libomp.dylib
        version: None
Output is truncated. View as a scrollable element or open in a text editor. Adjust cell output settings...

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0