8000 Merge pull request #5899 from argriffing/improve-umath-linalg · numpy/numpy@9dba7a4 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9dba7a4

Browse files
committed
Merge pull request #5899 from argriffing/improve-umath-linalg
MAINT: use more conservative integer types for umath linalg
2 parents a79d9d3 + b9f5e85 commit 9dba7a4

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

numpy/linalg/umath_linalg.c.src

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1128,6 +1128,7 @@ static void
11281128
npy_uint8 *tmp_buff = NULL;
11291129
size_t matrix_size;
11301130
size_t pivot_size;
1131+
size_t safe_m;
11311132
/* notes:
11321133
* matrix will need to be copied always, as factorization in lapack is
11331134
* made inplace
@@ -1138,8 +1139,9 @@ static void
11381139
*/
11391140
INIT_OUTER_LOOP_3
11401141
m = (fortran_int) dimensions[0];
1141-
matrix_size = m*m*sizeof(@typ@);
1142-
pivot_size = m*sizeof(fortran_int);
1142+
safe_m = m;
1143+
matrix_size = safe_m * safe_m * sizeof(@typ@);
1144+
pivot_size = safe_m * sizeof(fortran_int);
11431145
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);
11441146

11451147
if (tmp_buff)
@@ -1172,6 +1174,7 @@ static void
11721174
npy_uint8 *tmp_buff;
11731175
size_t matrix_size;
11741176
size_t pivot_size;
1177+
size_t safe_m;
11751178
/* notes:
11761179
* matrix will need to be copied always, as factorization in lapack is
11771180
* made inplace
@@ -1182,8 +1185,9 @@ static void
11821185
*/
11831186
INIT_OUTER_LOOP_2
11841187
m = (fortran_int) dimensions[0];
1185-
matrix_size = m*m*sizeof(@typ@);
1186-
pivot_size = m*sizeof(fortran_int);
1188+
safe_m = m;
1189+
matrix_size = safe_m * safe_m * sizeof(@typ@);
1190+
pivot_size = safe_m * sizeof(fortran_int);
11871191
tmp_buff = (npy_uint8 *)malloc(matrix_size + pivot_size);
11881192

