8000 Merge pull request #230 from dsyme/mm2 · DiffSharp/DiffSharp@e3139bc · GitHub
[go: up one dir, main page]

Skip to content

Commit e3139bc

Browse files
authored
Merge pull request #230 from dsyme/mm2
Redo #88 batch matmul
2 parents 9500d66 + 818c682 commit e3139bc

File tree

7 files changed

+264
-55
lines changed

7 files changed

+264
-55
lines changed

src/DiffSharp.Backends.Reference/Reference.RawTensor.fs

Lines changed: 18 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -530,20 +530,17 @@ module internal RawTensorCPU =
530530
let result = Array.map (fun t -> t ** t2value) t1value
531531
(result, t1.Shape)
532532

533-
let inline MatMulT2T2(t1: RawTensorCPU< ^T >, t2: RawTensor) : (^T[] * Shape) =
534-
Shape.checkCanMatmul t1.Shape t2.Shape
535-
let t1rows, t1cols = t1.Shape.[0], t1.Shape.[1]
536-
let t2rows, t2cols = t2.Shape.[0], t2.Shape.[1]
533+
let inline MatMulTT(t1: RawTensorCPU< ^T >, t2: RawTensor) : (^T[] * Shape) =
534+
let (t1BatchPart, t1MatrixPart), (t2BatchPart, t2MatrixPart) = Shape.checkCanMatmul t1.Shape t2.Shape
535+
if t1BatchPart <> t2BatchPart then failwithf "Cannot matrix multiply raw tensors with shapes %A, %A - mismatch batching" t1.Shape t2.Shape
536+
let t1rows, t1cols = t1MatrixPart.[0], t1MatrixPart.[1]
537+
let t2rows, t2cols = t2MatrixPart.[0], t2MatrixPart.[1]
537538
let t1value = t1.Values
538-
let t2value = t2.GetTypedValues()
539-
let result = Array.zeroCreate (t1rows*t2cols)
540-
for i in 0 .. t1rows - 1 do
541-
for j in 0 .. t2cols - 1 do
542-
let mutable acc = zero
543-
for k in 0..t2rows-1 do
544-
acc <- acc + t1value.[i*t1cols + k] * t2value.[k*t2cols + j]
545-
result.[i*t2cols + j] <- acc
546-
(result,[| t1rows; t2cols |])
539+
let t2value = (t2 :?> RawTensorCPU< ^T >).Values
540+
let newShape = Array.append t1BatchPart [| t1rows; t2cols |]
541+
let nb = shapeLength t1BatchPart
542+
let values = Array.initFlat3D nb t1rows t2cols (fun b i j -> Array.sumBy (fun k -> t1value.[b*t1cols*t1rows + i*t1cols + k] * t2value.[b*t2cols*t2rows + k*t2cols + j]) [|0..(t2rows-1)|] )
543+
(values, newShape)
547544

