8000 MAINT: use more conservative integer types for umath linalg by argriffing · Pull Request #5899 · numpy/numpy · GitHub
[go: up one dir, main page]

Skip to content

MAINT: use more conservative integer types for umath linalg #5899

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
May 21, 2015
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
8000
Diff view
100 changes: 63 additions & 37 deletions numpy/linalg/umath_linalg.c.src
Original file line number Diff line number Diff line change
Expand Up @@ -1128,6 +1128,7 @@ static void
npy_uint8 *tmp_buff = NULL;
size_t matrix_size;
size_t pivot_size;
size_t safe_m;
/* notes:
* matrix will need to be copied always, as factorization in lapack is
* made inplace
Expand All @@ -1138,8 +1139,9 @@ static void
*/
INIT_OUTER_LOOP_3
m = (fortran_int) dimensions[0];
matrix_size = m*m*sizeof(@typ@);
pivot_size = m*sizeof(fortran_int);
safe_m = m;
matrix_size = safe_m * safe_m * sizeof(@typ@);
pivot_size = safe_m * sizeof(fortran_int);
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);

if (tmp_buff)
Expand Down Expand Up @@ -1172,6 +1174,7 @@ static void
npy_uint8 *tmp_buff;
size_t matrix_size;
size_t pivot_size;
size_t safe_m;
/* notes:
* matrix will need to be copied always, as factorization in lapack is
* made inplace
Expand All @@ -1182,8 +1185,9 @@ static void
*/
INIT_OUTER_LOOP_2
m = (fortran_int) dimensions[0];
matrix_size = m*m*sizeof(@typ@);
pivot_size = m*sizeof(fortran_int);
safe_m = m;
matrix_size = safe_m * safe_m * sizeof(@typ@);
pivot_size = safe_m * sizeof(fortran_int);
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);

if (tmp_buff)
Expand Down Expand Up @@ -1252,14 +1256,15 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
fortran_int liwork = -1;
fortran_int info;
npy_uint8 *a, *w, *work, *iwork;
size_t alloc_size = N*(N+1)*sizeof(@typ@);
size_t safe_N = N;
size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@);

mem_buff = malloc(alloc_size);

if (!mem_buff)
goto error;
a = mem_buff;
w = mem_buff + N*N*sizeof(@typ@);
w = mem_buff + safe_N * safe_N * sizeof(@typ@);
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
(@ftyp@*)a, &N, (@ftyp@*)w,
&query_work_size, &lwork,
Expand Down Expand Up @@ -1344,12 +1349,14 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
fortran_int liwork = -1;
npy_uint8 *a, *w, *work, *rwork, *iwork;
fortran_int info;
size_t safe_N = N;

mem_buff = malloc(N*N*sizeof(@typ@)+N*sizeof(@basetyp@));
mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) +
safe_N * sizeof(@basetyp@));
if (!mem_buff)
goto error;
a = mem_buff;
w = mem_buff+N*N*sizeof(@typ@);
w = mem_buff + safe_N * safe_N * sizeof(@typ@);

LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
(@ftyp@*)a, &N, (@fbasetyp@*)w,
Expand Down Expand Up @@ -1581,14 +1588,16 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *a, *b, *ipiv;
mem_buff = malloc(N*N*sizeof(@ftyp@) +
N*NRHS*sizeof(@ftyp@) +
N*sizeof(fortran_int));
size_t safe_N = N;
size_t safe_NRHS = NRHS;
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) +
safe_N * safe_NRHS*sizeof(@ftyp@) +
safe_N * sizeof(fortran_int));
if (!mem_buff)
goto error;
a = mem_buff;
b = a + N*N*sizeof(@ftyp@);
ipiv = b + N*NRHS*sizeof(@ftyp@);
b = a + safe_N * safe_N * sizeof(@ftyp@);
ipiv = b + safe_N * safe_NRHS * sizeof(@ftyp@);

params->A = a;
params->B = b;
Expand Down Expand Up @@ -1759,8 +1768,9 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N)
{
npy_uint8 *mem_buff = NULL;
npy_uint8 *a;
size_t safe_N = N;

mem_buff = malloc(N*N*sizeof(@ftyp@));
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@));
if (!mem_buff)
goto error;

