8000 ENH: Grid optimization with mp.Pool & mp.shm.SharedMemory · kernc/backtesting.py@9a314b3 · GitHub
[go: up one dir, main page]

Skip to content

Commit 9a314b3

Browse files
committed
ENH: Grid optimization with mp.Pool & mp.shm.SharedMemory
Includes backported Share 10000 dMemory from 3.13 python/cpython#82300 (comment)
1 parent 8edc53b commit 9a314b3

File tree

3 files changed

+108
-55
lines changed

3 files changed

+108
-55
lines changed

backtesting/_util.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,12 @@
11
from __future__ import annotations
22

3+
import sys
34
import warnings
45
from contextlib import contextmanager
6+
from multiprocessing import resource_tracker as _mprt
7+
from multiprocessing import shared_memory as _mpshm
58
from numbers import Number
9+
from threading import Lock
610
from typing import Dict, List, Optional, Sequence, Union, cast
711

812
import numpy as np
@@ -225,3 +229,47 @@ def __getstate__(self):
225229

226230
def __setstate__(self, state):
227231
self.__dict__ = state
232+
233+
234+
if sys.version_info >= (3, 13):
235+
SharedMemory = _mpshm.SharedMemory
236+
from multiprocessing.managers import SharedMemoryManager # noqa: F401
237+
else:
238+
class SharedMemory(_mpshm.SharedMemory):
239+
# From https://github.com/python/cpython/issues/82300#issuecomment-2169035092
240+
__lock = Lock()
241+
242+
def __init__(self, *args, track: bool = True, **kwargs):
243+
self._track = track
244+
if track:
245+
return super().__init__(*args, **kwargs)
246+
with self.__lock:
247+
with patch(_mprt, 'register', lambda *a, **kw: None): # TODO lambda
248+
super().__init__(*args, **kwargs)
249+
250+
def unlink(self):
251+
if _mpshm._USE_POSIX and self._name:
252+
_mpshm._posixshmem.shm_unlink(self._name)
253+
if self._track:
254+
_mprt.unregister(self._name, "shared_memory")
255+
256+
class SharedMemoryManager:
257+
def __init__(self) -> None:
258+
self._shms: list[SharedMemory] = []
259+
260+
def SharedMemory(self, size):
261+
shm = SharedMemory(create=True, size=size, track=True)
262+
self._shms.append(shm)
263+
return shm
264+
265+
def __enter__(self):
266+
return self
267+
268+
def __exit__(self, *args, **kwargs):
269+
for shm in self._shms:
270+
try:
271+
shm.close()
272+
shm.unlink()
273+
except Exception:
274+
warnings.warn(f'Failed to unlink shared memory {shm.name!r}',
275+
category=ResourceWarning, stacklevel=2)

backtesting/backtesting.py

Lines changed: 57 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,12 @@
1313
import sys
1414
import warnings
1515
from abc import ABCMeta, abstractmethod
16-
from concurrent.futures import ProcessPoolExecutor, as_completed
1716
from copy import copy
1817
from functools import lru_cache, partial
1918
from itertools import chain, product, repeat
2019
from math import copysign
2120
from numbers import Number
22-
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
21+
from typing import Callable, List, Optional, Sequence, Tuple, Type, Union
2322

2423
import numpy as np
2524
import pandas as pd
@@ -34,7 +33,10 @@ def _tqdm(seq, **_):
3433

3534
from ._plotting import plot # noqa: I001
3635
from ._stats import compute_stats
37-
from ._util import _as_str, _Indicator, _Data, _indicator_warmup_nbars, _strategy_indicators, try_
36+
from ._util import (
37+
SharedMemory, SharedMemoryManager, _as_str, _Indicator, _Data, _indicator_warmup_nbars,
38+
_strategy_indicators, patch, try_,
39+
)
3840

3941
__pdoc__ = {
4042
'Strategy.__init__': False,
@@ -1498,40 +1500,40 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
14981500
names=next(iter(param_combos)).keys()))
14991501

