File tree Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Expand file tree Collapse file tree 1 file changed +7
-3
lines changed Original file line number Diff line number Diff line change @@ -439,6 +439,8 @@ def _average_path_length(n_samples_leaf):
439
439
"""
440
440
if isinstance (n_samples_leaf , INTEGER_TYPES ):
441
441
if n_samples_leaf <= 1 :
442
+ return 0.
443
+ if n_samples_leaf <= 2 :
442
444
return 1.
443
445
else :
444
446
return 2. * (np .log (n_samples_leaf - 1. ) + np .euler_gamma ) - 2. * (
@@ -450,10 +452,12 @@ def _average_path_length(n_samples_leaf):
450
452
n_samples_leaf = n_samples_leaf .reshape ((1 , - 1 ))
451
453
average_path_length = np .zeros (n_samples_leaf .shape )
452
454
453
- mask = (n_samples_leaf <= 1 )
454
- not_mask = np .logical_not (mask )
455
+ mask_1 = (n_samples_leaf <= 1 )
456
+ mask_2 = (n_samples_leaf == 2 )
457
+ not_mask = np .logical_not (np .logical_or (mask_1 , mask_2 ))
455
458
456
- average_path_length [mask ] = 1.
459
+ average_path_length [mask_1 ] = 0.
460
+ average_path_length [mask_2 ] = 1.
457
461
average_path_length [not_mask ] = 2. * (
458
462
np .log (n_samples_leaf [not_mask ] - 1. ) + np .euler_gamma ) - 2. * (
459
463
n_samples_leaf [not_mask ] - 1. ) / n_samples_leaf [not_mask ]
You can’t perform that action at this time.
0 commit comments