@@ -530,20 +530,17 @@ module internal RawTensorCPU =
530
530
let result = Array.map ( fun t -> t ** t2value) t1value
531
531
( result, t1.Shape)
532
532
533
- let inline MatMulT2T2 ( t1 : RawTensorCPU < ^T >, t2 : RawTensor ) : ( ^T [] * Shape ) =
534
- Shape.checkCanMatmul t1.Shape t2.Shape
535
- let t1rows , t1cols = t1.Shape.[ 0 ], t1.Shape.[ 1 ]
536
- let t2rows , t2cols = t2.Shape.[ 0 ], t2.Shape.[ 1 ]
533
+ let inline MatMulTT ( t1 : RawTensorCPU < ^T >, t2 : RawTensor ) : ( ^T [] * Shape ) =
534
+ let ( t1BatchPart , t1MatrixPart ), ( t2BatchPart , t2MatrixPart ) = Shape.checkCanMatmul t1.Shape t2.Shape
535
+ if t1BatchPart <> t2BatchPart then failwithf " Cannot matrix multiply raw tensors with shapes %A , %A - mismatch batching" t1.Shape t2.Shape
536
+ let t1rows , t1cols = t1MatrixPart.[ 0 ], t1MatrixPart.[ 1 ]
537
+ let t2rows , t2cols = t2MatrixPart.[ 0 ], t2MatrixPart.[ 1 ]
537
538
let t1value = t1.Values
538
- let t2value = t2.GetTypedValues()
539
- let result = Array.zeroCreate ( t1rows* t2cols)
540
- for i in 0 .. t1rows - 1 do
541
- for j in 0 .. t2cols - 1 do
542
- let mutable acc = zero
543
- for k in 0 .. t2rows-1 do
544
- acc <- acc + t1value.[ i* t1cols + k] * t2value.[ k* t2cols + j]
545
- result.[ i* t2cols + j] <- acc
546
- ( result,[| t1rows; t2cols |])
539
+ let t2value = ( t2 :?> RawTensorCPU< ^T >) .Values
540
+ let newShape = Array.append t1BatchPart [| t1rows; t2cols |]
541
+ let nb = shapeLength t1BatchPart
542
+ let values = Array.initFlat3D nb t1rows t2cols ( fun b i j -> Array.sumBy ( fun k -> t1value.[ b* t1cols* t1rows + i* t1cols + k] * t2value.[ b* t2cols* t2rows + k* t2cols + j]) [| 0 ..( t2rows-1 )|] )
543
+ ( values, newShape)
547
544
548
545
let inline MaxPool1D ( t1 : RawTensorCPU < ^T >, kernelSize , stride , padding ) : RawTensorCPU < ^T > * RawTensorCPU < int > =
549
546
let batchSize , channels , inputSize , outputSize , outputShape =
@@ -895,7 +892,7 @@ type RawTensorFloat32(values: float32[], shape:Shape, device) =
895
892
override t1.PowTT ( t2 ) = RawTensorCPU.PowTT( t1, t2) |> create
896
893
override t1.PowT0T ( t2 ) = RawTensorCPU.PowT0T( t1, t2) |> create
897
894
override t1.PowTT0 ( t2 ) = RawTensorCPU.PowTT0( t1, t2) |> create
898
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
895
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
899
896
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
900
897
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in result :> _, indices :> _
901
898
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -980,7 +977,7 @@ type RawTensorFloat64(values: double[], shape:Shape, device) =
980
977
override t1.PowTT ( t2 ) = RawTensorCPU.PowTT( t1, t2) |> create
981
978
override t1.PowT0T ( t2 ) = RawTensorCPU.PowT0T( t1, t2) |> create
982
979
override t1.PowTT0 ( t2 ) = RawTensorCPU.PowTT0( t1, t2) |> create
983
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
980
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
984
981
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
985
982
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in
77FB
result :> _, indices :> _
986
983
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1061,7 +1058,7 @@ type RawTensorInt8(values: int8[], shape:Shape, device) =
1061
1058
override t1.DivTT ( t2 ) = RawTensorCPU.DivTT( t1, t2) |> create
1062
1059
override t1.DivT0T ( t2 ) = RawTensorCPU.DivT0T( t1, t2) |> create
1063
1060
override t1.DivTT0 ( t2 ) = RawTensorCPU.DivTT0( t1, t2) |> create
1064
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
1061
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
1065
1062
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1066
1063
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1067
1064
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1143,7 +1140,7 @@ type RawTensorByte(values: byte[], shape:Shape, device) =
1143
1140
override t1.DivTT ( t2 ) = RawTensorCPU.DivTT( t1, t2) |> create
1144
1141
override t1.DivT0T ( t2 ) = RawTensorCPU.DivT0T( t1, t2) |> create
1145
1142
override t1.DivTT0 ( t2 ) = RawTensorCPU.DivTT0( t1, t2) |> create
1146
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
1143
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
1147
1144
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1148
1145
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1149
1146
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1225,7 +1222,7 @@ type RawTensorInt16(values: int16[], shape:Shape, device) =
1225
1222
override t1.DivTT ( t2 ) = RawTensorCPU.DivTT( t1, t2) |> create
1226
1223
override t1.DivT0T ( t2 ) = RawTensorCPU.DivT0T( t1, t2) |> create
1227
1224
override t1.DivTT0 ( t2 ) = RawTensorCPU.DivTT0( t1, t2) |> create
1228
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
1225
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
1229
1226
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1230
1227
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1231
1228
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1307,7 +1304,7 @@ type RawTensorInt32(values: int32[], shape:Shape, device) =
1307
1304
override t1.DivTT ( t2 ) = RawTensorCPU.DivTT( t1, t2) |> create
1308
1305
override t1.DivT0T ( t2 ) = RawTensorCPU.DivT0T( t1, t2) |> create
1309
1306
override t1.DivTT0 ( t2 ) = RawTensorCPU.DivTT0( t1, t2) |> create
1310
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
1307
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
1311
1308
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1312
1309
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1313
1310
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1389,7 +1386,7 @@ type RawTensorInt64(values: int64[], shape:Shape, device) =
1389
1386
override t1.DivTT ( t2 ) = RawTensorCPU.DivTT( t1, t2) |> create
1390
1387
override t1.DivT0T ( t2 ) = RawTensorCPU.DivT0T( t1, t2) |> create
1391
1388
override t1.DivTT0 ( t2 ) = RawTensorCPU.DivTT0( t1, t2) |> create
1392
- override t1.MatMulT2T2 ( t2 ) = RawTensorCPU.MatMulT2T2 ( t1, t2) |> create
1389
+ override t1.MatMulTT ( t2 ) = RawTensorCPU.MatMulTT ( t1, t2) |> create
1393
1390
override t1.MaxPool1D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool1D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1394
1391
override t1.MaxPool2D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool2D( t1, kernelSize, stride, padding) in result :> _, indices :> _
1395
1392
override t1.MaxPool3D ( kernelSize , stride , padding ) = let result , indices = RawTensorCPU.MaxPool3D( t1, kernelSize, stride, padding) in result :> _, indices :> _
@@ -1478,7 +1475,7 @@ type RawTensorBool(values: bool[], shape:Shape, device) =
1478
1475
override t1.DivTT ( t2 ) = opNotSupported2 " DivTT" t1.Dtype t2.Dtype
1479
1476
override t1.DivT0T ( t2 ) = opNotSupported2 " DivT0T" t1.Dtype t2.Dtype
1480
1477
override t1.DivTT0 ( t2 ) = opNotSupported2 " DivTT0" t1.Dtype t2.Dtype
1481
- override t1.MatMulT2T2 ( t2 ) = opNotSupported2 " MatMulT2T2 " t1.Dtype t2.Dtype
1478
+ override t1.MatMulTT ( t2 ) = opNotSupported2 " MatMulTT " t1.Dtype t2.Dtype
1482
1479
override t1.MaxPool1D ( _kernelSize , _stride , _padding ) = opNotSupported " MaxPool1D" t1.Dtype
1483
1480
override t1.MaxPool2D ( _kernelSize , _stride , _padding ) = opNotSupported " MaxPool2D" t1.Dtype
1484
1481
override t1.MaxPool3D ( _kernelSize , _stride , _padding ) = opNotSupported " MaxPool3D" t1.Dtype
0 commit comments