13
13
import sys
14
14
import warnings
15
15
from abc import ABCMeta , abstractmethod
16
- from concurrent .futures import ProcessPoolExecutor , as_completed
17
16
from copy import copy
18
17
from functools import lru_cache , partial
19
18
from itertools import chain , product , repeat
20
19
from math import copysign
21
20
from numbers import Number
22
- from typing import Callable , Dict , List , Optional , Sequence , Tuple , Type , Union
21
+ from typing import Callable , List , Optional , Sequence , Tuple , Type , Union
23
22
24
23
import numpy as np
25
24
import pandas as pd
@@ -34,7 +33,10 @@ def _tqdm(seq, **_):
34
33
35
34
from ._plotting import plot # noqa: I001
36
35
from ._stats import compute_stats
37
- from ._util import _as_str , _Indicator , _Data , _indicator_warmup_nbars , _strategy_indicators , try_
36
+ from ._util import (
37
+ SharedMemory , SharedMemoryManager , _as_str , _Indicator , _Data , _indicator_warmup_nbars ,
38
+ _strategy_indicators , patch , try_ ,
39
+ )
38
40
39
41
__pdoc__ = {
40
42
'Strategy.__init__' : False ,
@@ -1498,40 +1500,40 @@ def _optimize_grid() -> Union[pd.Series, Tuple[pd.Series, pd.Series]]:
1498
1500
names = next (iter (param_combos )).keys ()))
1499
1501
1500
1502
def _batch (seq ):
1503
+ # XXX: Replace with itertools.batched
1501
1504
n = np .clip (int (len (seq ) // (os .cpu_count () or 1 )), 1 , 300 )
1502
1505
for i in range (0 , len (seq ), n ):
1503
1506
yield seq [i :i + n ]
1504
1507
1505
- # Save necessary objects into "global" state; pass into concurrent executor
1506
- # (and thus pickle) nothing but two numbers; receive nothing but numbers.
1507
- # With start method "fork", children processes will inherit parent address space
1508
- # in a copy-on-write manner, achieving better performance/RAM benefit.
1509
- backtest_uuid = np .random .random ()
1510
- param_batches = list (_batch (param_combos ))
1511
- Backtest ._mp_backtests [backtest_uuid ] = (self , param_batches , maximize )
1512
- try :
1513
- # If multiprocessing start method is 'fork' (i.e. on POSIX), use
1514
- # a pool of processes to compute results in parallel.
1515
- # Otherwise (i.e. on Windos), sequential computation will be "faster".
1516
- if mp .get_start_method (allow_none = False ) == 'fork' :
1517
- with ProcessPoolExecutor () as executor :
1518
- futures = [executor .submit (Backtest ._mp_task , backtest_uuid , i )
1519
- for i in range (len (param_batches ))]
1520
- for future in _tqdm (as_completed (futures ), total = len (futures ),
1521
- desc = 'Backtest.optimize' ):
1522
- batch_index , values = future .result ()
1523
- for value , params in zip (values , param_batches [batch_index ]):
1524
- heatmap [tuple (params .values ())] = value
1525
- else :
1526
- if os .name == 'posix' :
1527
- warnings .warn ("For multiprocessing support in `Backtest.optimize()` "
1528
- "set multiprocessing start method to 'fork'." )
1529
- for batch_index in _tqdm (range (len (param_batches ))):
1530
- _ , values = Backtest ._mp_task (backtest_uuid , batch_index )
1531
- for value , params in zip (values , param_batches [batch_index ]):
1532
- heatmap [tuple (params .values ())] = value
1533
- finally :
1534
- del Backtest ._mp_backtests [backtest_uuid ]
1508
+ with mp .Pool () as pool , \
1509
+ SharedMemoryManager () as smm :
1510
+
1511
+ def arr2shm (vals ):
1512
+ nonlocal smm
1513
+ shm = smm .SharedMemory (size = vals .nbytes )
1514
+ buf = np .ndarray (vals .shape , dtype = vals .dtype , buffer = shm .buf )
1515
+ buf [:] = vals [:] # Copy into shared memory
1516
+ assert vals .ndim == 1 , (vals .ndim , vals .shape , vals )
1517
+ return shm .name , vals .shape , vals .dtype
1518
+
1519
+ data_shm = tuple ((
1520
+ (column , * arr2shm (values ))
1521
+ for column , values in chain ([(Backtest ._mp_task_INDEX_COL , self ._data .index )],
1522
+ self ._data .items ())
1523
+ ))
1524
+ with patch (self , '_data' , None ):
1525
+ bt = copy (self ) # bt._data will be reassigned in _mp_task worker
1526
+ results = _tqdm (
1527
+ pool .imap (Backtest ._mp_task ,
1528
+ ((bt , data_shm , params_batch )
1529
+ for params_batch in _batch (param_combos ))),
1530
+ total = len (param_combos ),
1531
+ desc = 'Backtest.optimize'
1532
+ )
1533
+ for param_batch , result in zip (_batch (param_combos ), results ):
1534
+ for params , stats in zip (param_batch , result ):
1535
+ if stats is not None :
1536
+ heatmap [tuple (params .values ())] = maximize (stats )
1535
1537
1536
1538
if pd .isnull (heatmap ).all ():
1537
1539
# No trade was made in any of the runs. Just make a random
@@ -1625,13 +1627,28 @@ def cons(x):
1625
1627
return output
1626
1628
1627
1629
@staticmethod
1628
- def _mp_task (backtest_uuid , batch_index ):
1629
- bt , param_batches , maximize_func = Backtest ._mp_backtests [backtest_uuid ]
1630
- return batch_index , [maximize_func (stats ) if stats ['# Trades' ] else np .nan
1631
- for stats in (bt .run (** params )
1632
- for params in param_batches [batch_index ])]
1633
-
1634
- _mp_backtests : Dict [float , Tuple ['Backtest' , List , Callable ]] = {}
1630
+ def _mp_task (arg ):
1631
+ bt , data_shm , params_batch = arg
1632
+ shm = [SharedMemory (name = shm_name , create = False , track = False )
1633
+ for _ , shm_name , * _ in data_shm ]
1634
+ try :
1635
+ def shm2arr (shm , shape , dtype ):
1636
+ arr = np .ndarray (shape , dtype = dtype , buffer = shm .buf )
1637
+ arr .setflags (write = False )
1638
+ return arr
1639
+
1640
+ bt ._data = df = pd .DataFrame ({
1641
+ col : shm2arr (shm , shape , dtype )
1642
+ for shm , (col , _ , shape , dtype ) in zip (shm , data_shm )})
1643
+ df .set_index (Backtest ._mp_task_INDEX_COL , drop = True , inplace = True )
1644
+ return [stats .filter (regex = '^[^_]' ) if stats ['# Trades' ] else None
1645
+ for stats in (bt .run (** params )
1646
+ for params in params_batch )]
1647
+ finally :
1648
+ for shmem in shm :
1649
+ shmem .close ()
1650
+
1651
+ _mp_task_INDEX_COL = '__bt_index'
1635
1652
1636
1653
def plot (self , * , results : pd .Series = None , filename = None , plot_width = None ,
1637
1654
plot_equity = True , plot_return = False , plot_pl = True ,
0 commit comments