Expand Down Expand Up @@ -1924,11 +1934,12 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
npy_uint8 *mem_buff=NULL;
npy_uint8 *mem_buff2=NULL;
npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr;
size_t a_size = n*n*sizeof(@typ@);
size_t wr_size = n*sizeof(@typ@);
size_t wi_size = n*sizeof(@typ@);
size_t vlr_size = jobvl=='V' ? n*n*sizeof(@typ@) : 0;
size_t vrr_size = jobvr=='V' ? n*n*sizeof(@typ@) : 0;
size_t safe_n = n;
size_t a_size = safe_n * safe_n * sizeof(@typ@);
size_t wr_size = safe_n * sizeof(@typ@);
size_t wi_size = safe_n * sizeof(@typ@);
size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
size_t w_size = wr_size*2;
size_t vl_size = vlr_size*2;
size_t vr_size = vrr_size*2;
Expand Down Expand Up @@ -2120,11 +2131,12 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *w, *vl, *vr, *work, *rwork;
size_t a_size = n*n*sizeof(@ftyp@);
size_t w_size = n*sizeof(@ftyp@);
size_t vl_size = jobvl=='V'? n*n*sizeof(@ftyp@) : 0;
size_t vr_size = jobvr=='V'? n*n*sizeof(@ftyp@) : 0;
size_t rwork_size = 2*n*sizeof(@realtyp@);
size_t safe_n = n;
size_t a_size = safe_n * safe_n * sizeof(@ftyp@);
size_t w_size = safe_n * sizeof(@ftyp@);
size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
size_t rwork_size = 2 * safe_n * sizeof(@realtyp@);
size_t work_count = 0;
@typ@ work_size_query;
fortran_int do_size_query = -1;
Expand Down Expand Up @@ -2446,20 +2458,27 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
npy_uint8 *mem_buff = NULL;
npy_uint8 *mem_buff2 = NULL;
npy_uint8 *a, *s, *u, *vt, *work, *iwork;
size_t a_size = (size_t)m*(size_t)n*sizeof(@ftyp@);
size_t safe_m = m;
size_t safe_n = n;
size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
fortran_int min_m_n = m<n?m:n;
size_t s_size = ((size_t)min_m_n)*sizeof(@ftyp@);
fortran_int u_row_count, vt_column_count;
size_t safe_min_m_n = min_m_n;
size_t s_size = safe_min_m_n * sizeof(@ftyp@);
fortran_int u_row_count, vt_column_count;
size_t safe_u_row_count, safe_vt_column_count;
size_t u_size, vt_size;
fortran_int work_count;
size_t work_size;
size_t iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int);

if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;

u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
safe_u_row_count = u_row_count;
safe_vt_column_count = vt_column_count;

u_size = safe_u_row_count * safe_m * sizeof(@ftyp@);
vt_size = safe_n * safe_vt_column_count * sizeof(@ftyp@);

mem_buff = malloc(a_size + s_size + u_size + vt_size + iwork_size);

Expand Down Expand Up @@ -2557,21 +2576,28 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
npy_uint8 *mem_buff = NULL, *mem_buff2 = NULL;
npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork;
size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size;
size_t safe_u_row_count, safe_vt_column_count;
fortran_int u_row_count, vt_column_count, work_count;
size_t safe_m = m;
size_t safe_n = n;
fortran_int min_m_n = m<n?m:n;
size_t safe_min_m_n = min_m_n;

if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
goto error;

a_size = ((size_t)m)*((size_t)n)*sizeof(@ftyp@);
s_size = ((size_t)min_m_n)*sizeof(@frealtyp@);
u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
safe_u_row_count = u_row_count;
safe_vt_column_count = vt_column_count;

a_size = safe_m * safe_n * sizeof(@ftyp@);
s_size = safe_min_m_n * sizeof(@frealtyp@);
u_size = safe_u_row_count * safe_m * sizeof(@ftyp@);
vt_size = safe_n * safe_vt_column_count * sizeof(@ftyp@);
rwork_size = 'N'==jobz?
7*((size_t)min_m_n) :
(5*(size_t)min_m_n*(size_t)min_m_n + 5*(size_t)min_m_n);
(7 * safe_min_m_n) :
(5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n);
rwork_size *= sizeof(@ftyp@);
iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
iwork_size = 8 * safe_min_m_n* sizeof(fortran_int);

mem_buff = malloc(a_size +
s_size +
Expand Down
0