11891193
if (tmp_buff)
@@ -1252,14 +1256,15 @@ init_@lapack_func@(EIGH_PARAMS_t* params, char JOBZ, char UPLO,
12521256
fortran_int liwork = -1;
12531257
fortran_int info;
12541258
npy_uint8 *a, *w, *work, *iwork;
1255-
size_t alloc_size = N*(N+1)*sizeof(@typ@);
1259+
size_t safe_N = N;
1260+
size_t alloc_size = safe_N * (safe_N + 1) * sizeof(@typ@);
12561261

12571262
mem_buff = malloc(alloc_size);
12581263

12591264
if (!mem_buff)
12601265
goto error;
12611266
a = mem_buff;
1262-
w = mem_buff + N*N*sizeof(@typ@);
1267+
w = mem_buff + safe_N * safe_N * sizeof(@typ@);
12631268
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
12641269
(@ftyp@*)a, &N, (@ftyp@*)w,
12651270
&query_work_size, &lwork,
@@ -1344,12 +1349,14 @@ init_@lapack_func@(EIGH_PARAMS_t *params,
13441349
fortran_int liwork = -1;
13451350
npy_uint8 *a, *w, *work, *rwork, *iwork;
13461351
fortran_int info;
1352+
size_t safe_N = N;
13471353

1348-
mem_buff = malloc(N*N*sizeof(@typ@)+N*sizeof(@basetyp@));
1354+
mem_buff = malloc(safe_N * safe_N * sizeof(@typ@) +
1355+
safe_N * sizeof(@basetyp@));
13491356
if (!mem_buff)
13501357
goto error;
13511358
a = mem_buff;
1352-
w = mem_buff+N*N*sizeof(@typ@);
1359+
w = mem_buff + safe_N * safe_N * sizeof(@typ@);
13531360

13541361
LAPACK(@lapack_func@)(&JOBZ, &UPLO, &N,
13551362
(@ftyp@*)a, &N, (@fbasetyp@*)w,
@@ -1581,14 +1588,16 @@ init_@lapack_func@(GESV_PARAMS_t *params, fortran_int N, fortran_int NRHS)
15811588
{
15821589
npy_uint8 *mem_buff = NULL;
15831590
npy_uint8 *a, *b, *ipiv;
1584-
mem_buff = malloc(N*N*sizeof(@ftyp@) +
1585-
N*NRHS*sizeof(@ftyp@) +
1586-
N*sizeof(fortran_int));
1591+
size_t safe_N = N;
1592+
size_t safe_NRHS = NRHS;
1593+
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@) +
1594+
safe_N * safe_NRHS*sizeof(@ftyp@) +
1595+
safe_N * sizeof(fortran_int));
15871596
if (!mem_buff)
15881597
goto error;
15891598
a = mem_buff;
1590-
b = a + N*N*sizeof(@ftyp@);
1591-
ipiv = b + N*NRHS*sizeof(@ftyp@);
1599+
b = a + safe_N * safe_N * sizeof(@ftyp@);
1600+
ipiv = b + safe_N * safe_NRHS * sizeof(@ftyp@);
15921601

15931602
params->A = a;
15941603
params->B = b;
@@ -1759,8 +1768,9 @@ init_@lapack_func@(POTR_PARAMS_t *params, char UPLO, fortran_int N)
17591768
{
17601769
npy_uint8 *mem_buff = NULL;
17611770
npy_uint8 *a;
1771+
size_t safe_N = N;
17621772

1763-
mem_buff = malloc(N*N*sizeof(@ftyp@));
1773+
mem_buff = malloc(safe_N * safe_N * sizeof(@ftyp@));
17641774
if (!mem_buff)
17651775
goto error;
17661776

@@ -1924,11 +1934,12 @@ init_@lapack_func@(GEEV_PARAMS_t *params, char jobvl, char jobvr, fortran_int n)
19241934
npy_uint8 *mem_buff=NULL;
19251935
npy_uint8 *mem_buff2=NULL;
19261936
npy_uint8 *a, *wr, *wi, *vlr, *vrr, *work, *w, *vl, *vr;
1927-
size_t a_size = n*n*sizeof(@typ@);
1928-
size_t wr_size = n*sizeof(@typ@);
1929-
size_t wi_size = n*sizeof(@typ@);
1930-
size_t vlr_size = jobvl=='V' ? n*n*sizeof(@typ@) : 0;
1931-
size_t vrr_size = jobvr=='V' ? n*n*sizeof(@typ@) : 0;
1937+
size_t safe_n = n;
1938+
size_t a_size = safe_n * safe_n * sizeof(@typ@);
1939+
size_t wr_size = safe_n * sizeof(@typ@);
1940+
size_t wi_size = safe_n * sizeof(@typ@);
1941+
size_t vlr_size = jobvl=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
1942+
size_t vrr_size = jobvr=='V' ? safe_n * safe_n * sizeof(@typ@) : 0;
19321943
size_t w_size = wr_size*2;
19331944
size_t vl_size = vlr_size*2;
19341945
size_t vr_size = vrr_size*2;
@@ -2120,11 +2131,12 @@ init_@lapack_func@(GEEV_PARAMS_t* params,
21202131
npy_uint8 *mem_buff = NULL;
21212132
npy_uint8 *mem_buff2 = NULL;
21222133
npy_uint8 *a, *w, *vl, *vr, *work, *rwork;
2123-
size_t a_size = n*n*sizeof(@ftyp@);
2124-
size_t w_size = n*sizeof(@ftyp@);
2125-
size_t vl_size = jobvl=='V'? n*n*sizeof(@ftyp@) : 0;
2126-
size_t vr_size = jobvr=='V'? n*n*sizeof(@ftyp@) : 0;
2127-
size_t rwork_size = 2*n*sizeof(@realtyp@);
2134+
size_t safe_n = n;
2135+
size_t a_size = safe_n * safe_n * sizeof(@ftyp@);
2136+
size_t w_size = safe_n * sizeof(@ftyp@);
2137+
size_t vl_size = jobvl=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
2138+
size_t vr_size = jobvr=='V'? safe_n * safe_n * sizeof(@ftyp@) : 0;
2139+
size_t rwork_size = 2 * safe_n * sizeof(@realtyp@);
21282140
size_t work_count = 0;
21292141
@typ@ work_size_query;
21302142
fortran_int do_size_query = -1;
@@ -2446,20 +2458,27 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
24462458
npy_uint8 *mem_buff = NULL;
24472459
npy_uint8 *mem_buff2 = NULL;
24482460
npy_uint8 *a, *s, *u, *vt, *work, *iwork;
2449-
size_t a_size = (size_t)m*(size_t)n*sizeof(@ftyp@);
2461+
size_t safe_m = m;
2462+
size_t safe_n = n;
2463+
size_t a_size = safe_m * safe_n * sizeof(@ftyp@);
24502464
fortran_int min_m_n = m<n?m:n;
2451-
size_t s_size = ((size_t)min_m_n)*sizeof(@ftyp@);
2452-
fortran_int u_row_count, vt_column_count;
2465+
size_t safe_min_m_n = min_m_n;
2466+
size_t s_size = safe_min_m_n * sizeof(@ftyp@);
2467+
fortran_int u_row_count, vt_column_count;
2468+
size_t safe_u_row_count, safe_vt_column_count;
24532469
size_t u_size, vt_size;
24542470
fortran_int work_count;
24552471
size_t work_size;
2456-
size_t iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
2472+
size_t iwork_size = 8 * safe_min_m_n * sizeof(fortran_int);
24572473

24582474
if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
24592475
goto error;
24602476

2461-
u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
2462-
vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
2477+
safe_u_row_count = u_row_count;
2478+
safe_vt_column_count = vt_column_count;
2479+
2480+
u_size = safe_u_row_count * safe_m * sizeof(@ftyp@);
2481+
vt_size = safe_n * safe_vt_column_count * sizeof(@ftyp@);
24632482

24642483
mem_buff = malloc(a_size + s_size + u_size + vt_size + iwork_size);
24652484

@@ -2557,21 +2576,28 @@ init_@lapack_func@(GESDD_PARAMS_t *params,
25572576
npy_uint8 *mem_buff = NULL, *mem_buff2 = NULL;
25582577
npy_uint8 *a,*s, *u, *vt, *work, *rwork, *iwork;
25592578
size_t a_size, s_size, u_size, vt_size, work_size, rwork_size, iwork_size;
2579+
size_t safe_u_row_count, safe_vt_column_count;
25602580
fortran_int u_row_count, vt_column_count, work_count;
2581+
size_t safe_m = m;
2582+
size_t safe_n = n;
25612583
fortran_int min_m_n = m<n?m:n;
2584+
size_t safe_min_m_n = min_m_n;
25622585

25632586
if (!compute_urows_vtcolumns(jobz, m, n, &u_row_count, &vt_column_count))
25642587
goto error;
25652588

2566-
a_size = ((size_t)m)*((size_t)n)*sizeof(@ftyp@);
2567-
s_size = ((size_t)min_m_n)*sizeof(@frealtyp@);
2568-
u_size = ((size_t)u_row_count)*m*sizeof(@ftyp@);
2569-
vt_size = n*((size_t)vt_column_count)*sizeof(@ftyp@);
2589+
safe_u_row_count = u_row_count;
2590+
safe_vt_column_count = vt_column_count;
2591+
2592+
a_size = safe_m * safe_n * sizeof(@ftyp@);
2593+
s_size = safe_min_m_n * sizeof(@frealtyp@);
2594+
u_size = safe_u_row_count * safe_m * sizeof(@ftyp@);
2595+
vt_size = safe_n * safe_vt_column_count * sizeof(@ftyp@);
25702596
rwork_size = 'N'==jobz?
2571-
7*((size_t)min_m_n) :
2572-
(5*(size_t)min_m_n*(size_t)min_m_n + 5*(size_t)min_m_n);
2597+
(7 * safe_min_m_n) :
2598+
(5*safe_min_m_n * safe_min_m_n + 5*safe_min_m_n);
25732599
rwork_size *= sizeof(@ftyp@);
2574-
iwork_size = 8*((size_t)min_m_n)*sizeof(fortran_int);
2600+
iwork_size = 8 * safe_min_m_n* sizeof(fortran_int);
25752601

25762602
mem_buff = malloc(a_size +
25772603
s_size +

0 commit comments

Comments
 (0)
0