15001502
def _batch(seq):
1503+
# XXX: Replace with itertools.batched
15011504
n = np.clip(int(len(seq) // (os.cpu_count() or 1)), 1, 300)
15021505
for i in range(0, len(seq), n):
15031506
yield seq[i:i + n]
15041507

1505-
# Save necessary objects into "global" state; pass into concurrent executor
1506-
# (and thus pickle) nothing but two numbers; receive nothing but numbers.
1507-
# With start method "fork", children processes will inherit parent address space
1508-
# in a copy-on-write manner, achieving better performance/RAM benefit.
1509-
backtest_uuid = np.random.random()
1510-
param_batches = list(_batch(param_combos))
1511-
Backtest._mp_backtests[backtest_uuid] = (self, param_batches, maximize)
1512-
try:
1513-
# If multiprocessing start method is 'fork' (i.e. on POSIX), use
1514-
# a pool of processes to compute results in parallel.
1515-
# Otherwise (i.e. on Windos), sequential computation will be "faster".
1516-
if mp.get_start_method(allow_none=False) == 'fork':
1517-
with ProcessPoolExecutor() as executor:
1518-
futures = [executor.submit(Backtest._mp_task, backtest_uuid, i)
1519-
for i in range(len(param_batches))]
1520-
for future in _tqdm(as_completed(futures), total=len(futures),
1521-
desc='Backtest.optimize'):
1522-
batch_index, values = future.result()
1523-
for value, params in zip(values, param_batches[batch_index]):
1524-
heatmap[tuple(params.values())] = value
1525-
else:
1526-
if os.name == 'posix':
1527-
warnings.warn("For multiprocessing support in `Backtest.optimize()` "
1528-
"set multiprocessing start method to 'fork'.")
1529-
for batch_index in _tqdm(range(len(param_batches))):
1530-
_, values = Backtest._mp_task(backtest_uuid, batch_index)
1531-
for value, params in zip(values, param_batches[batch_index]):
1532-
heatmap[tuple(params.values())] = value
1533-
finally:
1534-
del Backtest._mp_backtests[backtest_uuid]
1508+
with mp.Pool() as pool, \
1509+
SharedMemoryManager() as smm:
1510+
1511+
def arr2shm(vals):
1512+
nonlocal smm
1513+
shm = smm.SharedMemory(size=vals.nbytes)
1514+
buf = np.ndarray(vals.shape, dtype=vals.dtype, buffer=shm.buf)
1515+
buf[:] = vals[:] # Copy into shared memory
1516+
assert vals.ndim == 1, (vals.ndim, vals.shape, vals)
1517+
return shm.name, vals.shape, vals.dtype
1518+
1519+
data_shm = tuple((
1520+
(column, *arr2shm(values))
1521+
for column, values in chain([(Backtest._mp_task_INDEX_COL, self._data.index)],
1522+
self._data.items())
1523+
))
1524+
with patch(self, '_data', None):
1525+
bt = copy(self) # bt._data will be reassigned in _mp_task worker
1526+
results = _tqdm(
1527+
pool.imap(Backtest._mp_task,
1528+
((bt, data_shm, params_batch)
1529+
for params_batch in _batch(param_combos))),
1530+
total=len(param_combos),
1531+
desc='Backtest.optimize'
1532+
)
1533+
for param_batch, result in zip(_batch(param_combos), results):
1534+
for params, stats in zip(param_batch, result):
1535+
if stats is not None:
1536+
heatmap[tuple(params.values())] = maximize(stats)
15351537

15361538
if pd.isnull(heatmap).all():
15371539
# No trade was made in any of the runs. Just make a random
@@ -1625,13 +1627,28 @@ def cons(x):
16251627
return output
16261628

16271629
@staticmethod
1628-
def _mp_task(backtest_uuid, batch_index):
1629-
bt, param_batches, maximize_func = Backtest._mp_backtests[backtest_uuid]
1630-
return batch_index, [maximize_func(stats) if stats['# Trades'] else np.nan
1631-
for stats in (bt.run(**params)
1632-
for params in param_batches[batch_index])]
1633-
1634-
_mp_backtests: Dict[float, Tuple['Backtest', List, Callable]] = {}
1630+
def _mp_task(arg):
1631+
bt, data_shm, params_batch = arg
1632+
shm = [SharedMemory(name=shm_name, create=False, track=False)
1633+
for _, shm_name, *_ in data_shm]
1634+
try:
1635+
def shm2arr(shm, shape, dtype):
1636+
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
1637+
arr.setflags(write=False)
1638+
return arr
1639+
1640+
bt._data = df = pd.DataFrame({
1641+
col: shm2arr(shm, shape, dtype)
1642+
for shm, (col, _, shape, dtype) in zip(shm, data_shm)})
1643+
df.set_index(Backtest._mp_task_INDEX_COL, drop=True, inplace=True)
1644+
return [stats.filter(regex='^[^_]') if stats['# Trades'] else None
1645+
for stats in (bt.run(**params)
1646+
for params in params_batch)]
1647+
finally:
1648+
for shmem in shm:
1649+
shmem.close()
1650+
1651+
_mp_task_INDEX_COL = '__bt_index'
16351652

16361653
def plot(self, *, results: pd.Series = None, filename=None, plot_width=None,
16371654
plot_equity=True, plot_return=False, plot_pl=True,

backtesting/test/_test.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
import inspect
2-
import multiprocessing
32
import os
43
import sys
54
import time
@@ -621,18 +620,6 @@ def test_max_tries(self):
621620
**OPT_PARAMS)
622621
self.assertEqual(len(heatmap), 6)
623622

624-
def test_multiprocessing_windows_spawn(self):
625-
df = GOOG.iloc[:100]
626-
kw = {'fast': [10]}
627-
628-
stats1 = Backtest(df, SmaCross).optimize(**kw)
629-
with patch(multiprocessing, 'get_start_method', lambda **_: 'spawn'):
630-
with self.assertWarns(UserWarning) as cm:
631-
stats2 = Backtest(df, SmaCross).optimize(**kw)
632-
633-
self.assertIn('multiprocessing support', cm.warning.args[0])
634-
assert stats1.filter(chars := tuple('[^_]')).equals(stats2.filter(chars)), (stats1, stats2)
635-
636623
def test_optimize_invalid_param(self):
637624
bt = Backtest(GOOG.iloc[:100], SmaCross)
638625
self.assertRaises(AttributeError, bt.optimize, foo=range(3))
@@ -646,9 +633,10 @@ def test_optimize_no_trades(self):
646633
def test_optimize_speed(self):
647634
bt = Backtest(GOOG.iloc[:100], SmaCross)
648635
start = time.process_time()
649-
bt.optimize(fast=(2, 5, 7), slow=[10, 15, 20, 30])
636+
bt.optimize(fast=range(2, 20, 2), slow=range(10, 40, 2))
650637
end = time.process_time()
651-
self.assertLess(end - start, .2)
638+
print(end - start)
639+
self.assertLess(end - start, .3)
652640

653641

654642
class TestPlot(TestCase):

0 commit comments

Comments
 (0)
0