8000 MAINT: compute residuals inside the ufunc · numpy/numpy@12114c7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 12114c7

Browse files
committed
MAINT: compute residuals inside the ufunc
This prevents an overly large output array being allocated. It also means the the residuals can be handled as a separate out argument in future.
1 parent 3ef55be commit 12114c7

File tree

2 files changed

+102
-29
lines changed

2 files changed

+102
-29
lines changed

numpy/linalg/linalg.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -2036,17 +2036,9 @@ def lstsq(a, b, rcond="warn"):
20362036
else:
20372037
gufunc = _umath_linalg.lstsq_n
20382038

2039-
signature = 'DDd->Did' if isComplexType(t) else 'ddd->did'
2039+
signature = 'DDd->Ddid' if isComplexType(t) else 'ddd->ddid'
20402040
extobj = get_linalg_error_extobj(_raise_linalgerror_lstsq)
2041-
b_out, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)
2042-
2043-
# b_out contains both the solution and the components of the residuals
2044-
x = b_out[...,:n,:]
2045-
r_parts = b_out[...,n:,:]
2046-
if isComplexType(t):
2047-
resids = sum(abs(r_parts)**2, axis=-2)
2048-
else:
2049-
resids = sum(r_parts**2, axis=-2)
2041+
x, resids, rank, s = gufunc(a, b, rcond, signature=signature, extobj=extobj)
20502042

20512043
# remove the axis we added
20522044
if is_1d:

numpy/linalg/umath_linalg.c.src

