8000 MAINT Simplify arguments to csr_set_problem and csr_to_sparse (#14135) · rth/scikit-learn@cf2e60b · GitHub
[go: up one dir, main page]

Skip to content

Commit cf2e60b

alexhenrierth
authored andcommitted
MAINT Simplify arguments to csr_set_problem and csr_to_sparse (scikit-learn#14135)
1 parent 214def0 commit cf2e60b

File tree

3 files changed

+13
-18
lines changed

3 files changed

+13
-18
lines changed

sklearn/svm/liblinear.pxd

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,7 @@ cdef extern from "liblinear_helper.c":
3232
void copy_w(void *, model *, int)
3333
parameter *set_parameter(int, double, double, int, char *, char *, int, int, double)
3434
problem *set_problem (char *, char *, np.npy_intp *, double, char *)
35-
problem *csr_set_problem (char *values, np.npy_intp *n_indices,
36-
char *indices, np.npy_intp *n_indptr, char *indptr, char *Y,
37-
np.npy_intp n_features, double bias, char *)
35+
problem *csr_set_problem (char *, char *, char *, char *, int, int, double, char *)
3836

3937
model *set_model(parameter *, char *, np.npy_intp *, char *, double)
4038

sklearn/svm/liblinear.pyx

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,12 +26,10 @@ def train_wrap(X, np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
2626
if is_sparse:
2727
problem = csr_set_problem(
2828
(<np.ndarray[np.float64_t, ndim=1, mode='c']>X.data).data,
29-
(<np.ndarray[np.int32_t, ndim=1, mode='c']>X.indices).shape,
3029
(<np.ndarray[np.int32_t, ndim=1, mode='c']>X.indices).data,
31-
(<np.ndarray[np.int32_t, ndim=1, mode='c']>X.indptr).shape,
3230
(<np.ndarray[np.int32_t, ndim=1, mode='c']>X.indptr).data,
33-
Y.data, (<np.int32_t>X.shape[1]), bias,
34-
sample_weight.data)
31+
Y.data, (<np.int32_t>X.shape[0]), (<np.int32_t>X.shape[1]),
32+
bias, sample_weight.data)
3533
else:
3634
problem = set_problem(
3735
(<np.ndarray[np.float64_t, ndim=2, mode='c']>X).data,

sklearn/svm/src/liblinear/liblinear_helper.c

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -70,19 +70,18 @@ static struct feature_node **dense_to_sparse(double *x, npy_intp *dims,
7070
/*
7171
* Convert scipy.sparse.csr to libsvm's sparse data structure
7272
*/
73-
static struct feature_node **csr_to_sparse(double *values,
74-
npy_intp *shape_indices, int *indices, npy_intp *shape_indptr,
75-
int *indptr, double bias, int n_features)
73+
static struct feature_node **csr_to_sparse(double *values, int *indices,
74+
int *indptr, int n_samples, int n_features, double bias)
7675
{
7776
struct feature_node **sparse, *temp;
7877
int i, j=0, k=0, n;
7978
int have_bias = (bias > 0);
8079

81-
sparse = malloc ((shape_indptr[0]-1)* sizeof(struct feature_node *));
80+
sparse = malloc (n_samples * sizeof(struct feature_node *));
8281
if (sparse == NULL)
8382
return NULL;
8483

85-
for (i=0; i<shape_indptr[0]-1; ++i) {
84+
for (i=0; i<n_samples; ++i) {
8685
n = indptr[i+1] - indptr[i]; /* count elements in row i */
8786

8887
sparse[i] = malloc ((n+have_bias+1) * sizeof(struct feature_node));
@@ -140,14 +139,14 @@ struct problem * set_problem(char *X,char *Y, npy_intp *dims, double bias, char*
140139
return problem;
141140
}
142141

143-
struct problem * csr_set_problem (char *values, npy_intp *n_indices,
144-
char *indices, npy_intp *n_indptr, char *indptr, char *Y,
145-
npy_intp n_features, double bias, char *sample_weight) {
142+
struct problem * csr_set_problem (char *values, char *indices, char *indptr,
143+
char *Y, int n_samples, int n_features, double bias,
144+
char *sample_weight) {
146145

147146
struct problem *problem;
148147
problem = malloc (sizeof (struct problem));
149148
if (problem == NULL) return NULL;
150-
problem->l = (int) n_indptr[0] -1;
149+
problem->l = n_samples;
151150
problem->sample_weight = (double *) sample_weight;
152151

153152
if (bias > 0){
@@ -157,8 +156,8 @@ struct problem * csr_set_problem (char *values, npy_intp *n_indices,
157156
}
158157

159158
problem->y = (double *) Y;
160-
problem->x = csr_to_sparse((double *) values, n_indices, (int *) indices,
161-
n_indptr, (int *) indptr, bias, n_features);
159+
problem->x = csr_to_sparse((double *) values, (int *) indices,
160+
(int *) indptr, n_samples, n_features, bias);
162161
problem->bias = bias;
163162
problem->sample_weight = sample_weight;
164163

0 commit comments

Comments
 (0)
0