8000 Feature: Batch matmul by dsyme · Pull Request #88 · DiffSharp/DiffSharp · GitHub
[go: up one dir, main page]

Skip to content

Feature: Batch matmul #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fdbe45a
batch matmul
dsyme Feb 26, 2020
6cf7ed0
Merge branch 'feature/expand' of https://github.com/DiffSharp/DiffSha…
dsyme Feb 28, 2020
171d662
fix transpose case and add explicit Expand testing
dsyme Feb 28, 2020
5c24caf
organise array helpers
dsyme Feb 28, 2020
f30e516
integrate dev
dsyme Feb 28, 2020
76378d4
integrate dev
dsyme Feb 28, 2020
cd29bc6
integrate dev renaming
dsyme Mar 2, 2020
be905aa
Merge branch 'feature/expand' of https://github.com/DiffSharp/DiffSha…
dsyme Mar 2, 2020
23b3c7d< 8000 /code>
merge dev and resolve conflicts
dsyme Apr 29, 2020
6a959ce
merge dev
dsyme May 4, 2020
4e86210
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
dsyme May 4, 2020
ab0fddb
merge dev
dsyme May 4, 2020
098b813
improve coverage
dsyme May 5, 2020
a956604
merge dev
gbaydin May 5, 2020
b0d0ca9
improve coverage
dsyme May 6, 2020
77107a6
Merge branch 'feature/batch-matmul' of https://github.com/DiffSharp/D…
dsyme May 6, 2020
d76838f
make batchTranspose internal
dsyme May 6, 2020
4c0c005
make batchTranspose internal
dsyme May 6, 2020
6abcdfd
minor renames
gbaydin May 6, 2020
828dd07
integrate dev
dsyme May 21, 2020
dc046b4
integrate dev
dsyme Sep 9, 2020
8d3dd42
test matmul
dsyme Sep 9, 2020
70d3534
fix tests
dsyme Sep 9, 2020
c9d01db
fix large dimension matmul
dsyme Sep 9, 2020
429b47f
integrate docs
Sep 15, 2020
3060ce5
integrate docs
Sep 15, 2020
3ed86d4
merge dev
Oct 6, 2020
b93b754
merge dev
Oct 15, 2020
3c94672
Merge branch 'dev' into feature/batch-matmul
gbaydin Oct 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
fix transpose case and add explicit Expand testing
  • Loading branch information
dsyme committed Feb 28, 2020
commit 171d66289bf3642b65c9086045e448be97f789ef
4 changes: 2 additions & 2 deletions src/DiffSharp.Backend.None/RawTensorCPU.fs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ type RawTensorFloat32CPU(values: float32[], shape:int[]) =
Array.init n (fun i -> Array.init unstackedLength (fun j -> t.Values.[i*unstackedLength+j]))
|> Array.map (fun v -> upcast RawTensorFloat32CPU(v, unstackedShape))

override t.TransposeT2() =
override t.TransposeT() =
if t.Dim < 2 then failwith "Expecting at least a 2D tensor"
let oldShape = t.Shape
let batch = oldShape.[0..oldShape.Length-3]
Expand All @@ -207,7 +207,7 @@ type RawTensorFloat32CPU(values: float32[], shape:int[]) =
for i = 0 to values.Length-1 do
let col = i % ncols
let row = (i / ncols ) % nrows
let j = (i / ncols / nrows) + col*nrows + row
let j = (i / ncols / nrows)*ncols*nrows + col*nrows + row
result.[j] <- values.[i]
upcast RawTensorFloat32CPU(result, newShape)

Expand Down
2 changes: 1 addition & 1 deletion src/DiffSharp.Core/RawTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ and [<AbstractClass>]
abstract member NegT : unit -> RawTensor
abstract member SumT : unit -> RawTensor
abstract member SumT2Dim0 : unit -> RawTensor
abstract member TransposeT2: unit -> RawTensor
abstract member TransposeT: unit -> RawTensor
abstract member SqueezeT: int -> RawTensor
abstract member UnsqueezeT: int -> RawTensor
abstract member FlipT: int[] -> RawTensor
Expand Down
12 changes: 6 additions & 6 deletions src/DiffSharp.Core/Tensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -577,11 +577,11 @@ type Tensor =
member t.SumT2Dim0() = Tensor.SumT2Dim0(t)