Lines changed: 100 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -765,6 +765,10 @@ fortran_int_max(fortran_int x, fortran_int y) {
765765
INIT_OUTER_LOOP_5\
766766
npy_intp s5 = *steps++;
767767

768+
#define INIT_OUTER_LOOP_7 \
769+
INIT_OUTER_LOOP_6\
770+
npy_intp s6 = *steps++;
771+
768772
#define BEGIN_OUTER_LOOP_2 \
769773
for (N_ = 0;\
770774
N_ < dN;\
@@ -805,6 +809,17 @@ fortran_int_max(fortran_int x, fortran_int y) {
805809
args[4] += s4,\
806810
args[5] += s5) {
807811

812+
#define BEGIN_OUTER_LOOP_7 \
813+
for (N_ = 0;\
814+
N_ < dN;\
815+
N_++, args[0] += s0,\
816+
args[1] += s1,\
817+
args[2] += s2,\
818+
args[3] += s3,\
819+
args[4] += s4,\
820+
args[5] += s5,\
821+
args[6] += s6) {
822+
808823
#define END_OUTER_LOOP }
809824

810825
static NPY_INLINE void
@@ -836,6 +851,7 @@ update_pointers(npy_uint8** bases, ptrdiff_t* offsets, size_t count)
836851
#typ = float, double, COMPLEX_t, DOUBLECOMPLEX_t#
837852
#copy = scopy, dcopy, ccopy, zcopy#
838853
#nan = s_nan, d_nan, c_nan, z_nan#
854+
#zero = s_zero, d_zero, c_zero, z_zero#
839855
*/
840856
static NPY_INLINE void *
841857
linearize_@TYPE@_matrix(void *dst_in,
@@ -949,6 +965,23 @@ nan_@TYPE@_matrix(void *dst_in, const LINEARIZE_DATA_t* data)
949965
}
950966
}
951967

968+
static NPY_INLINE void
969+
zero_@TYPE@_matrix(void *dst_in, const LINEARIZE_DATA_t* data)
970+
{
971+
@typ@ *dst = (@typ@ *) dst_in;
972+
973+
int i, j;
974+
for (i = 0; i < data->rows; i++) {
975+
@typ@ *cp = dst;
976+
ptrdiff_t cs = data->column_strides/sizeof(@typ@);
977+
for (j = 0; j < data->columns; ++j) {
978+
*cp = @zero@;
979+
cp += cs;
980+
}
981+
dst += data->row_strides/sizeof(@typ@);
982+
}
983+
}
984+
952985
/**end repeat**/
953986

954987
/* identity square matrix generation */
@@ -3196,6 +3229,12 @@ init_@lapack_func@(GELSD_PARAMS_t *params,
31963229
#TYPE=FLOAT,DOUBLE,CFLOAT,CDOUBLE#
31973230
#REALTYPE=FLOAT,DOUBLE,FLOAT,DOUBLE#
31983231
#lapack_func=sgelsd,dgelsd,cgelsd,zgelsd#
3232+
#dot_func=sdot,ddot,cdotc,zdotc#
3233+
#typ = npy_float, npy_double, npy_cfloat, npy_cdouble#
3234+
#basetyp = npy_float, npy_double, npy_float, npy_double#
3235+
#ftyp = fortran_real, fortran_doublereal,
3236+
fortran_complex, fortran_doublecomplex#
3237+
#cmplx = 0, 0, 1, 1#
31993238
*/
32003239
static inline void
32013240
release_@lapack_func@(GELSD_PARAMS_t* params)
@@ -3206,42 +3245,84 @@ release_@lapack_func@(GELSD_PARAMS_t* params)
32063245
memset(params, 0, sizeof(*params));
32073246
}
32083247

3248+
/** Compute the squared l2 norm of a contiguous vector */
3249+
static @basetyp@
3250+
@TYPE@_abs2(@typ@ *p, npy_intp n) {
3251+
npy_intp i;
3252+
@basetyp@ res = 0;
3253+
for (i = 0; i < n; i++) {
3254+
@typ@ el = p[i];
3255+
#if @cmplx@
3256+
res += el.real*el.real + el.imag*el.imag;
3257+
#else
3258+
res += el*el;
3259+
#endif
3260+
}
3261+
return res;
3262+
}
3263+
32093264
static void
32103265
@TYPE@_lstsq(char **args, npy_intp *dimensions, npy_intp *steps,
32113266
void *NPY_UNUSED(func))
32123267
{
32133268
GELSD_PARAMS_t params;
32143269
int error_occurred = get_fp_invalid_and_clear();
32153270
fortran_int n, m, nrhs;
3216-
INIT_OUTER_LOOP_6
3271+
fortran_int excess;
3272+
3273+
INIT_OUTER_LOOP_7
32173274

32183275
m = (fortran_int)dimensions[0];
32193276
n = (fortran_int)dimensions[1];
32203277
nrhs = (fortran_int)dimensions[2];
3278+
excess = m - n;
32213279

32223280
if (init_@lapack_func@(&params, m, n, nrhs)) {
3223-
LINEARIZE_DATA_t a_in, b_in, x_out, s_out;
3281+
LINEARIZE_DATA_t a_in, b_in, x_out, s_out, r_out;
32243282

32253283
init_linearize_data(&a_in, n, m, steps[1], steps[0]);
32263284
init_linearize_data_ex(&b_in, nrhs, m, steps[3], steps[2], fortran_int_max(n, m));
3227-
init_linearize_data(&x_out, nrhs, fortran_int_max(n, m), steps[5], steps[4]);
3228-
init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[6]);
3285+
init_linearize_data_ex(&x_out, nrhs, n, steps[5], steps[4], fortran_int_max(n, m));
3286+
init_linearize_data(&r_out, 1, nrhs, 1, steps[6]);
3287+
init_linearize_data(&s_out, 1, fortran_int_min(n, m), 1, steps[7]);
32293288

3230-
BEGIN_OUTER_LOOP_6
3289+
BEGIN_OUTER_LOOP_7
32313290
int not_ok;
32323291
linearize_@TYPE@_matrix(params.A, args[0], &a_in);
32333292
linearize_@TYPE@_matrix(params.B, args[1], &b_in);
32343293
params.RCOND = args[2];
32353294
not_ok = call_@lapack_func@(&params);
32363295
if (!not_ok) {
32373296
delinearize_@TYPE@_matrix(args[3], params.B, &x_out);
3238-
*(npy_int*) args[4] = params.RANK;
3239-
delinearize_@REALTYPE@_matrix(args[5], params.S, &s_out);
3297+
*(npy_int*) args[5] = params.RANK;
3298+
delinearize_@REALTYPE@_matrix(args[6], params.S, &s_out);
3299+
3300+
/* Note that linalg.lstsq discards this when excess == 0 */
3301+
if (excess >= 0 && params.RANK == n) {
3302+
/* Compute the residuals as the square sum of each column */
3303+
int i;
3304+
char *resid = args[4];
3305+
@ftyp@ *components = (@ftyp@ *)params.B + n;
3306+
for (i = 0; i < nrhs; i++) {
3307+
@ftyp@ *vector = components + i*m;
3308+
/* Numpy and fortran floating types are the same size,
3309+
* so this case is safe */
3310+
@basetyp@ abs2 = @TYPE@_abs2((@typ@ *)vector, excess);
3311+
memcpy(
3312+
resid + i*r_out.column_strides,
3313+
&abs2, sizeof(abs2));
3314+
}
3315+
}
3316+
else {
3317+
/* Note that this is always discarded by linalg.lstsq */
3318+
nan_@REALTYPE@_matrix(args[4], &r_out);
3319+
}
32403320
} else {
32413321
error_occurred = 1;
32423322
nan_@TYPE@_matrix(args[3], &x_out);
3243-
*(npy_int*) args[4] = -1;
3244-
nan_@REALTYPE@_matrix(args[5], &s_out);
3323+
nan_@REALTYPE@_matrix(args[4], &r_out);
3324+
*(npy_int*) args[5] = -1;
3325+
nan_@REALTYPE@_matrix(args[6], &s_out);
32453326
}
32463327
END_OUTER_LOOP
32473328

@@ -3389,12 +3470,12 @@ static char svd_1_3_types[] = {
33893470
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE
33903471
};
33913472

3392-
/* A, b, rcond, x, rank, s */
3473+
/* A, b, rcond, x, resid, rank, s, */
33933474
static char lstsq_types[] = {
3394-
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT,
3395-
NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
3396-
NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_INT, NPY_FLOAT,
3397-
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_INT, NPY_DOUBLE
3475+
NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT,
3476+
NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
3477+
NPY_CFLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_CFLOAT, NPY_FLOAT, NPY_INT, NPY_FLOAT,
3478+
NPY_CDOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_CDOUBLE, NPY_DOUBLE, NPY_INT, NPY_DOUBLE,
33983479
};
33993480

34003481
typedef struct gufunc_descriptor_struct {
@@ -3590,19 +3671,19 @@ GUFUNC_DESCRIPTOR_t gufunc_descriptors [] = {
35903671
},
35913672
{
35923673
"lstsq_m",
3593-
"(m,n),(m,nrhs),()->(n,nrhs),(),(m)",
3674+
"(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(m)",
35943675
"least squares on the last two dimensions and broadcast to the rest. \n"\
35953676
"For m <= n. \n",
3596-
4, 3, 3,
3677+
4, 3, 4,
35973678
FUNC_ARRAY_NAME(lstsq),
35983679
lstsq_types
35993680
},
36003681
{
36013682
"lstsq_n",
3602-
"(m,n),(m,nrhs),()->(m,nrhs),(),(n)",
3683+
"(m,n),(m,nrhs),()->(n,nrhs),(nrhs),(),(n)",
36033684
"least squares on the last two dimensions and broadcast to the rest. \n"\
3604-
"For m >= n. \n",
3605-
4, 3, 3,
3685+
"For m >= n, meaning that residuals are produced. \n",
3686+
4, 3, 4,
36063687
FUNC_ARRAY_NAME(lstsq),
36073688
lstsq_types
36083689
}

0 commit comments

Comments
 (0)
0