8000 ENH: Return n_iter_ from liblinear and print convergence warnings · jwjohnson314/scikit-learn@c8c72fd · GitHub
[go: up one dir, main page]

Skip to content

Commit c8c72fd

Browse files
MechCoderogrisel
authored andcommitted
ENH: Return n_iter_ from liblinear and print convergence warnings
1 parent 3edfa3c commit c8c72fd

File tree

10 files changed

+300
-142
lines changed

10 files changed

+300
-142
lines changed

sklearn/linear_model/logistic.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -639,6 +639,10 @@ class LogisticRegression(BaseLibLinear, LinearClassifierMixin,
639639
Intercept (a.k.a. bias) added to the decision function.
640640
If `fit_intercept` is set to False, the intercept is set to zero.
641641
642+
`n_iter_` : int | array, shape (n_classes,)
643+
Number of iterations run per class. Valid only for the liblinear
644+
solver.
645+
642646
See also
643647
--------
644648
SGDClassifier : incrementally trained logistic regression (when given

sklearn/svm/base.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -716,17 +716,21 @@ def fit(self, X, y):
716716

717717
# LibLinear wants targets as doubles, even for classification
718718
y_ind = np.asarray(y_ind, dtype=np.float64).ravel()
719-
raw_coef_ = liblinear.train_wrap(X, y_ind,
720-
sp.isspmatrix(X),
721-
self._get_solver_type(),
722-
self.tol, self._get_bias(),
723-
self.C, self.class_weight_,
724-
self.max_iter,
725-
rnd.randint(np.iinfo('i').max))
719+
raw_coef_, self.n_iter_ = liblinear.train_wrap(
720+
X, y_ind, sp.isspmatrix(X), self._get_solver_type(),
721+
self.tol, self._get_bias(), self.C, self.class_weight_,
722+
self.max_iter, rnd.randint(np.iinfo('i').max)
723+
)
726724
# Regarding rnd.randint(..) in the above signature:
727725
# seed for srand in range [0..INT_MAX); due to limitations in Numpy
728726
# on 32-bit platforms, we can't get to the UINT_MAX limit that
729727
# srand supports
728+
for n_iter in self.n_iter_:
729+
if n_iter >= self.max_iter:
730+
warnings.warn("Liblinear failed to converge, increase "
731+
"the number of iterations.", ConvergenceWarning)
732+
if len(self.classes_) == 2:
733+
self.n_iter_ = self.n_iter_[0]
730734

731735
if self.fit_intercept:
732736
self.coef_ = raw_coef_[:, :-1]

sklearn/svm/liblinear.c

Lines changed: 220 additions & 104 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

sklearn/svm/liblinear.pxd

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ cdef extern from "src/liblinear/linear.h":
1313
model *train(problem_const_ptr prob, parameter_const_ptr param) nogil
1414
int get_nr_feature (model *model)
1515
int get_nr_class (model *model)
16+
void get_n_iter (model *model, int *n_iter)
1617
void free_and_destroy_model (model **)
1718
void destroy_param (parameter *)
1819

sklearn/svm/liblinear.pyx

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,13 @@ def train_wrap(X, np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
5555
# coef matrix holder created as fortran since that's what's used in liblinear
5656
cdef np.ndarray[np.float64_t, ndim=2, mode='fortran'] w
5757
cdef int nr_class = get_nr_class(model)
58+
59+
cdef int labels_ = nr_class
60+
if nr_class == 2:
61+
labels_ = 1
62+
cdef np.ndarray[np.int32_t, ndim=1, mode='c'] n_iter = np.zeros(labels_, dtype=np.int32)
63+
get_n_iter(model, <int *>n_iter.data)
64+
5865
cdef int nr_feature = get_nr_feature(model)
5966
if bias > 0: nr_feature = nr_feature + 1
6067
if nr_class == 2 and solver_type != 4: # solver is not Crammer-Singer
@@ -71,7 +78,7 @@ def train_wrap(X, np.ndarray[np.float64_t, ndim=1, mode='c'] Y,
7178
free_parameter(param)
7279
# destroy_param(param) don't call this or it will destroy class_weight_label and class_weight
7380

74-
return w
81+
return w, n_iter
7582

7683

7784
def set_verbosity_wrap(int verbosity):

sklearn/svm/src/liblinear/linear.cpp

Lines changed: 46 additions & 23 deletions
< 4D1F tr class="diff-line-row">
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,7 @@ class Solver_MCSVM_CS
480480
public:
481481
Solver_MCSVM_CS(const problem *prob, int nr_class, double *C, double eps=0.1, int max_iter=100000);
482482
~Solver_MCSVM_CS();
483-
void Solve(double *w);
483+
int Solve(double *w);
484484
private:
485485
void solve_sub_problem(double A_i, int yi, double C_yi, int active_i, double *alpha_new);
486486
bool be_shrunk(int i, int m, int yi, double alpha_i, double minG);
@@ -555,7 +555,7 @@ bool Solver_MCSVM_CS::be_shrunk(int i, int m, int yi, double alpha_i, double min
555555
return false;
556556
}
557557

558-
void Solver_MCSVM_CS::Solve(double *w)
558+
int Solver_MCSVM_CS::Solve(double *w)
559559
{
560560
int i, m, s;
561561
int iter = 0;
@@ -765,6 +765,7 @@ void Solver_MCSVM_CS::Solve(double *w)
765765
delete [] alpha_index;
766766
delete [] y_index;
767767
delete [] active_size_i;
768+
return iter;
768769
}
769770

770771
// A coordinate descent algorithm for
@@ -797,7 +798,7 @@ void Solver_MCSVM_CS::Solve(double *w)
797798
#define GETI(i) (y[i]+1)
798799
// To support weights for instances, use GETI(i) (i)
799800

800-
static void solve_l2r_l1l2_svc(
801+
static int solve_l2r_l1l2_svc(
801802
const problem *prob, double *w, double eps,
802803
double Cp, double Cn, int solver_type, int max_iter)
803804
{
@@ -983,6 +984,7 @@ static void solve_l2r_l1l2_svc(
983984
delete [] alpha;
984985
delete [] y;
985986
delete [] index;
987+
return iter;
986988
}
987989

988990

@@ -1014,7 +1016,7 @@ static void solve_l2r_l1l2_svc(
10141016
#define GETI(i) (0)
10151017
// To support weights for instances, use GETI(i) (i)
10161018

1017-
static void solve_l2r_l1l2_svr(
1019+
static int solve_l2r_l1l2_svr(
10181020
const problem *prob, double *w, const parameter *param,
10191021
int solver_type, int max_iter)
10201022
{
@@ -1215,6 +1217,7 @@ static void solve_l2r_l1l2_svr(
12151217
delete [] beta;
12161218
delete [] QD;
12171219
delete [] index;
1220+
return iter;
12181221
}
12191222

12201223

@@ -1240,7 +1243,7 @@ static void solve_l2r_l1l2_svr(
12401243
#define GETI(i) (y[i]+1)
12411244
// To support weights for instances, use GETI(i) (i)
12421245

1243-
void solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, double Cn,
1246+
int solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, double Cn,
12441247
int max_iter)
12451248
{
12461249
int l = prob->l;
@@ -1395,6 +1398,7 @@ void solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, do
13951398
delete [] alpha;
13961399
delete [] y;
13971400
delete [] index;
1401+
return iter;
13981402
}
13991403

14001404
// A coordinate descent algorithm for
@@ -1414,7 +1418,7 @@ void solve_l2r_lr_dual(const problem *prob, double *w, double eps, double Cp, do
14141418
#define GETI(i) (y[i]+1)
14151419
// To support weights for instances, use GETI(i) (i)
14161420

1417-
static void solve_l1r_l2_svc(
1421+
static int solve_l1r_l2_svc(
14181422
problem *prob_col, double *w, double eps,
14191423
double Cp, double Cn, int max_iter)
14201424
{
@@ -1681,6 +1685,7 @@ static void solve_l1r_l2_svc(
16811685
delete [] y;
16821686
delete [] b;
16831687
delete [] xj_sq;
1688+
return iter;
16841689
}
16851690

16861691
// A coordinate descent algorithm for
@@ -1700,7 +1705,7 @@ static void solve_l1r_l2_svc(
17001705
#define GETI(i) (y[i]+1)
17011706
// To support weights for instances, use GETI(i) (i)
17021707

1703-
static void solve_l1r_lr(
1708+
static int solve_l1r_lr(
17041709
const problem *prob_col, double *w, double eps,
17051710
double Cp, double Cn, int max_newton_iter)
17061711
{
@@ -2061,6 +2066,7 @@ static void solve_l1r_lr(
20612066
delete [] exp_wTx_new;
20622067
delete [] tau;
20632068
delete [] D;
2069+
return newton_iter;
20642070
}
20652071

20662072
// transpose matrix X from row format to column format
@@ -2211,12 +2217,13 @@ static void group_classes(const problem *prob, int *nr_class_ret, int **label_re
22112217
free(data_label);
22122218
}
22132219

2214-
static void train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn)
2220+
static int train_one(const problem *prob, const parameter *param, double *w, double Cp, double Cn)
22152221
{
22162222
double eps=param->eps;
22172223
int max_iter=param->max_iter;
22182224
int pos = 0;
22192225
int neg = 0;
2226+
int n_iter;
22202227
for(int i=0;i<prob->l;i++)
22212228
if(prob->y[i] > 0)
22222229
pos++;
@@ -2240,7 +2247,7 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
22402247
fun_obj=new l2r_lr_fun(prob, C);
22412248
TRON tron_obj(fun_obj, primal_solver_tol, max_iter);
22422249
tron_obj.set_print_string(liblinear_print_string);
2243-
tron_obj.tron(w);
2250+
n_iter=tron_obj.tron(w);
22442251
delete fun_obj;
22452252
delete [] C;
22462253
break;
@@ -2258,23 +2265,23 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
22582265
fun_obj=new l2r_l2_svc_fun(prob, C);
22592266
TRON tron_obj(fun_obj, primal_solver_tol, max_iter);
22602267
tron_obj.set_print_string(liblinear_print_string);
2261-
tron_obj.tron(w);
2268+
n_iter=tron_obj.tron(w);
22622269
delete fun_obj;
22632270
delete [] C;
22642271
break;
22652272
}
22662273
case L2R_L2LOSS_SVC_DUAL:
2267-
solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L2LOSS_SVC_DUAL, max_iter);
2274+
n_iter=solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L2LOSS_SVC_DUAL, max_iter);
22682275
break;
22692276
case L2R_L1LOSS_SVC_DUAL:
2270-
solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L1LOSS_SVC_DUAL, max_iter);
2277+
n_iter=solve_l2r_l1l2_svc(prob, w, eps, Cp, Cn, L2R_L1LOSS_SVC_DUAL, max_iter);
22712278
break;
22722279
case L1R_L2LOSS_SVC:
22732280
{
22742281
problem prob_col;
22752282
feature_node *x_space = NULL;
22762283
transpose(prob, &x_space ,&prob_col);
2277-
solve_l1r_l2_svc(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
2284+
n_iter=solve_l1r_l2_svc(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
22782285
delete [] prob_col.y;
22792286
delete [] prob_col.x;
22802287
delete [] x_space;
@@ -2285,14 +2292,14 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
22852292
problem prob_col;
22862293
feature_node *x_space = NULL;
22872294
transpose(prob, &x_space ,&prob_col);
2288-
solve_l1r_lr(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
2295+
n_iter=solve_l1r_lr(&prob_col, w, primal_solver_tol, Cp, Cn, max_iter);
22892296
delete [] prob_col.y;
22902297
delete [] prob_col.x;
22912298
delete [] x_space;
22922299
break;
22932300
}
22942301
case L2R_LR_DUAL:
2295-
solve_l2r_lr_dual(prob, w, eps, Cp, Cn, max_iter);
2302+
n_iter=solve_l2r_lr_dual(prob, w, eps, Cp, Cn, max_iter);
22962303
break;
22972304
case L2R_L2LOSS_SVR:
22982305
{
@@ -2303,22 +2310,23 @@ static void train_one(const problem *prob, const parameter *param, double *w, do
23032310
fun_obj=new l2r_l2_svr_fun(prob, C, param->p);
23042311
TRON tron_obj(fun_obj, param->eps, max_iter);
23052312
tron_obj.set_print_string(liblinear_print_string);
2306-
tron_obj.tron(w);
2313+
n_iter=tron_obj.tron(w);
23072314
delete fun_obj;
23082315
delete [] C;
23092316
break;
23102317

23112318
}
23122319
case L2R_L1LOSS_SVR_DUAL:
2313-
solve_l2r_l1l2_svr(prob, w, param, L2R_L1LOSS_SVR_DUAL, max_iter);
2320+
n_iter=solve_l2r_l1l2_svr(prob, w, param, L2R_L1LOSS_SVR_DUAL, max_iter);
23142321
break;
23152322
case L2R_L2LOSS_SVR_DUAL:
2316-
solve_l2r_l1l2_svr(prob, w, param, L2R_L2LOSS_SVR_DUAL, max_iter);
2323+
n_iter=solve_l2r_l1l2_svr(prob, w, param, L2R_L2LOSS_SVR_DUAL, max_iter);
23172324
break;
23182325
default:
23192326
fprintf(stderr, "ERROR: unknown solver_type\n");
23202327
break;
23212328
}
2329+
return n_iter;
23222330
}
23232331

23242332
//
@@ -2330,6 +2338,7 @@ model* train(const problem *prob, const parameter *param)
23302338
int l = prob->l;
23312339
int n = prob->n;
23322340
int w_size = prob->n;
2341+
int n_iter;
23332342
model *model_ = Malloc(model,1);
23342343

23352344
if(prob->bias>=0)
@@ -2344,9 +2353,10 @@ model* train(const problem *prob, const parameter *param)
23442353
param->solver_type == L2R_L2LOSS_SVR_DUAL)
23452354
{
23462355
model_->w = Malloc(double, w_size);
2356+
model_->n_iter = Malloc(int, 1);
23472357
model_->nr_class = 2;
23482358
model_->label = NULL;
2349-
train_one(prob, param, &model_->w[0], 0, 0);
2359+
model_->n_iter[0] =train_one(prob, param, &model_->w[0], 0, 0);
23502360
}
23512361
else
23522362
{
@@ -2398,31 +2408,33 @@ model* train(const problem *prob, const parameter *param)
23982408
if(param->solver_type == MCSVM_CS)
23992409
{
24002410
model_->w=Malloc(double, n*nr_class);
2411+
model_->n_iter=Malloc(int, 1);
24012412
for(i=0;i<nr_class;i++)
24022413
for(j=start[i];j<start[i]+count[i];j++)
24032414
sub_prob.y[j] = i;
24042415
Solver_MCSVM_CS Solver(&sub_prob, nr_class, weighted_C, param->eps);
2405-
Solver.Solve(model_->w);
2416+
model_->n_iter[0]=Solver.Solve(model_->w);
24062417
}
24072418
else
24082419
{
24092420
if(nr_class == 2)
24102421
{
24112422
model_->w=Malloc(double, w_size);
2412-
2423+
model_->n_iter=Malloc(int, 1);
24132424
int e0 = start[0]+count[0];
24142425
k=0;
24152426
for(; k<e0; k++)
24162427
sub_prob.y[k] = -1;
24172428
for(; k<sub_prob.l; k++)
24182429
sub_prob.y[k] = +1;
24192430

2420-
train_one(&sub_prob, param, &model_->w[0], weighted_C[1], weighted_C[0]);
2431+
model_->n_iter[0]=train_one(&sub_prob, param, &model_->w[0], weighted_C[1], weighted_C[0]);
24212432
}
24222433
else
24232434
{
24242435
model_->w=Malloc(double, w_size*nr_class);
24252436
double *w=Malloc(double, w_size);
2437+
model_->n_iter=Malloc(int, nr_class);
24262438
for(i=0;i<nr_class;i++)
24272439
{
24282440
int si = start[i];
@@ -2436,7 +2448,7 @@ model* train(const problem *prob, const parameter *param)
24362448
for(; k<sub_prob.l; k++)
24372449
sub_prob.y[k] = -1;
24382450

2439-
train_one(&sub_prob, param, w, weighted_C[i], param->C);
2451+
model_->n_iter[i]=train_one(&sub_prob, param, w, weighted_C[i], param->C);
24402452

24412453
for(int j=0;j<w_size;j++)
24422454
model_->w[j*nr_class+i] = w[j];
@@ -2795,6 +2807,17 @@ void get_labels(const model *model_, int* label)
27952807
label[i] = model_->label[i];
27962808
}
27972809

2810+
void get_n_iter(const model *model_, int* n_iter)
2811+
{
2812+
int labels;
2813+
labels = model_->nr_class;
2814+
if (labels == 2)
2815+
labels = 1;
2816+
if (model_->n_iter != NULL)
2817+
for(int i=0;i<labels;i++)
2818+
n_iter[i] = model_->n_iter[i];
2819+
}
2820+
27982821
void free_model_content(struct model *model_ptr)
27992822
{
28002823
if(model_ptr->w != NULL)

sklearn/svm/src/liblinear/linear.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ struct model
4343
double *w;
4444
int *label; /* label of each class */
4545
double bias;
46+
int *n_iter; /* no. of iterations of each class */
4647
};
4748

4849
struct model* train(const struct problem *prob, const struct parameter *param);
@@ -58,6 +59,7 @@ struct model *load_model(const char *model_file_name);
5859
int get_nr_feature(const struct model *model_);
5960
int get_nr_class(const struct model *model_);
6061
void get_labels(const struct model *model_, int* label);
62+
void get_n_iter(const struct model *model_, int* n_iter);
6163

6264
void free_model_content(struct model *model_ptr);
6365
void free_and_destroy_model(struct model **model_ptr_ptr);

sklearn/svm/src/liblinear/tron.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ TRON::~TRON()
4444
{
4545
}
4646

47-
void TRON::tron(double *w)
47+
int TRON::tron(double *w)
4848
{
4949
// Parameters for updating the iterates.
5050
double eta0 = 1e-4, eta1 = 0.25, eta2 = 0.75;
@@ -146,6 +146,7 @@ void TRON::tron(double *w)
146146
delete[] r;
147147
delete[] w_new;
148148
delete[] s;
149+
return --iter;
149150
}
150151

151152
int TRON::trcg(double delta, double *g, double *s, double *r)

0 commit comments

Comments
 (0)
0