@@ -718,21 +718,28 @@ module internal RawTensorCPU =
718
718
|> Array.map ( fun v -> t.MakeLike( v, [| t.Shape.[ 1 ]; t.Shape.[ 2 ]|]))
719
719
t.StackTs( tinvs, 0 ) :?> RawTensorCPU< 'T>
720
720
721
+ let inline diagonal ( square : ^T [,]) =
722
+ let n = square.GetLength( 0 )
723
+ if n <> square.GetLength( 1 ) then failwith " Expecting a square array"
724
+ Array.init n ( fun i -> square.[ i, i])
725
+
726
+ let inline prod ( t : ^T []) =
727
+ Array.fold ( fun s x -> s * x) LanguagePrimitives.GenericOne< 'T> t
728
+
721
729
let inline DetT ( t : RawTensorCPU < ^T >) : RawTensorCPU < ^T > =
722
730
Shape.checkCanDet t.Shape
723
731
let dim = t.Shape.Length
724
732
if dim = 2 then
725
733
let lu , _ , toggle = LUDecomposition( t.ToArray() :?> ^T [,])
726
- let n = t.Shape.[ 1 ]
727
- let luDiagonal = Array.init n ( fun i -> lu.[ i, i])
728
- let d : ^T = toggle * ( Array.fold ( fun s x -> s * x) LanguagePrimitives.GenericOne< 'T> ( luDiagonal))
734
+ let d : ^T = toggle * ( prod ( diagonal lu))
729
735
t.MakeLike([| d|], [||]) :?> RawTensorCPU< 'T>
730
736
else
731
- // let tdets =
732
- // t.UnstackT(0)
733
- // |> Array.map (fun v -> DetT(v :?> RawTensorCPU< ^T > ) :> RawTensor)
734
- // t.StackTs(tdets, 0) :?> RawTensorCPU<'T>
735
- failwith " Not implemented"
737
+ let tdets =
738
+ t.UnstackT( 0 )
739
+ |> Array.map ( fun v -> let lu , _ , toggle = LUDecomposition( v.ToArray() :?> ^T [,]) in lu, toggle)
740
+ |> Array.map ( fun ( lu , toggle ) -> toggle * ( prod ( diagonal lu)))
741
+ |> Array.map ( fun v -> t.MakeLike([| v|], [| t.Shape.[ 0 ]|]))
742
+ t.StackTs( tdets, 0 ) :?> RawTensorCPU< 'T>
736
743
737
744
let inline SolveTT ( a : RawTensorCPU < ^T >, b : RawTensor ) : RawTensorCPU < ^T > =
738
745
let newShape = Shape.checkCanSolve a.Shape b.Shape
0 commit comments