8000 Simpler extension API by gbaydin · Pull Request #311 · DiffSharp/DiffSharp · GitHub
[go: up one dir, main page]

Skip to content

Simpler extension API #311

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 27 commits into from
Apr 2, 2021
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
constant tensor tests
  • Loading branch information
gbaydin committed Mar 31, 2021
commit 343ffcce597e536b771fb3c5db8425f0490a03d3
78 changes: 56 additions & 22 deletions tests/DiffSharp.Tests/TestExtensions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -105,16 +105,6 @@ module TestOps =
let revxd1 = revx1.derivative
let revxd2 = revx2.derivative

// printfn "x %A" x
// printfn "fwdz1 %A" fwdz1
// printfn "fwdz2 %A" fwdz2
// printfn "fwdzd1 %A" fwdzd1
// printfn "fwdzd2 %A" fwdzd2
// printfn "revz1 %A" revz1
// printfn "revz2 %A" revz2
// printfn "revxd1 %A" revxd1
// printfn "revxd2 %A\n" revxd2

Assert.True(fwdz1.allclose(fwdz2, 0.01))
Assert.True(fwdzd1.allclose(fwdzd2, 0.01))
Assert.True(revz1.allclose(revz2, 0.01))
Expand All @@ -126,6 +116,7 @@ module TestOps =
let xd = dsharp.randnLike(x)
let yd = dsharp.randnLike(y)

// Tensor, Tensor
let fwdx = x.forwardDiff(xd)
let fwdy = y.forwardDiff(yd)
let fwdz1 : Tensor = op1 fwdx fwdy
Expand All @@ -147,23 +138,66 @@ module TestOps =
let revyd1 = revy1.derivative
let revyd2 = revy2.derivative

printfn "x %A" x
printfn "y %A" y
printfn "fwdz1 %A" fwdz1
printfn "fwdz2 %A" fwdz2
printfn "fwdzd1 %A" fwdzd1
printfn "fwdzd2 %A" fwdzd2
printfn "revz1 %A" revz1
printfn "revz2 %A" revz2
printfn "revxd1 %A" revxd1
printfn "revxd2 %A" revxd2
printfn "revyd1 %A" revyd1
printfn "revyd2 %A\n" revyd2
Assert.True(fwdz1.allclose(fwdz2, 0.01))
Assert.True(fwdzd1.allclose(fwdzd2, 0.01))
Assert.True(revz1.allclose(revz2, 0.01))
Assert.True(revxd1.allclose(revxd2, 0.01))
Assert.True(revyd1.allclose(revyd2, 0.01))

// Tensor, constant Tensor
let fwdx = x.forwardDiff(xd)
let fwdy = y
let fwdz1 : Tensor = op1 fwdx fwdy
let fwdz2 : Tensor = op2 fwdx fwdy
let fwdzd1 = fwdz1.derivative
let fwdzd2 = fwdz2.derivative

let zd = dsharp.randnLike(fwdz1)
let revx1 = x.reverseDiff()
let revy1 = y
let revx2 = x.reverseDiff()
let revy2 = y
let revz1 = op1 revx1 revy1
let revz2 = op1 revx2 revy2
revz1.reverse(zd)
revz2.reverse(zd)
let revxd1 = revx1.derivative
let revxd2 = revx2.derivative
let revyd1 = revy1.isNoDiff()
let revyd2 = revy2.isNoDiff()

Assert.True(fwdz1.allclose(fwdz2, 0.01))
Assert.True(fwdzd1.allclose(fwdzd2, 0.01))
Assert.True(revz1.allclose(revz2, 0.01))
Assert.True(revxd1.allclose(revxd2, 0.01))
Assert.CheckEqual(revyd1, revyd2)

// Constant Tensor, Tensor
let fwdx = x
let fwdy = y.forwardDiff(yd)
let fwdz1 : Tensor = op1 fwdx fwdy
let fwdz2 : Tensor = op2 fwdx fwdy
let fwdzd1 = fwdz1.derivative
let fwdzd2 = fwdz2.derivative

let zd = dsharp.randnLike(fwdz1)
let revx1 = x
let revy1 = y.reverseDiff()
let revx2 = x
let revy2 = y.reverseDiff()
let revz1 = op1 revx1 revy1
let revz2 = op1 revx2 revy2
revz1.reverse(zd)
revz2.reverse(zd)
let revxd1 = revx1.isNoDiff()
let revxd2 = revx2.isNoDiff()
let revyd1 = revy1.derivative
let revyd2 = revy2.derivative

Assert.True(fwdz1.allclose(fwdz2, 0.01))
Assert.True(fwdzd1.allclose(fwdzd2, 0.01))
Assert.True(revz1.allclose(revz2, 0.01))
Assert.CheckEqual(revxd1, revxd2)
Assert.True(revyd1.allclose(revyd2, 0.01))

[<Test>]
Expand Down
0