10BC0 升级expr_codegen · wukan1986/alpha_examples@eca9418 · GitHub
[go: up one dir, main page]

Skip to content

Commit eca9418

Browse files
committed
升级expr_codegen
1 parent b15276c commit eca9418

File tree

6 files changed

+25
-25
lines changed

6 files changed

+25
-25
lines changed

codes/features.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from polars_ta.prefix.ta import * # noqa
2020
from polars_ta.prefix.wq import * # noqa
2121
from polars_ta.prefix.cdl import * # noqa
22+
from polars_ta.prefix.vec import * # noqa
2223

2324
DataFrame = TypeVar("DataFrame", _pl_LazyFrame, _pl_DataFrame)
2425
# ===================================
@@ -91,19 +92,18 @@ def func_0_ts__asset(df: DataFrame) -> DataFrame:
9192
"""
9293

9394

94-
def filter_last(df: DataFrame) -> DataFrame:
95-
"""过滤数据,只取最后一天。实盘时可用于减少计算量
96-
前一个调用的ts,这里可以直接调用,可以认为已经排序好
97-
`df = filter_last(df)`
98-
反之
99-
`df = filter_last(df.sort(_DATE_))`
100-
"""
101-
return df.filter(pl.col(_DATE_) >= df.select(pl.last(_DATE_))[0, 0])
95+
def _filter_last(df: DataFrame, ge_date_idx: int) -> DataFrame:
96+
"""过滤数据,只取最后几天。实盘时可用于减少计算量"""
97+
if ge_date_idx == 0:
98+
return df
99+
else:
100+
return df.filter(pl.col(_DATE_) >= df.select(pl.col(_DATE_).unique().sort())[ge_date_idx, 0])
102101

103102

104-
def main(df: DataFrame) -> DataFrame:
103+
def main(df: DataFrame, ge_date_idx: int) -> DataFrame:
105104

106105
df = func_0_ts__asset(df.sort(_ASSET_, _DATE_)).drop(*["_x_0"])
106+
df = _filter_last(df, ge_date_idx)
107107

108108
# drop intermediate columns
109109
# df = df.select(pl.exclude(r'^_x_\d+$'))

codes/labels.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,13 @@
1919
from polars_ta.prefix.ta import * # noqa
2020
from polars_ta.prefix.wq import * # noqa
2121
from polars_ta.prefix.cdl import * # noqa
22+
from polars_ta.prefix.vec import * # noqa
2223

2324
DataFrame = TypeVar("DataFrame", _pl_LazyFrame, _pl_DataFrame)
2425
# ===================================
2526

26-
_ = ["DOJI", "HIGH", "CLOSE", "OPEN", "LOW"]
27-
[DOJI, HIGH, CLOSE, OPEN, LOW] = [pl.col(i) for i in _]
27+
_ = ["OPEN", "DOJI", "HIGH", "CLOSE", "LOW"]
28+
[OPEN, DOJI, HIGH, CLOSE, LOW] = [pl.col(i) for i in _]
2829

2930
_ = ["_x_1", "_x_2", "_x_0", "RETURN_CC_1", "_x_3", "RETURN_CO_1", "RETURN_OC_1", "NEXT_DOJI", "RETURN_OO_1", "RETURN_OO_5"]
3031
[_x_1, _x_2, _x_0, RETURN_CC_1, _x_3, RETURN_CO_1, RETURN_OC_1, NEXT_DOJI, RETURN_OO_1, RETURN_OO_5] = [pl.col(i) for i in _]
@@ -118,21 +119,20 @@ def func_1_ts__asset(df: DataFrame) -> DataFrame:
118119
"""
119120

120121

121-
def filter_last(df: DataFrame) -> DataFrame:
122-
"""过滤数据,只取最后一天。实盘时可用于减少计算量
123-
前一个调用的ts,这里可以直接调用,可以认为已经排序好
124-
`df = filter_last(df)`
125-
反之
126-
`df = filter_last(df.sort(_DATE_))`
127-
"""
128-
return df.filter(pl.col(_DATE_) >= df.select(pl.last(_DATE_))[0, 0])
122+
def _filter_last(df: DataFrame, ge_date_idx: int) -> DataFrame:
123+
"""过滤数据,只取最后几天。实盘时可用于减少计算量"""
124+
if ge_date_idx == 0:
125+
return df
126+
else:
127+
return df.filter(pl.col(_DATE_) >= df.select(pl.col(_DATE_).unique().sort())[ge_date_idx, 0])
129128

130129

131-
def main(df: DataFrame) -> DataFrame:
130+
def main(df: DataFrame, ge_date_idx: int) -> DataFrame:
132131

133132
df = func_0_ts__asset(df.sort(_ASSET_, _DATE_)).drop(*[])
134-
df = func_0_cl(df).drop(*["_x_0", "_x_1", "_x_2"])
133+
df = func_0_cl(df).drop(*["_x_2", "_x_0", "_x_1"])
135134
df = func_1_ts__asset(df.sort(_ASSET_, _DATE_)).drop(*["_x_3"])
135+
df = _filter_last(df, ge_date_idx)
136136

137137
# drop intermediate columns
138138
# df = df.select(pl.exclude(r'^_x_\d+$'))

gp_base_cs/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def batched_exprs(batch_id, exprs_list, gen, label, split_date, df_input):
104104

105105
globals_ = {}
106106
exec(codes, globals_)
107-
df_output = globals_['main'](df_input, filter_last=False)
107+
df_output = globals_['main'](df_input, ge_date_idx=0)
108108

109109
elapsed_time = time.perf_counter() - tic
110110
logger.info("{}代{}批 因子 计算完成。共用时 {:.3f} 秒,平均 {:.3f} 秒/条,或 {:.3f} 条/秒", gen, batch_id, elapsed_time, elapsed_time / cnt, cnt / elapsed_time)

gp_base_ts/helper.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def batched_exprs(batch_id, exprs_list, gen, label, split_date, df_input):
9696

9797
globals_ = {}
9898
exec(codes, globals_)
99-
df_output = globals_['main'](df_input, filter_last=False)
99+
df_output = globals_['main'](df_input, ge_date_idx=0)
100100

101101
elapsed_time = time.perf_counter() - tic
102102
logger.info("{}代{}批 因子 计算完成。共用时 {:.3f} 秒,平均 {:.3f} 秒/条,或 {:.3f} 条/秒", gen, batch_id, elapsed_time, elapsed_time / cnt, cnt / elapsed_time)

gp_run/requirements.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ polars_ta
33
tensorboard
44
tensorboardX
55
more_itertools
6-
expr_codegen>=0.14.0
6+
expr_codegen>=0.15.0

gp_run/requirements_node.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
ray[default]
22
polars_ta
3-
expr_codegen>=0.14.0
3+
expr_codegen>=0.15.0
44
more_itertools

0 commit comments

Comments
 (0)
0