8000 Merge pull request #28144 from ngoldbaum/rm-update-flags · numpy/numpy@bbf4836 · GitHub
[go: up one dir, main page]

Skip to content

Commit bbf4836

Browse files
authored
Merge pull request #28144 from ngoldbaum/rm-update-flags
BUG: remove unnecessary call to PyArray_UpdateFlags
2 parents 6bedb61 + bb75e6e commit bbf4836

File tree

3 files changed

+38
-19
lines changed

3 files changed

+38
-19
lines changed

numpy/_core/src/multiarray/iterators.c

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,6 @@ PyArray_RawIterBaseInit(PyArrayIterObject *it, PyArrayObject *ao)
136136
nd = PyArray_NDIM(ao);
137137
/* The legacy iterator only supports 32 dimensions */
138138
assert(nd <= NPY_MAXDIMS_LEGACY_ITERS);
139-
PyArray_UpdateFlags(ao, NPY_ARRAY_C_CONTIGUOUS);
140139
if (PyArray_ISCONTIGUOUS(ao)) {
141140
it->contiguous = 1;
142141
}

numpy/_core/tests/test_multithreading.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ def func(seed):
1818

1919
run_threaded(func, 500, pass_count=True)
2020

21+
2122
def test_parallel_ufunc_execution():
2223
# if the loop data cache or dispatch cache are not thread-safe
2324
# computing ufuncs simultaneously in multiple threads leads
@@ -31,18 +32,14 @@ def func():
3132
# see gh-26690
3233
NUM_THREADS = 50
3334

34-
b = threading.Barrier(NUM_THREADS)
35-
3635
a = np.ones(1000)
3736

38-
def f():
37+
def f(b):
3938
b.wait()
4039
return a.sum()
4140

42-
threads = [threading.Thread(target=f) for _ in range(NUM_THREADS)]
41+
run_threaded(f, NUM_THREADS, max_workers=NUM_THREADS, pass_barrier=True)
4342

44-
[t.start() for t in threads]
45-
[t.join() for t in threads]
4643

4744
def test_temp_elision_thread_safety():
4845
amid = np.ones(50000)
@@ -121,16 +118,27 @@ def legacy_125():
121118
task1.start()
122119
task2.start()
123120

121+
124122
def test_parallel_reduction():
125123
# gh-28041
126124
NUM_THREADS = 50
127125

128-
b = threading.Barrier(NUM_THREADS)
129-
130126
x = np.arange(1000)
131127

132-
def closure():
128+
def closure(b):
133129
b.wait()
134130
np.sum(x)
135131

136-
run_threaded(closure, NUM_THREADS, max_workers=NUM_THREADS)
132+
run_threaded(closure, NUM_THREADS, max_workers=NUM_THREADS,
133+
pass_barrier=True)
134+
135+
136+
def test_parallel_flat_iterator():
137+
x = np.arange(20).reshape(5, 4).T
138+
139+
def closure(b):
140+
b.wait()
141+
for _ in range(100):
142+
list(x.flat)
143+
144+
run_threaded(closure, outer_iterations=100, pass_barrier=True)

numpy/testing/_private/utils.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import pprint
1919
import sysconfig
2020
import concurrent.futures
21+
import threading
2122

2223
import numpy as np
2324
from numpy._core import (
@@ -2685,12 +2686,23 @@ def _get_glibc_version():
26852686
_glibc_older_than = lambda x: (_glibcver != '0.0' and _glibcver < x)
26862687

26872688

2688-
def run_threaded(func, iters, pass_count=False, max_workers=8):
2689+
def run_threaded(func, iters=8, pass_count=False, max_workers=8,
2690+
pass_barrier=False, outer_iterations=1):
26892691
"""Runs a function many times in parallel"""
2690-
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as tpe:
2691-
if pass_count:
2692-
futures = [tpe.submit(func, i) for i in range(iters)]
2693-
else:
2694-
futures = [tpe.submit(func) for _ in range(iters)]
2695-
for f in futures:
2696-
f.result()
2692+
for _ in range(outer_iterations):
2693+
with (concurrent.futures.ThreadPoolExecutor(max_workers=max_workers)
2694+
as tpe):
2695+
args = []
2696+
if pass_barrier:
2697+
if max_workers != iters:
2698+
raise RuntimeError(
2699+
"Must set max_workers equal to the number of "
2700+
"iterations to avoid deadlocks.")
2701+
barrier = threading.Barrier(max_workers)
2702+
args.append(barrier)
2703+
if pass_count:
2704+
futures = [tpe.submit(func, i, *args) for i in range(iters)]
2705+
else:
2706+
futures = [tpe.submit(func, *args) for _ in range(iters)]
2707+
for f in futures:
2708+
f.result()

0 commit comments

Comments
 (0)
0