8000 Feature: Batch matmul by dsyme · Pull Request #88 · DiffSharp/DiffSharp · GitHub
[go: up one dir, main page]

Skip to content

Feature: Batch matmul #88

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 29 commits into from
Oct 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
fdbe45a
batch matmul
dsyme Feb 26, 2020
6cf7ed0
Merge branch 'feature/expand' of https://github.com/DiffSharp/DiffSha…
dsyme Feb 28, 2020
171d662
fix transpose case and add explicit Expand testing
dsyme Feb 28, 2020
5c24caf
organise array helpers
dsyme Feb 28, 2020
f30e516
integrate dev
dsyme Feb 28, 2020
76378d4
integrate dev
dsyme Feb 28, 2020
cd29bc6
integrate dev renaming
dsyme Mar 2, 2020
be905aa
Merge branch 'feature/expand' of https://github.com/DiffSharp/DiffSha…
dsyme Mar 2, 2020
23b3c7d
merge dev and resolve conflicts
dsyme Apr 29, 2020
6a959ce
merge dev
dsyme May 4, 2020 8000
4e86210
Merge branch 'dev' of https://github.com/DiffSharp/DiffSharp into fea…
dsyme May 4, 2020
ab0fddb
merge dev
dsyme May 4, 2020
098b813
improve coverage
dsyme May 5, 2020
a956604
merge dev
gbaydin May 5, 2020
b0d0ca9
improve coverage
dsyme May 6, 2020
77107a6
Merge branch 'feature/batch-matmul' of https://github.com/DiffSharp/D…
dsyme May 6, 2020
d76838f
make batchTranspose internal
dsyme May 6, 2020
4c0c005
make batchTranspose internal
dsyme May 6, 2020
6abcdfd
minor renames
gbaydin May 6, 2020
828dd07
integrate dev
dsyme May 21, 2020
dc046b4
integrate dev
dsyme Sep 9, 2020
8d3dd42
test matmul
dsyme Sep 9, 2020
70d3534
fix tests
dsyme Sep 9, 2020
c9d01db
fix large dimension matmul
dsyme Sep 9, 2020
429b47f
integrate docs
Sep 15, 2020
3060ce5
integrate docs
Sep 15, 2020
3ed86d4
merge dev
Oct 6, 2020
b93b754
merge dev
Oct 15, 2020
3c94672
Merge branch 'dev' into feature/batch-matmul
gbaydin Oct 26, 2020
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 18 additions & 21 deletions src/DiffSharp.Backends.Reference/Reference.RawTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -530,20 +530,17 @@ module internal RawTensorCPU =
let result = Array.map (fun t -> t ** t2value) t1value
(result, t1.Shape)

let inline MatMulT2T2(t1: RawTensorCPU< ^T >, t2: RawTensor) : (^T[] * Shape) =
Shape.checkCanMatmul t1.Shape t2.Shape
let t1rows, t1cols = t1.Shape.[0], t1.Shape.[1]
let t2rows, t2cols = t2.Shape.[0], t2.Shape.[1]
let inline MatMulTT(t1: RawTensorCPU< ^T >, t2: RawTensor) : (^T[] * Shape) =
let (t1BatchPart, t1MatrixPart), (t2BatchPart, t2MatrixPart) = Shape.checkCanMatmul t1.Shape t2.Shape
if t1BatchPart <> t2BatchPart then failwithf "Cannot matrix multiply raw tensors with shapes %A, %A - mismatch batching" t1.Shape t2.Shape
let t1rows, t1cols = t1MatrixPart.[0], t1MatrixPart.[1]
let t2rows, t2cols = t2MatrixPart.[0], t2MatrixPart.[1]
let t1value = t1.Values
let t2value = t2.GetTypedValues()
let result = Array.zeroCreate (t1rows*t2cols)
for i in 0 .. t1rows - 1 do
for j in 0 .. t2cols - 1 do
let mutable acc = zero
for k in 0..t2rows-1 do
acc <- acc + t1value.[i*t1cols + k] * t2value.[k*t2cols + j]
result.[i*t2cols + j] <- acc
(result,[| t1rows; t2cols |])
let t2value = (t2 :?> RawTensorCPU< ^T >).Values
let newShape = Array.append t1BatchPart [| t1rows; t2cols |]
let nb = shapeLength t1BatchPart
let values = Array.initFlat3D nb t1rows t2cols (fun b i j -> Array.sumBy (fun k -> t1value.[b*t1cols*t1rows + i*t1cols + k] * t2value.[b*t2cols*t2rows + k*t2cols + j]) [|0..(t2rows-1)|] )
(values, newShape)

