8000 Merge pull request #254 from DiffSharp/feature/permute · DiffSharp/DiffSharp@1213767 · GitHub
[go: up one dir, main page]

Skip to content

Commit 1213767

Browse files
authored
Merge pull request #254 from DiffSharp/feature/permute
feature/permute - permute
2 parents 7f58350 + 334b51f commit 1213767

File tree

10 files changed

+316
-30
lines changed

10 files changed

+316
-30
lines changed

src/DiffSharp.Backends.Reference/Reference.RawTensor.fs

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -203,29 +203,26 @@ type RawTensorCPU<'T when 'T : equality and 'T :> scalar>(values: 'T[], shape: S
203203
(results, outShapes) ||> Array.map2 (fun rvalues outShape ->
204204
t.MakeLike(rvalues, outShape))
205205

206+
override t.PermuteT(permutation) =
207+
let inversePermutation, newShape = Shape.checkCanPermute t.Shape permutation
208+
let result = t.ZerosLike(newShape) :?> RawTensorCPU<'T>
209+
let rec transpose (shape:Shape) externalCoords =
210+
if shape.Length = 1 then
211+ 8000
for i=0 to shape.[0]-1 do
212+
let globalCoords = Array.append externalCoords [|i|]
213+
let transposedCoords = Array.permute (fun i -> inversePermutation.[i]) globalCoords
214+
result.[transposedCoords] <- t.[globalCoords]
215+
else
216+
for i=0 to shape.[0]-1 do
217+
transpose shape.[1..] (Array.append externalCoords [|i|])
218+
transpose t.Shape [||]
219+
upcast result
220+
206221
override t.TransposeT(dim0, dim1) =
207-
Shape.checkCanTranspose t.Shape dim0 dim1
208-
if dim0 = dim1 then
209-
let result = Array.copy t.Values
210-
t.MakeLike(result, t.Shape)
211-
else
212-
let shape = Array.copy t.Shape
213-
shape.[dim0] <- t.Shape.[dim1]
214-
shape.[dim1] <- t.Shape.[dim0]
215-
let result = t.ZerosLike(shape) :?> RawTensorCPU<'T>
216-
let rec transpose (shape:Shape) externalCoords =
217-
if shape.Length = 1 then
218-
for i=0 to shape.[0]-1 do
219-
let globalCoords = Array.append externalCoords [|i|]
220-
let transposedCoords = Array.copy globalCoords
221-
transposedCoords.[dim0] <- globalCoords.[dim1]
222-
transposedCoords.[dim1] <- globalCoords.[dim0]
223-
result.[transposedCoords] <- t.[globalCoords]
224-
else
225-
for i=0 to shape.[0]-1 do
226-
transpose shape.[1..] (Array.append externalCoords [|i|])
227-
transpose t.Shape [||]
228-
upcast result
222+
let permutation = [| 0 .. t.Shape.Length - 1 |]
223+
permutation.[dim0] <- dim1
224+
permutation.[dim1] <- dim0
225+
t.PermuteT(permutation)
229226

230227
override t.TransposeT2() =
231228
Shape.checkCanTranspose2d t.Dim

src/DiffSharp.Backends.Torch/Torch.RawTensor.fs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -292,6 +292,11 @@ type TorchRawTensor(tt: TorchTensor, shape: Shape, dtype: Dtype, device: Device)
292292
(results, outShapes) ||> Array.map2 (fun rvalues outShape ->
293293
t.MakeLike(rvalues, shape=outShape))
294294

295+
override t.PermuteT(permutation) =
296+
let _, newShape = Shape.checkCanPermute t.Shape permutation
297+
let result = tt.Permute(int64s permutation)
298+
t.MakeLike(result, shape=newShape)
299+
295300
override t.TransposeT(dim0, dim1) =
296301
Shape.checkCanTranspose t.Shape dim0 dim1
297302
let result = tt.Transpose(int64 dim0, int64 dim1)

src/DiffSharp.Core/DiffSharp.fs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -648,6 +648,11 @@ type dsharp =
648648
/// <param name="indices">The the indices of elements to gather.</param>
649649
static member gather(input:Tensor, dim:int, indices:Tensor) = input.gather(dim, indices)
650650

651+
/// <summary>Returns the original tensor with its dimensions permuted.</summary>
652+
/// <param name="input">The input tensor.</param>
653+
/// <param name="permutation">The desired ordering of dimensions.</param>
654+
static member permute(input:Tensor, permutation:seq<int>) = input.permute(permutation)
655+
651656
/// <summary>Returns a tensor that is a transposed version of input. The given dimensions dim0 and dim1 are swapped.</summary>
652657
/// <param name="input">The input tensor.</param>
653658
/// <param name="dim0">The first dimension to be transposed.</param>

