8000 MNT: improve process ordering for spawned workers. · RocketPy-Team/RocketPy@d22c957 · GitHub
[go: up one dir, main page]

Skip to content

Commit d22c957

Browse files
committed
MNT: improve process ordering for spawned workers.
1 parent 1e24643 commit d22c957

File tree

1 file changed

+13
-10
lines changed

1 file changed

+13
-10
lines changed

rocketpy/simulation/monte_carlo.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,7 @@ def __run_in_parallel(self, n_workers=None):
298298
None
299299
"""
300300
if n_workers is None or n_workers > os.cpu_count():
301-
# For Windows, the number of workers must be at most os.cpu_count() - 1
302-
n_workers = os.cpu_count() - 1
301+
n_workers = os.cpu_count()
303302

304303
if n_workers < 2:
305304
raise ValueError("Number of workers must be at least 2 for parallel mode.")
@@ -321,6 +320,13 @@ def __run_in_parallel(self, n_workers=None):
321320
processes = []
322321
seeds = np.random.SeedSequence().spawn(n_workers - 1)
323322

323+
sim_consumer = multiprocess.Process(
324+
target=self.__sim_consumer,
325+
args=(export_queue, mutex, consumer_stop_event, simulation_error_event),
326+
)
327+
328+
sim_consumer.start()
329+
324330
for seed in seeds:
325331
sim_producer = multiprocess.Process(
326332
target=self.__sim_producer,
@@ -337,13 +343,6 @@ def __run_in_parallel(self, n_workers=None):
337343
for sim_producer in processes:
338344
sim_producer.start()
339345

340-
sim_consumer = multiprocess.Process(
341-
target=self.__sim_consumer,
342-
args=(export_queue, mutex, consumer_stop_event, simulation_error_event),
343-
)
344-
345-
sim_consumer.start()
346-
347346
try:
348347
for sim_producer in processes:
349348
sim_producer.join()
@@ -455,14 +454,18 @@ def __sim_consumer(
455454
The event indicating that an error occurred during the simulation.
456455
"""
457456
trials = 0
458-
while not stop_event.is_set() and not error_event.is_set():
457+
458+
while not error_event.is_set():
459459
try:
460460
mutex.acquire()
461461
inputs_dict, outputs_dict = export_queue.get(timeout=3)
462462

463463
self.__export_flight_data(inputs_dict, outputs_dict)
464464

465465
except queue.Empty as exc:
466+
if stop_event.is_set():
467+
break
468+
466469
trials += 1
467470

468471
if trials > 10:

0 commit comments

Comments
 (0)
0