let inline MaxPool1D(t1: RawTensorCPU< ^T >, kernelSize, stride, padding) : RawTensorCPU< ^T > * RawTensorCPU< int > =
let batchSize, channels, inputSize, outputSize, outputShape =
Expand Down Expand Up @@ -895,7 +892,7 @@ type RawTensorFloat32(values: float32[], shape:Shape, device) =
override t1.PowTT(t2) = RawTensorCPU.PowTT(t1, t2) |> create
override t1.PowT0T(t2) = RawTensorCPU.PowT0T(t1, t2) |> create
override t1.PowTT0(t2) = RawTensorCPU.PowTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -980,7 +977,7 @@ type RawTensorFloat64(values: double[], shape:Shape, device) =
override t1.PowTT(t2) = RawTensorCPU.PowTT(t1, t2) |> create
override t1.PowT0T(t2) = RawTensorCPU.PowT0T(t1, t2) |> create
override t1.PowTT0(t2) = RawTensorCPU.PowTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -1061,7 +1058,7 @@ type RawTensorInt8(values: int8[], shape:Shape, device) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -1143,7 +1140,7 @@ type RawTensorByte(values: byte[], shape:Shape, device) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -1225,7 +1222,7 @@ type RawTensorInt16(values: int16[], shape:Shape, device) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -1307,7 +1304,7 @@ type RawTensorInt32(values: int32[], shape:Shape, device) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -1389,7 +1386,7 @@ type RawTensorInt64(values: int64[], shape:Shape, device) =
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
Expand Down Expand Up @@ -1478,7 +1475,7 @@ type RawTensorBool(values: bool[], shape:Shape, device) =
override t1.DivTT(t2) = opNotSupported2 "DivTT" t1.Dtype t2.Dtype
override t1.DivT0T(t2) = opNotSupported2 "DivT0T" t1.Dtype t2.Dtype
override t1.DivTT0(t2) = opNotSupported2 "DivTT0" t1.Dtype t2.Dtype
override t1.MatMulT2T2(t2) = opNotSupported2 "MatMulT2T2" t1.Dtype t2.Dtype
override t1.MatMulTT(t2) = opNotSupported2 "MatMulTT" t1.Dtype t2.Dtype
override t1.MaxPool1D(_kernelSize, _stride, _padding) = opNotSupported "MaxPool1D" t1.Dtype
override t1.MaxPool2D(_kernelSize, _stride, _padding) = opNotSupported "MaxPool2D" t1.Dtype
override t1.MaxPool3D(_kernelSize, _stride, _padding) = opNotSupported "MaxPool3D" t1.Dtype
Expand Down
6 changes: 3 additions & 3 deletions src/DiffSharp.Backends.Torch/Torch.RawTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -523,11 +523,11 @@ type TorchRawTensor(tt: TorchTensor, shape: Shape, dtype: Dtype, device: Device)
let result = tt.Pow(t2v)
t1.MakeLike(result)

override t1.MatMulT2T2(t2) =
override t1.MatMulTT(t2) =
match dtype with
| Dtype.Bool -> opNotSupported2 "MatMulT2T2" dtype t2.Dtype
| Dtype.Bool -> opNotSupported2 "MatMulTT" dtype t2.Dtype
| _ ->
Shape.checkCanMatmul t1.Shape t2.Shape
let _, _ = Shape.checkCanMatmul t1.Shape t2.Shape
let result =
// "addmm for CUDA tensors only supports floating-point types. Try converting the tensors with .float()" | const char *
match t1.DeviceType, dtype with
Expand Down
6 changes: 6 additions & 0 deletions src/DiffSharp.Core/Extensions.fs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,12 @@ module Array =
else
counts |> Array.ofSeq |> Array.map (fun (KeyValue(k, v)) -> k, v) |> Array.unzip

// Create a 2D array using a flat representation
let initFlat2D i j f = Array.init (i*j) (fun ij -> f (ij/j) (ij%j))

// Create a 3D array using a flat representation
let initFlat3D i j k f = Array.init (i*j*k) (fun ijk -> f (ijk/j/k) ((ijk/k)%j) (ijk%k))

/// Contains extensions to the F# Seq module.
module Seq =

Expand Down
2 changes: 1 addition & 1 deletion src/DiffSharp.Core/RawTensor.fs
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ type RawTensor() =
abstract member PowTT0: t2: RawTensor -> RawTensor

/// Returns the matrix multiplication of two tensors
abstract member MatMulT2T2: t2: RawTensor -> RawTensor
abstract member MatMulTT: t2: RawTensor -> RawTensor

/// Returns the 1D maxpool of a tensor and its chosen maximum indices
abstract member MaxPool1D: kernelSize: int * stride: int * padding: int -> RawTensor * RawTensor
Expand Down
11 changes: 7 additions & 4 deletions src/DiffSharp.Core/Shape.fs
Original file line number Diff line number Diff line change
Expand Up @@ -459,10 +459,13 @@ module rec Shape =
if not (contains shape1 shape2) then failwithf "Expecting shape1 to contain shape2, received %A, %A" shape1 shape2
if location.Length <> shape1.Length then failwithf "Expecting location of the same length as shape1, received %A, %A" (location.Length) shape1

/// Checks if the given shape is appropriate for a matmul operation.
let checkCanMatmul (shape1: Shape) (shape2: Shape) =
if shape1.Length <> 2 || shape2.Length <> 2 then failwithf "Expecting two 2d Tensors, received Tensors with shapes %A, %A" shape1 shape2
if shape1.[1] <> shape2.[0] then failwithf "Cannot multiply Tensors with shapes %A, %A" shape1 shape2
/// Check if the given shape is appropriate for a matmul operation.
let checkCanMatmul (shape1:int[]) (shape2:int[]) =
if shape1.Length < 2 || shape2.Length < 2 then failwithf "Expecting two 2d Tensors, received Tensors with shapes %A, %A" shape1 shape2
let aBatchPart, aMatrixPart = Array.splitAt (shape1.Length-2) shape1
let bBatchPart, bMatrixPart = Array.splitAt (shape2.Length-2) shape2
if aMatrixPart.[1] <> bMatrixPart.[0] then failwithf "Cannot matrix multiply tensors with shapes %A, %A - mismatch in matrix dimension" shape1 shape2
(aBatchPart, aMatrixPart), (bBatchPart, bMatrixPart)

/// Checks if the given shape is appropriate for a dot product operation.
let checkCanDot (shape1: Shape) (shape2: Shape) =
Expand Down
Loading
0