8000 部分改进 · wukan1986/alpha_examples@a491364 · GitHub
[go: up one dir, main page]

Skip to content

Commit a491364

Browse files
committed
部分改进
1 parent 95012c2 commit a491364

File tree

5 files changed

+65
-50
lines changed

5 files changed

+65
-50
lines changed

ml_cs/config.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,36 +5,30 @@
55
"""
66
import polars as pl # noqa
77
import polars.selectors as cs # noqa
8-
from polars_ta.prefix.wq import cs_zscore
98

109
# %%
1110
DATE = "date"
1211
ASSET = "asset"
1312
LABEL = 'LABEL' # 训练用的标签
14-
FWD_RET = 'FWD_RET' # 计算净值必需提供1日收益率
13+
FWD_RET = 'FWD_RET' # 计算净值必需提供日化收益率
1514
DATA_END = '2025-03'
1615
DATA_START = '2025-04'
1716

18-
INPUT1_PATH = r'M:\preprocessing\out1.parquet' # 添加了特征的数据
17+
INPUT1_PATH = r'M:\preprocessing\data5.parquet' # 添加了特征的数据
1918

2019
# %%
2120
MODEL_FILENAME = r'D:\GitHub\alpha_examples\ml_cs\models.pkl' # 训练后保存的模型名
2221
PRED_PATH = 'pred.parquet' # 预测结果
2322
PRED_EXCEL = 'pred.xlsx' # 预测结果导出Excel
2423

2524
# %%
26-
# TODO 丢弃的字段。保留的字段远远多余丢弃的字段,用丢弃法
27-
# 1. 对机器学习无意义的字段
28-
# 2. 留下日期、资产、多个特征、一标签、一未来收益
29-
drop_columns = [
30-
'paused', 'factor',
31-
'high_limit', 'low_limit',
32-
'sw_l1', 'sw_l3', 'sw_l2', 'zjw',
33-
'上海主板', '深圳主板', '科创板', '创业板', '北交所',
34-
'NEXT_DOJI4',
35-
'SSE50', 'CSI300', 'CSI500', 'CSI1000',
36-
'pe_ratio', 'pb_ratio', 'ps_ratio', 'pcf_ratio', 'pe_ratio_lyr',
37-
"ONE", "MC_LOG", "MC_NORM", 'market_cap', 'circulating_market_cap',
25+
# TODO 特征
26+
feature_columns = [
27+
"MC_NEUT", "EP", "BP", "SP", "CFP",
28+
29+
"DOJI4",
30+
31+
"A_0001", "A_0002", "A_0003",
3832
]
3933

4034
# TODO 分类特征。布尔型号和少量的整数型,只在LightGBM中使用
@@ -45,21 +39,24 @@
4539
# '当前价格是否高于10日均线',
4640
]
4741

48-
exclude_columns = [
49-
]
50-
5142

5243
# %%
5344
def load_process():
5445
"""加载数据,然后进行预处理"""
55-
df = pl.read_parquet(INPUT1_PATH)
46+
df: pl.DataFrame = pl.read_parquet(INPUT1_PATH)
5647
print(df.columns)
5748

58-
# 删除不需要的字段。留下日期、资产、多个特征、一标签、一未来收益
59-
df = df.drop(drop_columns)
49+
# 留下日期、资产、多个特征、一标签、一未来收益
50+
df = df.select(DATE, ASSET, LABEL, FWD_RET, *feature_columns)
51+
52+
# 预处理,需要提前在其他地方处理好,这里不再处理
53+
# df = df.with_columns(
54+
# cs_zscore(cs.float() & cs.exclude(DATE, ASSET, LABEL, FWD_RET, *exclude_columns)).over(DATE)
55+
# )
6056

61-
# 预处理
57+
# TODO 回归问题转换成分类问题
6258
df = df.with_columns(
63-
cs_zscore(cs.float() & cs.exclude(DATE, ASSET, LABEL, FWD_RET, *exclude_columns)).over(DATE)
59+
(pl.col(LABEL) > 0.00).cast(pl.UInt8)
6460
)
61+
print(df[LABEL].value_counts())
6562
return df

ml_cs/pred.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
import joblib
2+
import numpy as np
23
import polars as pl
34
from alphainspect.reports import create_3x2_sheet
45
from alphainspect.utils import with_factor_quantile
56
from loguru import logger
67
from matplotlib import pyplot as plt
8+
from sklearn.metrics import roc_auc_score, classification_report
79

810
from ml_cs.config import DATE, ASSET, LABEL, MODEL_FILENAME, INPUT1_PATH, DATA_START, FWD_RET, load_process
911
from ml_cs.utils import load_dates, get_XyOther, walk_forward
@@ -19,24 +21,30 @@
1921
logger.info('加载模型...')
2022
models = joblib.load(MODEL_FILENAME)
2123

24+
# TODO 试验阶段is_test=True
25+
is_test = True
26+
2227

2328
# %% 预测
2429
def predict():
2530
trading_dates = load_dates(INPUT1_PATH, DATE)[DATA_START:]
2631

2732
others = []
2833
for i, train_dt, test_dt in walk_forward(trading_dates,
29-
n_splits=3, max_train_size=None, test_size=None, gap=0):
34+
n_splits=1, max_train_size=None, test_size=None, gap=0):
3035
start, end = train_dt[0], test_dt[-1]
31-
X, y, other = get_XyOther(df, start, end, DATE, ASSET, LABEL, FWD_RET, is_fit=False)
36+
37+
X_test, y_test, other = get_XyOther(df, start, end, DATE, ASSET, LABEL, FWD_RET, is_test=is_test)
3238

3339
y_preds = {}
3440
for i, model in enumerate(models):
35-
# print(f'{i}: {model.__class__.__name__}')
36-
if hasattr(model, 'best_iteration'):
37-
y_preds[f'y_pred_{i}'] = model.predict(X, num_iteration=model.best_iteration)
38-
else:
39-
y_preds[f'y_pred_{i}'] = model.predict(X)
41+
num_iteration = model.best_iteration if hasattr(model, 'best_iteration') else None
42+
pred_proba = model.predict(X_test, num_iteration=num_iteration)
43+
print("预测概率范围:", pred_proba.min(), "~", pred_proba.max())
44+
if is_test:
45+
print("AUC分数:", roc_auc_score(y_test, pred_proba))
46+
print(classification_report(y_test, (pred_proba > 0.5).astype(int), zero_division=np.nan))
47+
y_preds[f'y_pred_{i}'] = pred_proba
4048
# TODO 预测值等权,可以按需进行权重分配
4149
result = other.with_columns(y_pred=pl.from_dict(y_preds).mean_horizontal())
4250
others.append(result)

ml_cs/train_lasso.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ def fit():
2424
for i, train_dt, test_dt in walk_forward(trading_dates,
2525
n_splits=5, max_train_size=None, test_size=30, gap=3):
2626
for start, end in (train_dt, test_dt):
27-
X, y, other = get_XyOther(df, start, end, DATE, ASSET, LABEL, FWD_RET, is_fit=True)
27+
X, y, other = get_XyOther(df, start, end, DATE, ASSET, LABEL, FWD_RET, is_test=True)
2828
break
2929

3030
model = Lasso(

ml_cs/train_lgb.py

Lines changed: 24 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -14,24 +14,33 @@
1414

1515
# %%
1616
params = {
17-
'boosting_type': 'gbdt',
18-
'objective': 'mse', # 损失函数
19-
# 'metric': 'None', # 评估函数,这里用feval来替代
17+
# TODO 分类不平衡
18+
'is_unbalance': True, # 自动平衡正负样本
19+
# 或者使用以下方式手动设置权重
20+
# 'scale_pos_weight': 3, # 假设正样本是少数类,放大10倍权重
21+
# 或者更精确的类别权重
22+
# 'class_weight': {0: 1, 1: 3}, # 为类别1设置更高权重
2023

21-
'max_depth': 8,
24+
# TODO 分类
25+
'objective': 'binary',
26+
'metric': {'binary_logloss'}, # 评价函数选择
27+
28+
# # TODO 回归
29+
# 'objective': 'mse',
30+
# 'metric': {'l2'}, #评价函数选择
31+
32+
# 其他参数
33+
'max_depth': -1,
2234
'num_leaves': 63,
23-
'learning_rate': 0.05,
24-
'min_data_in_leaf': 50,
25-
'feature_fraction': 1.0,
26-
'bagging_fraction': 1.0,
35+
'learning_rate': 0.01,
36+
'feature_fraction': 0.8,
37+
'bagging_fraction': 0.9,
2738
'bagging_freq': 5,
2839
'lambda_l1': 0.0,
2940
'lambda_l2': 0.0,
30-
'max_bin': 127,
3141
'verbose': -1, # 不显示
3242
'device_type': 'cpu',
3343
'seed': 42,
34-
'force_col_wise': True,
3544
}
3645
# %%
3746
df = load_process()
@@ -45,17 +54,17 @@ def fit():
4554

4655
models = []
4756
for i, train_dt, test_dt in walk_forward(trading_dates,
48-
n_splits=3, max_train_size=None, test_size=60, gap=3):
57+
n_splits=1, max_train_size=None, test_size=60, gap=3):
4958
ds = []
5059
for start, end in (train_dt, test_dt):
51-
X, y, other = get_XyOther(df, start, end, DATE, ASSET, LABEL, FWD_RET, is_fit=True)
60+
X, y, other = get_XyOther(df, start, end, DATE, ASSET, LABEL, FWD_RET, is_test=True)
5261
ds.append(lgb.Dataset(X, label=y, categorical_feature=categorical_feature))
5362

5463
evals_result = {} # to record eval results for plotting
5564
model = lgb.train(
5665
params,
57-
ds[0],
58-
num_boost_round=500,
66+
train_set=ds[0],
67+
num_boost_round=300,
5968
valid_sets=ds,
6069
valid_names=['train', 'valid'],
6170
feval=None, # 与早停相配合
@@ -74,7 +83,7 @@ def fit():
7483
# %% 模型评估
7584
def evaluate(models):
7685
_, ax = plt.subplots(1, 1, figsize=(10, 5))
77-
plot_metric_errorbar(models, metric='l2', ax=ax)
86+
plot_metric_errorbar(models, metric=list(params['metric'])[0], ax=ax)
7887
_, ax = plt.subplots(1, 1, figsize=(10, 5))
7988
plot_importance_box(models, ax=ax)
8089
plt.show()

ml_cs/utils.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def load_dates(path: str, date: str) -> pd.Series:
100100

101101

102102
def get_XyOther(df: pl.DataFrame, start: pd.Timestamp, end: pd.Timestamp,
103-
date: str, asset: str, label: str, *fwd_ret: str, is_fit: bool) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
103+
date: str, asset: str, label: str, *fwd_ret: str, is_test: bool) -> Tuple[pl.DataFrame, pl.DataFrame, pl.DataFrame]:
104104
"""获取X y other
105105
106106
Parameters
@@ -112,10 +112,11 @@ def get_XyOther(df: pl.DataFrame, start: pd.Timestamp, end: pd.Timestamp,
112112
asset
113113
label
114114
fwd_ret
115-
is_fit:bool
115+
is_test:bool
116116
是否用于训练。
117117
fit时,X和y都不能出现null
118118
predict时,X不能出现null,y无限制
119+
但要验证predict效果时,y不能为hull
119120
120121
Returns
121122
-------
@@ -128,7 +129,7 @@ def get_XyOther(df: pl.DataFrame, start: pd.Timestamp, end: pd.Timestamp,
128129
"""
129130

130131
df = df.filter(pl.col(date).is_between(start, end))
131-
if is_fit:
132+
if is_test:
132133
df = df.drop_nulls(subset=pl.exclude(*fwd_ret))
133134
else:
134135
df = df.drop_nulls(subset=pl.exclude(*fwd_ret, label))
@@ -137,7 +138,7 @@ def get_XyOther(df: pl.DataFrame, start: pd.Timestamp, end: pd.Timestamp,
137138
_y = df.select(date, asset, label)
138139
_other = df.select(date, asset, label, *fwd_ret)
139140

140-
# 转换成复合索引,还成正常输入到sklearn
141+
# 转换成复合索引,可正常输入到sklearn
141142
_X = _X.to_pandas().set_index([date, asset])
142143
_y = _y.to_pandas().set_index([date, asset])
143144

0 commit comments

Comments
 (0)
0