1414
1515# %%
1616params = {
17- 'boosting_type' : 'gbdt' ,
18- 'objective' : 'mse' , # 损失函数
19- # 'metric': 'None', # 评估函数,这里用feval来替代
17+ # TODO 分类不平衡
18+ 'is_unbalance' : True , # 自动平衡正负样本
19+ # 或者使用以下方式手动设置权重
20+ # 'scale_pos_weight': 3, # 假设正样本是少数类,放大10倍权重
21+ # 或者更精确的类别权重
22+ # 'class_weight': {0: 1, 1: 3}, # 为类别1设置更高权重
2023
21- 'max_depth' : 8 ,
24+ # TODO 分类
25+ 'objective' : 'binary' ,
26+ 'metric' : {'binary_logloss' }, # 评价函数选择
27+
28+ # # TODO 回归
29+ # 'objective': 'mse',
30+ # 'metric': {'l2'}, #评价函数选择
31+
32+ # 其他参数
33+ 'max_depth' : - 1 ,
2234 'num_leaves' : 63 ,
23- 'learning_rate' : 0.05 ,
24- 'min_data_in_leaf' : 50 ,
25- 'feature_fraction' : 1.0 ,
26- 'bagging_fraction' : 1.0 ,
35+ 'learning_rate' : 0.01 ,
36+ 'feature_fraction' : 0.8 ,
37+ 'bagging_fraction' : 0.9 ,
2738 'bagging_freq' : 5 ,
2839 'lambda_l1' : 0.0 ,
2940 'lambda_l2' : 0.0 ,
30- 'max_bin' : 127 ,
3141 'verbose' : - 1 , # 不显示
3242 'device_type' : 'cpu' ,
3343 'seed' : 42 ,
34- 'force_col_wise' : True ,
3544}
3645# %%
3746df = load_process ()
@@ -45,17 +54,17 @@ def fit():
4554
4655 models = []
4756 for i , train_dt , test_dt in walk_forward (trading_dates ,
48- n_splits = 3 , max_train_size = None , test_size = 60 , gap = 3 ):
57+ n_splits = 1 , max_train_size = None , test_size = 60 , gap = 3 ):
4958 ds = []
5059 for start , end in (train_dt , test_dt ):
51- X , y , other = get_XyOther (df , start , end , DATE , ASSET , LABEL , FWD_RET , is_fit = True )
60+ X , y , other = get_XyOther (df , start , end , DATE , ASSET , LABEL , FWD_RET , is_test = True )
5261 ds .append (lgb .Dataset (X , label = y , categorical_feature = categorical_feature ))
5362
5463 evals_result = {} # to record eval results for plotting
5564 model = lgb .train (
5665 params ,
57- ds [0 ],
58- num_boost_round = 500 ,
66+ train_set = ds [0 ],
67+ num_boost_round = 300 ,
5968 valid_sets = ds ,
6069 valid_names = ['train' , 'valid' ],
6170 feval = None , # 与早停相配合
@@ -74,7 +83,7 @@ def fit():
7483# %% 模型评估
7584def evaluate (models ):
7685 _ , ax = plt .subplots (1 , 1 , figsize = (10 , 5 ))
77- plot_metric_errorbar (models , metric = 'l2' , ax = ax )
86+ plot_metric_errorbar (models , metric = list ( params [ 'metric' ])[ 0 ] , ax = ax )
7887 _ , ax = plt .subplots (1 , 1 , figsize = (10 , 5 ))
7988 plot_importance_box (models , ax = ax )
8089 plt .show ()
0 commit comments