@@ -1284,12 +1284,14 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
1284
1284
"""Mean Tweedie deviance regression loss."""
1285
1285
xp , _ = get_namespace (y_true , y_pred )
1286
1286
p = power
1287
+ zero = xp .asarray (0 , dtype = y_true .dtype )
1287
1288
if p < 0 :
1288
1289
# 'Extreme stable', y any real number, y_pred > 0
1289
1290
dev = 2 * (
1290
- xp .pow (xp .where (y_true > 0 , y_true , 0 ), 2 - p ) / ((1 - p ) * (2 - p ))
1291
- - y_true * xp .pow (y_pred , 1 - p ) / (1 - p )
1292
- + xp .pow (y_pred , 2 - p ) / (2 - p )
1291
+ xp .pow (xp .where (y_true > 0 , y_true , zero ), xp .asarray (2 - p ))
1292
+ / ((1 - p ) * (2 - p ))
1293
+ - y_true * xp .pow (y_pred , xp .asarray (1 - p )) / (1 - p )
1294
+ + xp .pow (y_pred , xp .asarray (2 - p )) / (2 - p )
1293
1295
)
1294
1296
elif p == 0 :
1295
1297
# Normal distribution, y and y_pred any real number
@@ -1302,9 +1304,9 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
1302
1304
dev = 2 * (xp .log (y_pred / y_true ) + y_true / y_pred - 1 )
1303
1305
else :
1304
1306
dev = 2 * (
1305
- xp .pow (y_true , 2 - p ) / ((1 - p ) * (2 - p ))
1306
- - y_true * xp .pow (y_pred , 1 - p ) / (1 - p )
1307
- + xp .pow (y_pred , 2 - p ) / (2 - p )
1307
+ xp .pow (y_true , xp . asarray ( 2 - p ) ) / ((1 - p ) * (2 - p ))
1308
+ - y_true * xp .pow (y_pred , xp . asarray ( 1 - p ) ) / (1 - p )
1309
+ + xp .pow (y_pred , xp . asarray ( 2 - p ) ) / (2 - p )
1308
1310
)
1309
1311
return float (_average (dev , weights = sample_weight ))
1310
1312
@@ -1384,14 +1386,14 @@ def mean_tweedie_deviance(y_true, y_pred, *, sample_weight=None, power=0):
1384
1386
message = f"Mean Tweedie deviance error with power={ power } can only be used on "
1385
1387
if power < 0 :
1386
1388
# 'Extreme stable', y any real number, y_pred > 0
1387
- if (y_pred <= 0 ). any ( ):
1389
+ if xp . any (y_pred <= 0 ):
1388
1390
raise ValueError (message + "strictly positive y_pred." )
1389
1391
elif power == 0 :
1390
1392
# Normal, y and y_pred can be any real number
1391
1393
pass
1392
1394
elif 1 <= power < 2 :
1393
1395
# Poisson and compound Poisson distribution, y >= 0, y_pred > 0
1394
- if (y_true < 0 ). any () or (y_pred <= 0 ). any ( ):
1396
+ if xp . any (y_true < 0 ) or xp . any (y_pred <= 0 ):
1395
1397
raise ValueError (message + "non-negative y and strictly positive y_pred." )
1396
1398
elif power >= 2 :
1397
1399
# Gamma and Extreme stable distribution, y and y_pred > 0
0 commit comments