8000 Feature: Batch matmul by dsyme · Pull Request #88 · DiffSharp/DiffSharp · GitHub
[go: up one dir, main page]

Skip to content

Feature: Batch matmul #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fdbe45a
batch matmul
dsyme Feb 26, 2020
6cf7ed0
Merge branch 'feature/expand' of https://github.com/DiffSharp/DiffSha…
dsyme Feb 28, 2020
171d662
fix transpose case and add explicit Expand testing
dsyme Feb 28, 2020
5c24caf
organise array helpers
dsyme Feb 28, 2020
f30e516
integrate dev
dsyme Feb 28, 2020
76378d4
integrate dev
dsyme Feb 28, 2020
cd29bc6
integrate dev renaming
dsyme Mar 2, 2020
be905aa
Merge branch 'feature/expand' of https://github.com/DiffSharp/DiffSha…
dsyme Mar 2, 2020
8000 23b3c7d
merge dev and resolve conflicts
dsyme Apr 29, 2020
6a959ce
merge dev
dsyme May 4, 2020
4e86210
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
dsyme May 4, 2020
ab0fddb
merge dev
dsyme May 4, 2020
098b813
improve coverage
dsyme May 5, 2020
a956604
merge dev
gbaydin May 5, 2020
b0d0ca9
improve coverage
dsyme May 6, 2020
77107a6
Merge branch 'feature/batch-matmul' of https://github.com/DiffSharp/D…
dsyme May 6, 2020
d76838f
make batchTranspose internal
dsyme May 6, 2020
4c0c005
make batchTranspose internal
dsyme May 6, 2020
6abcdfd
minor renames
gbaydin May 6, 2020
828dd07
integrate dev
dsyme May 21, 2020
dc046b4
integrate dev
dsyme Sep 9, 2020
8d3dd42
test matmul
dsyme Sep 9, 2020
70d3534
fix tests
dsyme Sep 9, 2020
c9d01db
fix large dimension matmul
dsyme Sep 9, 2020
429b47f
integrate docs
Sep 15, 2020
3060ce5
integrate docs
Sep 15, 2020
3ed86d4
merge dev
Oct 6, 2020
b93b754
merge dev
Oct 15, 2020
3c94672
Merge branch 'dev' into feature/batch-matmul
gbaydin Oct 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
merge dev
  • Loading branch information
