@@ -727,7 +727,7 @@ class Multinomial_LDL_Decomposition:
727
727
728
728
def __init__ (self , * , proba , proba_sum_to_1 = True ):
729
729
self .p = proba
730
- self .q = 1 - np .cumsum (self .p , axis = 1 )
730
+ self .q = 1 - np .cumsum (self .p , axis = 1 ) # contiguity of p
731
731
self .proba_sum_to_1 = proba_sum_to_1
732
732
if self .p .dtype == np .float32 :
733
733
eps = 2 * np .finfo (np .float32 ).resolution
@@ -745,10 +745,13 @@ def __init__(self, *, proba, proba_sum_to_1=True):
745
745
self .q [self .q <= eps ] = 0.0
746
746
self .p [self .p <= eps ] = 0.0
747
747
d = self .p * self .q
748
+ # From now on, self.q is always used in the denominator. We handle q == 0 by
749
+ # setting q to 1 whenever q == 0 such that a division of q is a no-op in this
750
+ # case.
751
+ self .q [self .q == 0 ] = 1
748
752
if self .proba_sum_to_1 :
749
753
# If q_{i - 1} = 0, then also q_i = 0.
750
- mask = self .q [:, 1 :- 1 ] == 0
751
- d [:, 1 :- 1 ] /= self .q [:, :- 2 ] + mask # d[:, -1] = 0 anyway
754
+ d [:, 1 :- 1 ] /= self .q [:, :- 2 ] # d[:, -1] = 0 anyway.
752
755
else :
753
756
d [:, 1 :] /= self .q [:, :- 1 ]
754
757
self .sqrt_d = np .sqrt (d )
@@ -785,11 +788,7 @@ def sqrt_D_Lt_matmul(self, x):
785
788
for i in range (0 , n_classes - 1 ): # row i
786
789
# L_ij = -p_i / q_j, we need transpose L'
787
790
for j in range (i + 1 , n_classes ): # column j
788
- if self .proba_sum_to_1 :
789
- mask = self .q [:, i ] == 0
790
- x [:, i ] -= self .p [:, j ] / (self .q [:, i ] + mask ) * x [:, j ]
791
- else :
792
- x [:, i ] -= self .p [:, j ] / self .q [:, i ] * x [:, j ]
791
+ x [:, i ] -= self .p [:, j ] / self .q [:, i ] * x [:, j ]
793
792
x *= self .sqrt_d
794
10000
793
return x
795
794
@@ -825,12 +824,7 @@ def L_sqrt_D_matmul(self, x):
825
824
for i in range (n_classes - 1 , 0 , - 1 ): # row i
826
825
# L_ij = -p_i / q_j
827
826
for j in range (0 , i ): # column j
828
- if self .proba_sum_to_1 :
829
- term = self .p [:, i ] * x [:, j ]
830
- mask = term == 0
831
- x [:, i ] -= term / (self .q [:, j ] + mask )
832
- else :
833
- x [:, i ] -= self .p [:, i ] / self .q [:, j ] * x [:, j ]
827
+ x [:, i ] -= self .p [:, i ] / self .q [:, j ] * x [:, j ]
834
828
return x
835
829
836
830
def inverse_L_sqrt_D_matmul (self , x ):
@@ -861,12 +855,7 @@ def inverse_L_sqrt_D_matmul(self, x):
861
855
n_classes = self .p .shape [1 ]
862
856
for i in range (n_classes - 1 , 0 , - 1 ): # row i
863
857
if i > 0 :
864
- if self .proba_sum_to_1 :
865
- # 0 / something = 0
866
- mask = self .p [:, i ] == 0
867
- fj = self .p [:, i ] / (self .q [:, i - 1 ] + mask )
868
- else :
869
- fj = self .p [:, i ] / self .q [:, i - 1 ]
858
+ fj = self .p [:, i ] / self .q [:, i - 1 ]
870
859
else :
871
860
fj = self .p [:, i ]
872
861
for j in range (0 , i ): # column j
0 commit comments