src/DiffSharp.Core/Extensions.fs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,13 @@ module Array =
5151
// Create a 3D array using a flat representation
5252
let initFlat3D i j k f = Array.init (i*j*k) (fun ijk -> f (ijk/j/k) ((ijk/k)%j) (ijk%k))
5353

54+
let foralli f (arr: 'T[]) =
55+
let mutable i = 0
56+
let n = arr.Length
57+
while i < n && f i arr.[i] do
58+
i <- i + 1
59+
(i = n)
60+
5461
module ArrayND =
5562
/// Initializes an array with a given shape and initializer function.
5663
let init (shape: int[]) (f: int[] -> 'T) : obj =

src/DiffSharp.Core/RawTensor.fs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,9 @@ type RawTensor() =
457457
/// Returns the 3D convolution of the tensor
458458
abstract Conv3D: kernel: RawTensor * strides: int[] * padding: int[] -> RawTensor
459459

460+
/// Returns a view of the original tensor with its dimensions permuted
461+
abstract PermuteT: permutation: int[] -> RawTensor
462+
460463
/// Returns the element-wise negation of the tensor
461464
abstract NegT: unit -> RawTensor
462465

src/DiffSharp.Core/Shape.fs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -426,6 +426,14 @@ module rec Shape =
426426
let checkCanTranspose2d (dim: int) =
427427
if dim <> 2 then failwith "Expecting dim=2 when no specific dimensions are given to transpose. Consider using general transpose(dim0, dim1)."
428428

429+
/// Checks if the given shape is appropriate for a permute operation and returns information related to the resulting shape.
430+
let checkCanPermute (shape: Shape) (permutation: int[]) =
431+
if shape.Length <> permutation.Length then failwithf "Expecting tensor's shape (%A) and permutation (%A) to have the same dims" shape permutation
432+
if Seq.hasDuplicates permutation then failwithf "Expecting permutation (%A) to have no duplicate values" permutation
433+
let inversePermutation = Array.permute (fun i -> permutation.[i]) [| 0.. shape.Length-1 |]
434+
let newShape = Array.permute (fun i -> inversePermutation.[i]) shape
435+
inversePermutation, newShape
436+
429437
/// Checks if the given shape is appropriate for a flip operation.
430438
let checkCanFlip (dim: int) (dims: int[]) =
431439
if dims.Length > dim then failwithf "Expecting dims (list of dimension indices to flip) of length less than Tensor's dimensions, received %A, %A" dims.Length dim

src/DiffSharp.Core/Tensor.fs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1554,6 +1554,20 @@ type Tensor =
15541554
let inline dfTensorRev(a) = TransposeT(a, dim0, dim1)
15551555
Tensor.OpUnary(a, fRaw, fTensor, dfTensorFwd, dfTensorRev)
15561556

1557+
/// <summary>Returns the original tensor with its dimensions permuted.</summary>
1558+
/// <param name="permutation">The desired ordering of dimensions.</param>
1559+
member a.permute(permutation:seq<int>) =
1560+
let permutation = Seq.toArrayQuick permutation
1561+
let inversePermutation, _ = Shape.checkCanPermute a.shape permutation
1562+
if permutation |> Array.foralli (fun i j -> i = j) then
1563+
a
1564+
else
1565+
let inline fRaw(a:RawTensor) = a.PermuteT(permutation)
1566+
let inline fTensor(a:Tensor) = a.permute(permutation)
1567+
let inline dfTensorFwd(cp,ap,ad:Tensor) = ad.permute(permutation)
1568+
let inline dfTensorRev(a) = PermuteT(a, inversePermutation)
1569+
Tensor.OpUnary(a, fRaw, fTensor, dfTensorFwd, dfTensorRev)
1570+
15571571
/// <summary>Returns a tensor that is a transposed version of input with dimensions 0 and 1 swapped.</summary>
15581572
member a.transpose() =
15591573
Shape.checkCanTranspose2d a.dim
@@ -2627,6 +2641,7 @@ type Tensor =
26272641
| CatTs(a,_) -> reset (List.append (a |> List.ofSeq) tt)
26282642
| SplitT(a,_,_,_) -> reset (a::tt)
26292643
| GatherT(a,_,_) -> reset (a::tt)
2644+
| PermuteT(a,_) -> reset (a::tt)
26302645
| TransposeT(a,_,_) -> reset (a::tt)
26312646
| TransposeT2(a) -> reset (a::tt)
26322647
| SqueezeT(a) -> reset (a::tt)
@@ -2784,6 +2799,7 @@ type Tensor =
27842799
loc.[dim] <- j
27852800
a.derivative <- a.derivative.addSlice(loc, t)
27862801
push (check(a.zeroLike(), a) :: tt)
2802+
| PermuteT(a, inversePermutation) -> push (check(td.permute(inversePermutation), a) :: tt)
27872803
| TransposeT(a, dim0, dim1) -> push (check(td.transpose(dim0, dim1), a) :: tt)
27882804
| TransposeT2(a) -> push (check(td.transpose(), a) :: tt)
27892805
| SqueezeT(a) -> push (check(td.viewAs(a), a) :: tt)
@@ -2900,6 +2916,7 @@ and TensorOp =
29002916
| SplitT of Tensor * int[] * dim:int * i:int
29012917
| SliceT of Tensor * int[,]
29022918
| GatherT of Tensor * int * Tensor
2919+
| PermuteT of Tensor * inversePermutation: int[]
29032920
| TransposeT of Tensor * int * int
29042921
| TransposeT2 of Tensor
29052922
| SqueezeT of Tensor

tests/DiffSharp.Tests/Properties/launchSettings.json

Lines changed: 0 additions & 8 deletions
This file was deleted.

tests/DiffSharp.Tests/TestDerivatives.fs

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1596,6 +1596,88 @@ type TestDerivatives () =
15961596
Assert.CheckEqual(revzCorrect, revz)
15971597
Assert.CheckEqual(revxdCorrect, revxd)
15981598

1599+
[<Test>]
1600+
member _.TestDerivativePermuteT () =
1601+
for combo in Combos.AllDevicesAndBackendsFloat32 do
1602+
let fwdx = combo.tensor([[[ 0., 1.],
1603+
[ 2., 3.],
1604+
[ 4., 5.]],
1605+
1606+
[[ 6., 7.],
1607+
[ 8., 9.],
1608+
[10., 11.]]]).forwardDiff(combo.tensor([[[ 0., 10.],
1609+
[ 20., 30.],
1610+
[ 40., 50.]],
1611+
1612+
[[ 60., 70.],
1613+
[ 80., 90.],
1614+
[100., 110.]]]))
1615+
// Note, this is a swap
1616+
let fwdz = fwdx.permute([2;1;0])
1617+
let fwdzCorrect = combo.tensor([[[ 0., 6.],
1618+
[ 2., 8.],
1619+
[ 4., 10.]],
1620+
1621+
[[ 1., 7.],
1622+
[ 3., 9.],
1623+
[ 5., 11.]]])
1624+
let fwdzd = fwdz.derivative
1625+
let fwdzdCorrect = combo.tensor([[[ 0., 60.],
1626+
[ 20., 80.],
1627+
[ 40., 100.]],
1628+
1629+
[[ 10., 70.],
1630+
[ 30., 90.],
1631+
[ 50., 110.]]])
1632+
1633+
Assert.CheckEqual(fwdzCorrect, fwdz)
1634+
Assert.CheckEqual(fwdzdCorrect, fwdzd)
1635+
1636+
// Python:
1637+
(*
1638+
import torch
1639+
revx = torch.tensor([[[ 0., 1.],[ 2., 3.],[ 4., 5.]],[[ 6., 7.],[ 8., 9.],[10., 11.]]], requires_grad=True)
1640+
revz = revx.permute([1,2,0])
1641+
revz.backward(torch.tensor([[[ 0., 1.],[ 2., 3.]],[[ 4., 5.],[ 6., 7.]],[[ 8., 9.],[10., 11.]]]))
1642+
revz
1643+
revx.grad
1644+
*)
1645+
1646+
let revx = combo.tensor([[[ 0., 1.],
1647+
[ 2., 3.],
1648+
[ 4., 5.]],
1649+
1650+
[[ 6., 7.],
1651+
[ 8., 9.],
1652+
[10., 11.]]]).reverseDiff()
1653+
1654+
// Note, this is a rotation
1655+
let revz = revx.permute([1;2;0])
1656+
let revzCorrect = combo.tensor([[[ 0., 6.],
1657+
[ 1., 7.]],
1658+
[[ 2., 8.],
1659+
[ 3., 9.]],
1660+
[[ 4., 10.],
1661+
[ 5., 11.]]])
1662+
1663+
revz.reverse(combo.tensor([[[ 0., 1.],
1664+
A880 [ 2., 3.]],
1665+
[[ 4., 5.],
1666+
[ 6., 7.]],
1667+
[[ 8., 9.],
1668+
[10., 11.]]]))
1669+
let revxd = revx.derivative
1670+
let revxdCorrect = combo.tensor([[[ 0., 2.],
1671+
[ 4., 6.],
1672+
[ 8., 10.]],
1673+
1674+
[[ 1., 3.],
1675+
[ 5., 7.],
1676+
[ 9., 11.]]])
1677+
1678+
Assert.CheckEqual(revzCorrect, revz)
1679+
Assert.CheckEqual(revxdCorrect, revxd)
1680+
15991681
[<Test>]
16001682
member _.TestDerivativeTransposeT () =
16011683
for combo in Combos.AllDevicesAndBackendsFloat32 do

0 commit comments

Comments
 (0)
0