dsyme committed May 4, 2020
commit 6a959ceba2e997b6c59bdd67402127c514660195
22 changes: 11 additions & 11 deletions src/DiffSharp.Backend.None/RawTensorCPU.fs
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ type RawTensorCPU<'T when 'T : equality>(values: 'T[], shape: int[], dtype: DTyp
let row = (i / ncols ) % nrows
let j = (i / ncols / nrows)*ncols*nrows + col*nrows + row
result.[j] <- values.[i]
upcast RawTensorFloat32CPU(result, newShape)
t.CreateShaped(result, newShape)

override t.SqueezeT(dim) =
let result = Array.copy t.Values
Expand Down Expand Up @@ -547,15 +547,15 @@ module internal RawTensorCPU =
let result = Array.map (fun t -> t ** t2value) t1value
(result, t1.Shape)

override t1.MatMulTT(t2) =
let inline MatMulTT(t1: RawTensorCPU< ^T >, t2: RawTensor) : (^T[] * int[]) =
checkCanMatmul t1.Shape t2.Shape
let t1BatchPart, t1MatrixPart = t1.Shape |> Array.splitAt (t1.Shape.Length-2)
let t2BatchPart, t2MatrixPart = t2.Shape |> Array.splitAt (t2.Shape.Length-2)
if t1BatchPart <> t2BatchPart then failwithf "Cannot matrix multiply tensors with shapes %A, %A - mismatch batching" t1.Shape t2.Shape
let t1rows, t1cols = t1MatrixPart.[0], t1MatrixPart.[1]
let t2rows, t2cols = t2MatrixPart.[0], t2MatrixPart.[1]
let t1value = t1.Values
let t2value = (t2 :?> RawTensorFloat32CPU).Values
let t2value = (t2 :?> RawTensorCPU< ^T >).Values
let newShape = Array.append t1BatchPart [| t1rows; t2cols |]
let values =
match t1.Dim with
Expand All @@ -570,7 +570,7 @@ module internal RawTensorCPU =
Array.init4D nb0 nb1 t1rows t2cols (fun b0 b1 i j -> Array.sumBy (fun k -> t1value.[((b0*nb1+b1)*t1rows+i)*t1cols+k] * t2value.[((b0*nb1+b1)*t2rows+k)*t2cols+j]) [|0..(t2rows-1)|] )
| _ -> failwith "MatMulTT - tensor size > 4 nyi"

upcast RawTensorFloat32CPU(values, newShape)
(values, newShape)

let inline Conv1D(t1: RawTensorCPU< ^T >, t2: RawTensor, stride, padding) : RawTensorCPU< ^T > =
// t1: input, NxCxI (batchSize x inputChannels x inputLength)
Expand Down Expand Up @@ -825,7 +825,7 @@ type RawTensorFloat32CPU(values: float32[], shape:int[]) =
override t1.PowTT(t2) = RawTensorCPU.PowTT(t1, t2) |> create
override t1.PowT0T(t2) = RawTensorCPU.PowT0T(t1, t2) |> create
override t1.PowTT0(t2) = RawTensorCPU.PowTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.Conv1D(t2, stride, padding) = RawTensorCPU.Conv1D (t1, t2, stride, padding) :> _
override t1.Conv2D(t2, stride, padding) = RawTensorCPU.Conv2D (t1, t2, stride, padding) :> _
override t1.Conv3D(t2, stride, padding) = RawTensorCPU.Conv3D (t1, t2, stride, padding) :> _
Expand Down Expand Up @@ -902,7 +902,7 @@ type RawTensorFloat64CPU(values: double[], shape:int[]) =
override t1.PowTT(t2) = RawTensorCPU.PowTT(t1, t2) |> create
override t1.PowT0T(t2) = RawTensorCPU.PowT0T(t1, t2) |> create
override t1.PowTT0(t2) = RawTensorCPU.PowTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.Conv1D(t2, stride, padding) = RawTensorCPU.Conv1D (t1, t2, stride, padding) :> _
override t1.Conv2D(t2, stride, padding) = RawTensorCPU.Conv2D (t1, t2, stride, padding) :> _
override t1.Conv3D(t2, stride, padding) = RawTensorCPU.Conv3D (t1, t2, stride, padding) :> _
Expand Down Expand Up @@ -974,7 +974,7 @@ type RawTensorInt8CPU(values: int8[], shape:int[]) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.Conv1D(t2, stride, padding) = RawTensorCPU.Conv1D(t1, t2, stride, padding) :> _
override t1.Conv2D(t2, stride, padding) = RawTensorCPU.Conv2D (t1, t2, stride, padding) :> _
override t1.Conv3D(t2, stride, padding) = RawTensorCPU.Conv3D (t1, t2, stride, padding) :> _
Expand Down Expand Up @@ -1051,7 +1051,7 @@ type RawTensorInt16CPU(values: int16[], shape:int[]) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.Conv1D(t2, stride, padding) = RawTensorCPU.Conv1D(t1, t2, stride, padding) :> _
override t1.Conv2D(t2, stride, padding) = RawTensorCPU.Conv2D (t1, t2, stride, padding) :> _
override t1.Conv3D(t2, stride, padding) = RawTensorCPU.Conv3D (t1, t2, stride, padding) :> _
Expand Down Expand Up @@ -1128,7 +1128,7 @@ type RawTensorInt32CPU(values: int32[], shape:int[]) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.Conv1D(t2, stride, padding) = RawTensorCPU.Conv1D(t1, t2, stride, padding) :> _
override t1.Conv2D(t2, stride, padding) = RawTensorCPU.Conv2D (t1, t2, stride, padding) :> _
override t1.Conv3D(t2, stride, padding) = RawTensorCPU.Conv3D (t1, t2, stride, padding) :> _
Expand Down Expand Up @@ -1205,7 +1205,7 @@ type RawTensorInt64CPU(values: int64[], shape:int[]) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.Conv1D(t2, stride, padding) = RawTensorCPU.Conv1D(t1, t2, stride, padding) :> _
override t1.Conv2D(t2, stride, padding) = RawTensorCPU.Conv2D (t1, t2, stride, padding) :> _
override t1.Conv3D(t2, stride, padding) = RawTensorCPU.Conv3D (t1, t2, stride, padding) :> _
Expand Down Expand Up @@ -1286,7 +1286,7 @@ type RawTensorBoolCPU(values: bool[], shape:int[]) =
override t1.DivTT(t2) = opNotSupported2 t1.DType t2.DType
override t1.DivT0T(t2) = opNotSupported2 t1.DType t2.DType
override t1.DivTT0(t2) = opNotSupported2 t1.DType t2.DType
override t1.MatMulT2T2(t2) = opNotSupported2 t1.DType t2.DType
override t1.MatMulTT(t2) = opNotSupported2 t1.DType t2.DType
override t1.Conv1D(t2, _stride, _padding) = opNotSupported2 t1.DType t2.DType
override t1.Conv2D(t2, _stride, _padding) = opNotSupported2 t1.DType t2.DType
override t1.Conv3D(t2, _stride, _padding) = opNotSupported2 t1.DType t2.DType
Expand Down
2 changes: 1 addition & 1 deletion src/DiffSharp.Core/Tensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -663,7 +663,7 @@ type Tensor =
if aBatchPart = bBatchPart then
let inline fRaw(a:RawTensor,b) = a.MatMulTT(b)
let inline fTensor(a:Tensor,b) = a.matmul(b)
let inline dfTensorFwdTT(cp,ap:Tensor,ad:Tensor,bp,bd) = ad.matmul(bp) + ap.matmul(bd)
let inline dfTensorFwdTT(cp,ap:Tensor,ad:Tensor,bp:Tensor,bd:Tensor) = ad.matmul(bp) + ap.matmul(bd)
let inline dfTensorFwdTC(cp,ap,ad:Tensor) = ad.matmul(b)
let inline dfTensorFwdCT(cp,bp,bd) = a.matmul(bd)
let inline dfTensorRevTT(a,b) = MatMulTT(a,b)
Expand Down
26 changes: 26 additions & 0 deletions src/DiffSharp.Tests/TestTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,32 @@ open NUnit.Framework
open DiffSharp
open DiffSharp.Util
open DiffSharp.Backend
open DiffSharp.Backends
open System

// This captures the expected semantis of different DTypes
type DTypeInfo(dtype: DType) =
member _.dtype = dtype
member _.mkTensor(data: obj) = dsharp.tensor(data, dtype=dtype)
member _.arrayCreator1D(arr: double[]) =
match dtype with
| DType.Float32 -> arr |> Array.map float32 :> Array
| DType.Float64 -> arr |> Array.map double :> Array
| DType.Int8 -> arr |> Array.map int8 :> Array
| DType.Int16 -> arr |> Array.map int16:> Array
| DType.Int32 -> arr |> Array.map int32 :> Array
| DType.Int64 -> arr |> Array.map int64 :> Array
| DType.Bool -> arr |> Array.map (fun x -> abs x >= 1.0) :> Array

member _.arrayCreator2D(arr: double[,]) : Array =
match dtype with
| DType.Float32 -> arr |> Array2D.map float32 :> Array
| DType.Float64 -> arr |> Array2D.map double :> Array
| DType.Int8 -> arr |> Array2D.map int8 :> Array
| DType.Int16 -> arr |> Array2D.map int16:> Array
| DType.Int32 -> arr |> Array2D.map int32 :> Array
| DType.Int64 -> arr |> Array2D.map int64 :> Array
| DType.Bool -> arr |> Array2D.map (fun x -> abs x >= 1.0) :> Array

[<TestFixture>]
type TestTensor () =
Expand Down
You are viewing a condensed version of this merge commit. You can view the full changes here.
0