@@ -456,8 +456,9 @@ def gradient_hessian(
456
456
# Exit early without computing the hessian.
457
457
return grad , hess , hessian_warning
458
458
459
- # TODO: This "sandwich product", X' diag(W) X, can be greatly improved by
460
- # a dedicated Cython routine.
459
+ # TODO: This "sandwich product", X' diag(W) X, is the main computational
460
+ # bottleneck for solvers. A dedicated Cython routine might improve it
461
+ # exploiting the symmetry (as opposed to, e.g., BLAS gemm).
461
462
if sparse .issparse (X ):
462
463
hess [:n_features , :n_features ] = (
463
464
X .T
@@ -467,9 +468,8 @@ def gradient_hessian(
467
468
@ X
468
469
).toarray ()
469
470
else :
470
- # np.einsum may use less memory but the following is by far faster.
471
- # This matrix multiplication (gemm) is most often the most time
472
- # consuming step for solvers.
471
+ # np.einsum may use less memory but the following, using BLAS matrix
472
+ # multiplication (gemm), is by far faster.
473
473
WX = hess_pointwise [:, None ] * X
474
474
hess [:n_features , :n_features ] = np .dot (X .T , WX )
475
475
# flattened view on the array
0 commit comments