static member Transpose (a:Tensor) =
if a.Dim > 2 then failwithf "Expecting at least a 2d tensor, received Tensor with shape %A" a.Shape
let inline fRaw(a:RawTensor) = a.TransposeT2()
if a.Dim < 2 then failwithf "Expecting at least a 2d tensor, received Tensor with shape %A" a.Shape
let inline fRaw(a:RawTensor) = a.TransposeT()
let inline fTensor(a) = Tensor.Transpose(a)
let inline dfTensorFwd(cp,ap,ad) = Tensor.Transpose(ad)
let inline dfTensorRev(a) = TransposeT2(a)
let inline dfTensorRev(a) = TransposeT(a)
Tensor.OpUnary(a, fRaw, fTensor, dfTensorFwd, dfTensorRev)
member t.Transpose() = Tensor.Transpose(t)

Expand Down Expand Up @@ -964,7 +964,7 @@ type Tensor =
| ExpandT(a) -> reset (a::tt)
| StackTs(a) -> reset (List.append (a |> List.ofSeq) tt)
| UnstackT(a,_) -> reset (a::tt)
| TransposeT2(a) -> reset (a::tt)
| TransposeT(a) -> reset (a::tt)
| SqueezeT(a) -> reset (a::tt)
| UnsqueezeT(a) -> reset (a::tt)
| FlipT(a,_) -> reset (a::tt)
Expand Down Expand Up @@ -1245,7 +1245,7 @@ type Tensor =
if a.Derivative.Dim = 0 then a.Derivative <- Tensor.ZerosLike(a) + a.Derivative
a.Derivative <- Tensor.AddSlice(a.Derivative, Array.init a.Dim (fun j -> if j=0 then i else 0), t.Derivative.Unsqueeze(0))
push ((a.Zero(), a) :: tt)
| TransposeT2(a) -> push (( 8000 t.Derivative.Transpose(), a) :: tt)
| TransposeT(a) -> push ((t.Derivative.Transpose(), a) :: tt)
| SqueezeT(a) -> push ((t.Derivative.ViewAs(a), a) :: tt)
| UnsqueezeT(a) -> push ((t.Derivative.ViewAs(a), a) :: tt)
| FlipT(a, dims) -> push ((t.Derivative.Flip(dims), a) :: tt)
Expand Down Expand Up @@ -1353,7 +1353,7 @@ and TensorOp =
| AddTTSlice of Tensor * int[] * Tensor
| AddTTConstSlice of Tensor
| AddTConstTSlice of int[] * Tensor
| TransposeT2 of Tensor
| TransposeT of Tensor
| SqueezeT of Tensor
| UnsqueezeT of Tensor
| FlipT of Tensor * int[]
Expand Down
61 changes: 59 additions & 2 deletions src/DiffSharp.Tests/TestDerivatives.fs
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,39 @@ type TestDerivatives () =
// TODO: add test for AddT2ConstT1

[<Test>]
member this.TestDerivativeAddWithBroadcast () =
member this.TestDerivativeExpand () =

let t1 = Tensor.Create([[1.]; [2.]]).ForwardDiff(Tensor.Create([[5.]; [6.]])) // 2x1
let t1Expand = t1.Expand([2;2;2]) // 2x2x2 = [[[1.;1]; [2.;2]]; [[1.;1]; [2.;2]]]
let fwdz = t1Expand
let fwdzd = fwdz.Derivative
let fwdzdCorrect = Tensor.Create ([[[5., 5.], [6., 6.]], [[5., 5.], [6., 6.]]])

(* Python:
import torch
t1 = torch.tensor([[1.], [2.]], requires_grad=True)
revz = t1.expand([2,2,2])
revz.backward(torch.tensor([[[3.,3.], [6.,6.]], [[3.,3.], [6.,6.]]]))
t1.grad
--> tensor([[12.],[24.]])
*)
let revy = t1.ReverseDiff()
let revz = revy.Expand([2;2;2])
let revz_grad = Tensor.Create([[[3.;3.]; [6.;6.]]; [[3.;3.]; [6.;6.]]])
revz.Reverse(revz_grad)
let revyd = revy.Derivative
// Note: The 4x'3' accumulate to the first entry, the 4x'6' accumulate to the second entry
let revydCorrect = Tensor.Create [[12.], [24.]]
Assert.AreEqual(fwdzd,fwdzdCorrect)
Assert.AreEqual(revyd,revydCorrect)

[<Test>]
member this.TestAddWithBroadcastSystematic () =