548545
let inline MaxPool1D(t1: RawTensorCPU< ^T >, kernelSize, stride, padding) : RawTensorCPU< ^T > * RawTensorCPU< int > =
549546
let batchSize, channels, inputSize, outputSize, outputShape =
@@ -895,7 +892,7 @@ type RawTensorFloat32(values: float32[], shape:Shape, device) =
895892
override t1.PowTT(t2) = RawTensorCPU.PowTT(t1, t2) |> create
896893
override t1.PowT0T(t2) = RawTensorCPU.PowT0T(t1, t2) |> create
897894
override t1.PowTT0(t2) = RawTensorCPU.PowTT0(t1, t2) |> create
898-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
895+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
899896
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
900897
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
901898
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -980,7 +977,7 @@ type RawTensorFloat64(values: double[], shape:Shape, device) =
980977
override t1.PowTT(t2) = RawTensorCPU.PowTT(t1, t2) |> create
981978
override t1.PowT0T(t2) = RawTensorCPU.PowT0T(t1, t2) |> create
982979
override t1.PowTT0(t2) = RawTensorCPU.PowTT0(t1, t2) |> create
983-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
980+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
984981
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
985982
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in 77FB result :> _, indices :> _
986983
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1061,7 +1058,7 @@ type RawTensorInt8(values: int8[], shape:Shape, device) =
10611058
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
10621059
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
10631060
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
1064-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
1061+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
10651062
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
10661063
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
10671064
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1143,7 +1140,7 @@ type RawTensorByte(values: byte[], shape:Shape, device) =
11431140
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
11441141
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
11451142
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
1146-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
1143+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
11471144
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
11481145
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
11491146
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1225,7 +1222,7 @@ type RawTensorInt16(values: int16[], shape:Shape, device) =
12251222
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
12261223
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
12271224
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
1228-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
1225+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
12291226
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
12301227
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
12311228
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1307,7 +1304,7 @@ type RawTensorInt32(values: int32[], shape:Shape, device) =
13071304
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
13081305
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
13091306
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
1310-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
1307+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
13111308
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
13121309
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
13131310
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1389,7 +1386,7 @@ type RawTensorInt64(values: int64[], shape:Shape, device) =
13891386
override t1.DivTT(t2) = RawTensorCPU.DivTT(t1, t2) |> create
13901387
override t1.DivT0T(t2) = RawTensorCPU.DivT0T(t1, t2) |> create
13911388
override t1.DivTT0(t2) = RawTensorCPU.DivTT0(t1, t2) |> create
1392-
override t1.MatMulT2T2(t2) = RawTensorCPU.MatMulT2T2(t1, t2) |> create
1389+
override t1.MatMulTT(t2) = RawTensorCPU.MatMulTT(t1, t2) |> create
13931390
override t1.MaxPool1D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool1D(t1, kernelSize, stride, padding) in result :> _, indices :> _
13941391
override t1.MaxPool2D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool2D(t1, kernelSize, stride, padding) in result :> _, indices :> _
13951392
override t1.MaxPool3D(kernelSize, stride, padding) = let result, indices = RawTensorCPU.MaxPool3D(t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1478,7 +1475,7 @@ type RawTensorBool(values: bool[], shape:Shape, device) =
14781475
override t1.DivTT(t2) = opNotSupported2 "DivTT" t1.Dtype t2.Dtype
14791476
override t1.DivT0T(t2) = opNotSupported2 "DivT0T" t1.Dtype t2.Dtype
14801477
override t1.DivTT0(t2) = opNotSupported2 "DivTT0" t1.Dtype t2.Dtype
1481-
override t1.MatMulT2T2(t2) = opNotSupported2 "MatMulT2T2" t1.Dtype t2.Dtype
1478+
override t1.MatMulTT(t2) = opNotSupported2 "MatMulTT" t1.Dtype t2.Dtype
14821479
override t1.MaxPool1D(_kernelSize, _stride, _padding) = opNotSupported "MaxPool1D" t1.Dtype
14831480
override t1.MaxPool2D(_kernelSize, _stride, _padding) = opNotSupported "MaxPool2D" t1.Dtype
14841481
override t1.MaxPool3D(_kernelSize, _stride, _padding) = opNotSupported "MaxPool3D" t1.Dtype

src/DiffSharp.Backends.Torch/Torch.RawTensor.fs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,11 +523,11 @@ type TorchRawTensor(tt: TorchTensor, shape: Shape, dtype: Dtype, device: Device)
523523
let result = tt.Pow(t2v)
524524
t1.MakeLike(result)
525525

526-
override t1.MatMulT2T2(t2) =
526+
override t1.MatMulTT(t2) =
527527
match dtype with
528-
| Dtype.Bool -> opNotSupported2 "MatMulT2T2" dtype t2.Dtype
528+
| Dtype.Bool -> opNotSupported2 "MatMulTT" dtype t2.Dtype
529529
| _ ->
530-
Shape.checkCanMatmul t1.Shape t2.Shape
530+
let _, _ = Shape.checkCanMatmul t1.Shape t2.Shape
531531
let result =
532532
// "addmm for CUDA tensors only supports floating-point types. Try converting the tensors with .float()" | const char *
533533
match t1.DeviceType, dtype with

src/DiffSharp.Core/Extensions.fs

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ module Array =
4040
else
4141
counts |> Array.ofSeq |> Array.map (fun (KeyValue(k, v)) -> k, v) |> Array.unzip
4242

43+
// Create a 2D array using a flat representation
44+
let initFlat2D i j f = Array.init (i*j) (fun ij -> f (ij/j) (ij%j))
45+
46+
// Create a 3D array using a flat representation
47+
let initFlat3D i j k f = Array.init (i*j*k) (fun ijk -> f (ijk/j/k) ((ijk/k)%j) (ijk%k))
48+
4349
/// Contains extensions to the F# Seq module.
4450
module Seq =
4551

src/DiffSharp.Core/RawTensor.fs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,7 +408,7 @@ type RawTensor() =
408408
abstract member PowTT0: t2: RawTensor -> RawTensor
409409

410410
/// Returns the matrix multiplication of two tensors
411-
abstract member MatMulT2T2: t2: RawTensor -> RawTensor
411+
abstract member MatMulTT: t2: RawTensor -> RawTensor
412412

413413
/// Returns the 1D maxpool of a tensor and its chosen maximum indices
414414
abstract member MaxPool1D: kernelSize: int * stride: int * padding: int -> RawTensor * RawTensor

src/DiffSharp.Core/Shape.fs

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -459,10 +459,13 @@ module rec Shape =
459459
if not (contains shape1 shape2) then failwithf "Expecting shape1 to contain shape2, received %A, %A" shape1 shape2
460460
if location.Length <> shape1.Length then failwithf "Expecting location of the same length as shape1, received %A, %A" (location.Length) shape1
461461

462-
/// Checks if the given shape is appropriate for a matmul operation.
463-
let checkCanMatmul (shape1: Shape) (shape2: Shape) =
464-
if shape1.Length <> 2 || shape2.Length <> 2 then failwithf "Expecting two 2d Tensors, received Tensors with shapes %A, %A" shape1 shape2
465-
if shape1.[1] <> shape2.[0] then failwithf "Cannot multiply Tensors with shapes %A, %A" shape1 shape2
462+
/// Check if the given shape is appropriate for a matmul operation.
463+
let checkCanMatmul (shape1:int[]) (shape2:int[]) =
464+
if shape1.Length < 2 || shape2.Length < 2 then failwithf "Expecting two 2d Tensors, received Tensors with shapes %A, %A" shape1 shape2
465+
let aBatchPart, aMatrixPart = Array.splitAt (shape1.Length-2) shape1
466+
let bBatchPart, bMatrixPart = Array.splitAt (shape2.Length-2) shape2
467+
if aMatrixPart.[1] <> bMatrixPart.[0] then failwithf "Cannot matrix multiply tensors with shapes %A, %A - mismatch in matrix dimension" shape1 shape2
468+
(aBatchPart, aMatrixPart), (bBatchPart, bMatrixPart)
466469

467470
/// Checks if the given shape is appropriate for a dot product operation.
468471
let checkCanDot (shape1: Shape) (shape2: Shape) =

0 commit comments

Comments
 (0)
0