8000 det · DiffSharp/DiffSharp@55f0dbe · GitHub
[go: up one dir, main page]

Skip to content

Commit 55f0dbe

Browse files
committed
det
1 parent 4a96600 commit 55f0dbe

File tree

2 files changed

+17
-9
lines changed

2 files changed

+17
-9
lines changed

src/DiffSharp.Backends.Reference/Reference.RawTensor.fs

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -718,21 +718,28 @@ module internal RawTensorCPU =
718718
|> Array.map (fun v -> t.MakeLike(v, [|t.Shape.[1]; t.Shape.[2]|]))
719719
t.StackTs(tinvs, 0) :?> RawTensorCPU<'T>
720720

721+
let inline diagonal(square: ^T[,]) =
722+
let n = square.GetLength(0)
723+
if n <> square.GetLength(1) then failwith "Expecting a square array"
724+
Array.init n (fun i -> square.[i, i])
725+
726+
let inline prod(t: ^T[]) =
727+
Array.fold (fun s x -> s * x) LanguagePrimitives.GenericOne<'T> t
728+
721729
let inline DetT(t: RawTensorCPU< ^T >) : RawTensorCPU< ^T > =
722730
Shape.checkCanDet t.Shape
723731
let dim = t.Shape.Length
724732
if dim = 2 then
725733
let lu, _, toggle = LUDecomposition(t.ToArray() :?> ^T[,])
726-
let n = t.Shape.[1]
727-
let luDiagonal = Array.init n (fun i -> lu.[i, i])
728-
let d:^T = toggle * (Array.fold (fun s x -> s * x) LanguagePrimitives.GenericOne<'T> (luDiagonal))
734+
let d:^T = toggle * (prod (diagonal lu))
729735
t.MakeLike([|d|], [||]) :?> RawTensorCPU<'T>
730736
else
731-
// let tdets =
732-
// t.UnstackT(0)
733-
// |> Array.map (fun v -> DetT(v :?> RawTensorCPU< ^T > ) :> RawTensor)
734-
// t.StackTs(tdets, 0) :?> RawTensorCPU<'T>
735-
failwith "Not implemented"
737+
let tdets =
738+
t.UnstackT(0)
739+
|> Array.map (fun v -> let lu, _, toggle = LUDecomposition(v.ToArray() :?> ^T[,]) in lu, toggle)
740+
|> Array.map (fun (lu, toggle) -> toggle * (prod (diagonal lu)))
741+
|> Array.map (fun v -> t.MakeLike([|v|], [|t.Shape.[0]|]))
742+
t.StackTs(tdets, 0) :?> RawTensorCPU<'T>
736743

737744
let inline SolveTT(a: RawTensorCPU< ^T >, b: RawTensor) : RawTensorCPU< ^T > =
738745
let newShape = Shape.checkCanSolve a.Shape b.Shape

src/DiffSharp.Backends.Torch/Torch.RawTensor.fs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,8 @@ type TorchRawTensor(tt: torch.Tensor, shape: Shape, dtype: Dtype, device: Device
322322
override t.DetT() =
323323
Shape.checkCanDet t.Shape
324324
let result = torch.linalg.det(tt)
325-
t.MakeLike(result)
325+
let shape = result.shape |> Array.map int32
326+
t.MakeLike(result, shape=shape)
326327

327328
override t1.SolveTT(t2) =
328329
let newShape = Shape.checkCanSolve t1.Shape t2.Shape

0 commit comments

Comments
 (0)
0