// This is a somewhat adhoc extra test to do a whole range of additiosn
// with broadcast, mainly to check that not problems occur in taking the
// derivatives.
//
// Systematically do all allowed broadcasts into 2x3x4
// 2x3x4 + 1 (broadcast --> 2x3x4)
// 2x3x4 + 4 (broadcast --> 2x3x4)
Expand All @@ -81,7 +113,7 @@ type TestDerivatives () =
let t1a = Tensor.Create([ [ [1.; 2.; 3.; 4.]; [5.; 6.; 7.; 8.]; [9.; 10.; 11.; 12.] ];
[ [13.; 14.; 15.; 16.]; [17.; 18.; 19.; 20.]; [21.; 22.; 23.; 24.] ] ])

// Get all the interesting shapes that broadcast into t1a
// Get all the interesting shapes that expand into t1a
let shapes =
[ for i1 in [0;1;2] do
for i2 in [0;1;3] do
Expand Down Expand Up @@ -2228,15 +2260,40 @@ type TestDerivatives () =
let fwdzd = fwdz.Derivative
let fwdzdCorrect = Tensor.Create([[2.; 10.]; [3.; 20.]; [4.; 30.]])

Assert.AreEqual(fwdzCorrect, fwdz)
Assert.AreEqual(fwdzdCorrect, fwdzd)

let revx = Tensor.Create([[1.; 2.; 3.]; [4.; 5.; 6.]]).ReverseDiff()
let revz = revx.Transpose()
let revzCorrect = Tensor.Create([[1.; 4.]; [2.; 5.]; [3.; 6.]])
revz.Reverse(Tensor.Create([[5.; 5.]; [2.; 5.]; [3.; 7.]]))
let revxd = revx.Derivative
let revxdCorrect = Tensor.Create([[5.; 2.; 3.]; [5.; 5.; 7.]])

Assert.AreEqual(revzCorrect, revz)
Assert.AreEqual(revxdCorrect, revxd)

[<Test>]
member this.TestDerivativeTransposeBatch () =
// This test is the same as TestDerivativeTransposeT2 except we add a batching expansion to
// both input and expected results
let t = Tensor.Create([[1.; 2.; 3.]; [4.; 5.; 6.]]).Expand([| 3;2;3 |])
let fwdx = t.ForwardDiff(Tensor.Create([[2.; 3.; 4.]; [10.; 20.; 30.]]).Expand([| 3;2;3 |]))
let fwdz = fwdx.Transpose()
let fwdzCorrect = Tensor.Create([[1.; 4.]; [2.; 5.]; [3.; 6.]]).Expand([| 3;3;2 |])
let fwdzd = fwdz.Derivative
let fwdzdCorrect = Tensor.Create([[2.; 10.]; [3.; 20.]; [4.; 30.]]).Expand([| 3;3;2 |])

Assert.AreEqual(fwdzCorrect, fwdz)
Assert.AreEqual(fwdzdCorrect, fwdzd)

let revx = t.ReverseDiff()
let revz = revx.Transpose()
let revzCorrect = Tensor.Create([[1.; 4.]; [2.; 5.]; [3.; 6.]]).Expand([| 3;3;2 |])
revz.Reverse(Tensor.Create([[5.; 5.]; [2.; 5.]; [3.; 7.]]).Expand([| 3;3;2 |]))
let revxd = revx.Derivative
let revxdCorrect = Tensor.Create([[5.; 2.; 3.]; [5.; 5.; 7.]]).Expand([| 3;2;3 |])

Assert.AreEqual(revzCorrect, revz)
Assert.AreEqual(revxdCorrect, revxd)

Expand Down
13 changes: 13 additions & 0 deletions src/DiffSharp.Tests/TestTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -1255,6 +1255,19 @@ type TestTensor () =
Assert.AreEqual(t1TransposeCorrect, t1Transpose)
Assert.AreEqual(t2TransposeTransposeCorrect, t2TransposeTranspose)

[<Test>]
member this.TestTensorTransposeBatch () =
let t1 = Tensor.Create([[1.; 2.; 3.]; [4.; 5.; 6.]]).Expand([|3;2;3|])
let t1Transpose = t1.Transpose()
let t1TransposeCorrect = Tensor.Create([[1.; 4.]; [2.; 5.]; [3.; 6.]]).Expand([|3;3;2|])

let t2 = Tensor.Create([[1.; 2.]; [3.; 4.]]).Expand([|3;2;2|])
let t2TransposeTranspose = t2.Transpose().Transpose()
let t2TransposeTransposeCorrect = t2

Assert.AreEqual(t1TransposeCorrect, t1Transpose)
Assert.AreEqual(t2TransposeTransposeCorrect, t2TransposeTranspose)

[<Test>]
member this.TestTensorSignT () =
let t1 = Tensor.Create([-1.; -2.; 0.; 3.])
Expand Down
0