8000 Take extra check from double precision addition and apply it to singl… · DiffSharp/DiffSharp@042d2a7 · GitHub
[go: up one dir, main page]

Skip to content

Commit 042d2a7

Browse files
committed
Take extra check from double precision addition and apply it to single precision case
1 parent 35d8925 commit 042d2a7

File tree

2 files changed

+15
-1
lines changed

2 files changed

+15
-1
lines changed

src/DiffSharp/Backend.OpenBLAS.fs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -715,7 +715,8 @@ module OpenBLAS =
715715
let xl2 = Array2D.length2 x
716716
let yl1 = Array2D.length1 y
717717
let yl2 = Array2D.length2 y
718-
if (xl1 <> yl1) || (xl2 <> yl2) then
718+
if xl1 * xl2 = 0 then ()
719+
elif (xl1 <> yl1) || (xl2 <> yl2) then
719720
ErrorMessages.InvalidArgMM()
720721
else
721722
Stats.InplaceOp(yl1 * yl2)

tests/DiffSharp.Tests/AD.Float32.fs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,3 +24,16 @@ let ``AD.32.R.D.FixedPoint``() =
2424
let g (a:D) (b:D) = (a + b / a) / (D 2.f)
2525
let p, t = jacobianTv' (D.FixedPoint g (D 1.2f)) (D 25.f) (D 1.f)
2626
Util.(=~)(p, D 5.f) && Util.(=~)(t, D 0.1f)
27+
28+
[<Property>]
29+
let ``Compute Adjoint``() =
30+
let tag = DiffSharp.Util.GlobalTagger.Next
31+
32+
let Wt = toDM [[0.0f; 1.0f]]
33+
let Wt' = Wt |> makeReverse tag
34+
let loss (weights:DM) : D = cos (weights.Item(0,0))
35+
36+
let L = loss Wt'
37+
let A = computeAdjoints L //Smoke test computeAdjoints, was an issue with single precision
38+
39+
()

0 commit comments

Comments
 (0)
0