8000 Merge pull request #400 from DiffSharp/gunes/fix24 · DiffSharp/DiffSharp@a618c43 · GitHub
[go: up one dir, main page]

Skip to content

Commit a618c43

Browse files
authored
Merge pull request #400 from DiffSharp/gunes/fix24
Introduce scatter, improve reverse mode of gather, simplify nllLoss
2 parents ef8184b + e8e6c04 commit a618c43

File tree

9 files changed

+232
-89
lines changed

9 files changed

+232
-89
lines changed

src/DiffSharp.Backends.Reference/Reference.RawTensor.fs

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,23 @@ type RawTensorCPU<'T when 'T : equality and 'T :> scalar>(values: 'T[], shape: S
312312
gather result.Shape [||]
313313
upcast result
314314

315+
override t.ScatterT(dim:int, indices, destinationShape:Shape) =
316+
Shape.checkCanScatter t.Shape dim indices.Shape indices.Dtype destinationShape
317+
let indices = indices :?> RawTensorCPU<int>
318+
let result = t.ZerosLike(destinationShape) :?> RawTensorCPU<'T>
319+
let rec scatter (shape:Shape) externalCoords =
320+
if shape.Length = 1 then
321+
for i=0 to shape.[0]-1 do
322+
let globalCoords = Array.append externalCoords [|i|]
323+
let globalCoordsIndices = Array.copy globalCoords
324+
globalCoordsIndices.[dim] <- indices.[globalCoords]
325+
result.[globalCoordsIndices] <- t.[globalCoords]
326+
else
327+
for i=0 to shape.[0]-1 do
328+
scatter shape.[1..] (Array.append externalCoords [|i|])
329+
scatter t.Shape [||]
330+
upcast result
331+
315332
override t.ViewT(shape:Shape) =
316333
Shape.checkCanView t.Shape shape
317334
let result = Array.copy t.Values

src/DiffSharp.Backends.Torch/Torch.RawTensor.fs

Lines changed: 14 additions & 1 deletion
Original file line numberDi A3D4 ff line numberDiff line change
@@ -373,7 +373,6 @@ type TorchRawTensor(tt: torch.Tensor, shape: Shape, dtype: Dtype, device: Device
373373

374374
override t.GatherT(dim:int, indices) =
375375
Shape.checkCanGather t.Shape dim indices.Shape indices.Dtype
376-
if indices.Dtype <> Dtype.Int32 then opNotSupported "Gather (indices must currently be int32 tensors in DiffSharp" indices.Dtype
377376

378377
// NOTE: DiffSharp currently expects indices as an Int32 tensor, Torch wants Int64
379378
let indices = indices.Cast(Dtype.Int64)
@@ -385,6 +384,20 @@ type TorchRawTensor(tt: torch.Tensor, shape: Shape, dtype: Dtype, device: Device
385384
t.TorchTensor.gather(int64 dim, indices.TorchTensor)
386385
t.MakeLike(res, indices.Shape)
387386

387+
override t.ScatterT(dim:int, indices, destinationShape:Shape) =
388+
Shape.checkCanScatter t.Shape dim indices.Shape indices.Dtype destinationShape
389+
// NOTE: DiffSharp currently expects indices as an Int32 tensor, Torch wants Int64
390+
let indices = indices.Cast(Dtype.Int64)
391+
let res = t.ZerosLike(destinationShape)
392+
// LibTorch Scatter on float16/bfloat16 gives : method_name not implemented for 'BFloat16'
393+
if dtype = Dtype.Float16 || dtype = Dtype.BFloat16 then
394+
let res2 = res.TorchTensor.to_type(torch.ScalarType.Float32)
395+
res2.scatter_(int64 dim, indices.TorchTensor, t.TorchTensor.to_type(torch.ScalarType.Float32)) |> ignore
396+
t.MakeLike(res2.to_type(toTorchType dtype), destinationShape)
397+
else
398+
res.TorchTensor.scatter_(int64 dim, indices.TorchTensor, t.TorchTensor) |> ignore
399+
res
400+
388401
override t.ViewT(shape:Shape) =
389402
Shape.checkCanView t.Shape shape
390403
t.MakeLike(tt.reshape(toTorchShape shape), shape=shape) // Use Reshape instead of View to ensure underlying non-contiguous libtorch tensors can be viewed. Internally Reshape uses View if possible, otherwise it copies data to a contiguous tensor and then views.

src/DiffSharp.Core/DiffSharp.Compose.fs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,9 @@ type dsharp with
160160
/// <summary>TBD</summary>
161161
static member gather(dim:int, indices:Tensor) = fun (a:Tensor) -> a.gather(dim, indices)
162162

163+
/// <summary>TBD</summary>
164+
static member scatter(dim:int, indices:Tensor, destinationShape:seq<int>) = fun (a:Tensor) -> a.scatter(dim, indices, destinationShape)
165+
163166
/// <summary>TBD</summary>
164167
static member transpose(dim0:int, dim1:int) = fun (a:Tensor) -> a.transpose(dim0, dim1)
165168

src/DiffSharp.Core/DiffSharp.fs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -748,6 +748,13 @@ type dsharp =
748748
/// <param name="indices">The the indices of elements to gather.</param>
749749
static member gather(input:Tensor, dim:int, indices:Tensor) = input.gather(dim, indices)
750750

751+
/// <summary>Gathers values along an axis specified by dim.</summary>
752+
/// <param name="input">The input tensor.</param>
753+
/// <param name="dim">The axis along which to index.</param>
754+
/// <param name="indices">The the indices of elements to gather.</param>
755+
/// <param name="destinationShape">The destination shape.</param>
756+
static member scatter(input:Tensor, dim:int, indices:Tensor, destinationShape:seq<int>) = input.scatter(dim, indices, destinationShape)
757+
751758
/// <summary>Returns the original tensor with its dimensions permuted.</summary>
752759
/// <param name="input">The input tensor.</param>
753760
/// <param name="permutation">The desired ordering of dimensions.</param>

src/DiffSharp.Core/RawTensor.fs

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -359,12 +359,15 @@ type RawTensor() =
359359
/// are equal up to the given tolerances.
360360
abstract AllClose: t2: RawTensor * relativeTolerance: float * absoluteTolerance: float -> bool
361361

362-
/// Returns a boolean tensor with values constrained by the corresponding elements in the low/high tensors.
362+
/// Returns a tensor with values constrained by the corresponding elements in the low/high tensors.
363363
abstract ClampT: low: RawTensor * high: RawTensor -> RawTensor
364364

365-
/// Returns a boolean tensor selecting the given indices from the given dimension and stacking those in the order specified.
365+
/// Returns a tensor selecting the given indices from the given dimension and stacking those in the order specified.
366366
abstract GatherT: dim: int * indices: RawTensor -> RawTensor
367367

368+
/// Returns a tensor with given destination shape where values are copied from the current tensor to locations specified by the dimension and indices.
369+
abstract ScatterT: dim: int * indices: RawTensor * destinationShape: Shape -> RawTensor
370+
368371
/// Returns a boolean tensor comparing each element pairwise with the corresponding element in <c>t2</c>
369372
abstract LtTT: t2: RawTensor -> RawTensor
370373

src/DiffSharp.Core/Shape.fs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -523,9 +523,17 @@ module rec Shape =
523523

524524
/// Checks if the given shape is appropriate for a gather operation.
525525
let checkCanGather (shape: Shape) (dim: int) (indicesShape: Shape) (indicesDtype:Dtype) =
526-
if shape.Length <> indicesShape.Length then failwithf "Expecting tensorShape (%A) and indicesShape (%A) to have the same number of dimensions" shape indicesShape
527-
if dim < 0 || dim > shape.Length-1 then failwithf "Expecting 0<= dim (%A) < tensorShape.Length (%A)" dim shape.Length
528-
if indicesShape.[dim] < 1 then failwithf "Expecting indicesShape.[dim] (%A) >= 1" indicesShape.[dim]
526+
if shape.Length <> indicesShape.Length then failwithf "Expecting tensor (%A) and indices (%A) to have the same number of dimensions" shape indicesShape
527+
if dim < 0 || dim > shape.Length-1 then failwithf "Expecting 0<= dim (%A) < tensor dim (%A)" dim shape.Length
528+
if indicesShape.[dim] < 1 then failwithf "Expecting indices shape at dim %A (%A) >= 1" dim indicesShape.[dim]
529+
if indicesDtype <> Dtype.Int32 then failwithf "Expecting indices to have type %A" Dtype.Int32
530+
531+
/// Checks if the given shape is appropriate for a scatter operation.
532+
let checkCanScatter (shape: Shape) (dim: int) (indicesShape: Shape) (indicesDtype:Dtype) (destinationShape: Shape)=
533+
if shape.Length <> indicesShape.Length then failwithf "Expecting tensor (%A) and indices (%A) to have the same number of dimensions" shape indicesShape
534+
if shape.Length <> destinationShape.Length then failwithf "Expecting tensor (%A) and destination (%A) to have the same number of dimensions" shape destinationShape
535+
if not (contains shape indicesShape) then failwithf "Expecting tensor shape (%A) to contain indices shape (%A)" shape indicesShape
536+
if dim < 0 || dim > shape.Length-1 then failwithf "Expecting 0<= dim (%A) < tensor dim (%A)" dim shape.Length
529537
if indicesDtype <> Dtype.Int32 the E377 n failwithf "Expecting indices to have type %A" Dtype.Int32
530538

531539
/// Checks if the given shape is appropriate for a view operation.

src/DiffSharp.Core/Tensor.fs

Lines changed: 40 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -1789,13 +1789,28 @@ type Tensor =
17891789
/// <param name="dim">The axis along which to index.</param>
17901790
/// <param name="indices">The the indices of elements to gather.</param>
17911791
member a.gather(dim:int, indices:Tensor) =
1792+
let dim = Shape.completeDim a.dim dim // Handles -1 semantics
17921793
Shape.checkCanGather a.shape dim indices.shape indices.dtype
17931794
let inline fRaw(a:RawTensor) = a.GatherT(dim, indices.primalRaw)
17941795
let inline fTensor(a:Tensor) = a.gather(dim, indices)
17951796
let inline dfFwd(ap,ad:Tensor,fp) = ad.gather(dim, indices)
17961797
let inline dfRev(a) = GatherT(a, dim, indices)
17971798
Tensor.OpUnary(a, fRaw, fTensor, dfFwd, dfRev)
17981799

1800+
/// <summary>Scatter values along an axis specified by dim.</summary>
1801+
/// <param name="dim">The axis along which to index.</param>
1802+
/// <param name="indices">The the indices of elements to gather.</param>
1803+
/// <param name="destinationShape">The destination shape.</param>
1804+
member a.scatter(dim:int, indices:Tensor, destinationShape:seq<int>) =
1805+
let destinationShape = destinationShape|>Shape.create
1806+
let dim = Shape.completeDim a.dim dim // Handles -1 semantics
1807+
Shape.checkCanScatter a.shape dim indices.shape indices.dtype destinationShape
1808+
let inline fRaw(a:RawTensor) = a.ScatterT(dim, indices.primalRaw, destinationShape)
1809+
let inline fTensor(a:Tensor) = a.scatter(dim, indices, destinationShape)
1810+
let inline dfFwd(ap,ad:Tensor,fp) = ad.scatter(dim, indices, destinationShape)
1811+
let inline dfRev(a) = ScatterT(a, dim, indices)
1812+
Tensor.OpUnary(a, fRaw, fTensor, dfFwd, dfRev)
1813+
17991814
/// <summary>Returns a new tensor with the same data as the self tensor but of a different shape.</summary>
18001815
/// <remarks>
18011816
/// The returned tensor shares the same data and must have the same number of elements, but may have a different size.
@@ -2224,44 +2239,26 @@ type Tensor =
22242239
if target.shape.[0] <> n then failwithf "Expecting either: input with shape (N,C) and target with shape (N); or input with shape (N,C,d1,d2,...,dk) and target with shape (N,d1,d2,...,dk). Received input.shape %A and target.shape %A" input.shape target.shape
22252240
if d <> target.shape.[1..] then failwithf "Expecting either: input with shape (N,C) and target with shape (N); or input with shape (N,C,d1,d2,...,dk) and target with shape (N,d1,d2,...,dk). Received input.shape %A and target.shape %A" input.shape target.shape
22262241
n, c, d
2227-
let mutable weightSpecified = false
2228-
let mutable ww = input.zeroLike()
2229-
match weight with
2230-
| Some w -> ww <- w; weightSpecified <- true
2231-
| None -> ww <- input.onesLike([classes]); weightSpecified <- false
2232-
let weight = ww
2242+
let target = target.int()
2243+
let weightSpecified, weight =
2244+
match weight with
2245+
| Some w ->
2246+
if w.dim <> 1 || w.shape.[0] <> classes then failwithf "Expecting weight with shape (C). Received weight.shape %A" w.shape
2247+
let vv = Array.create input.dim 1
2248+
vv.[1] <- classes
2249+
true, w.view(vv).expandAs(input).gather(1, target.unsqueeze(1)).squeeze(1)
2250+
| None -> false, input.zeroLike()
22332251
let reduction = defaultArg reduction "mean"
22342252
if not (reduction = "none" || reduction = "mean" || reduction = "sum") then failwithf "Expecting reduction (%A) to be one of (none, mean, sum)" reduction
2235-
if input.dim = 2 then
2236-
let mutable wacc = input.zeroLike()
2237-
let l = Array.init n (fun i ->
2238-
let target = int target.[i]
2239-
let w = weight.[target]
2240-
wacc <- wacc + w
2241-
-w*input.[i, target]) |> Tensor.stack
2242-
if reduction = "none" then
2243-
l
2244-
elif reduction = "mean" then
2245-
if weightSpecified then l.sum()/wacc else l.mean()
2246-
else // reduction = "sum"
2247-
l.sum()
2248-
else
2249-
let mutable wacc = input.zeroLike()
2250-
let l = Array.init n (fun i ->
2251-
let aa = input.[i].view([classes; -1])
2252-
let bb = target.[i].view(-1)
2253-
let l = Array.init bb.nelement (fun j ->
2254-
let target = int bb.[j]
2255-
let w = weight.[target]
2256-
wacc <- wacc + w
2257-
-w*aa.[target, j]) |> Tensor.stack
2258-
l.view(d)) |> Tensor.stack
2259-
if reduction = "none" then
2260-
l
2261-
elif reduction = "mean" then
2262-
if weightSpecified then l.sum()/wacc else l.mean()
2263-
else // reduction = "sum"
2264-
l.sum()
2253+
let mutable l = input.gather(1, target.unsqueeze(1)).squeeze(1).neg()
2254+
if weightSpecified then
2255+
l <- l * weight
2256+
if reduction = "none" then
2257+
l
2258+
elif reduction = "mean" then
2259+
if weightSpecified then l.sum()/weight.sum() else l.mean()
2260+
else // reduction = "sum"
2261+
l.sum()
22652262

22662263
/// <summary>Add zero padding to each side of a tensor</summary>
22672264
/// <param name="paddings">The implicit paddings on corresponding sides of the input.</param>
@@ -2770,6 +2767,7 @@ type Tensor =
27702767
| CatTs(a,_) -> reset (List.append (a |> List.ofSeq) tt)
27712768
| SplitT(a,_,_,_) -> reset (a::tt)
27722769
| GatherT(a,_,_) -> reset (a::tt)
2770+
| ScatterT(a,_,_) -> reset (a::tt)
27732771
| PermuteT(a,_) -> reset (a::tt)
27742772
| TransposeT(a,_,_) -> reset (a::tt)
27752773
| TransposeT2(a) -> reset (a::tt)
@@ -2915,31 +2913,19 @@ type Tensor =
29152913
| StackTs(a,dim) ->
29162914
push (List.append (Array.zip (td.unstack(dim)) a |> Array.map check |> Array.toList) tt)
29172915
| UnstackT(a,dim,i) ->
2918-
if a.derivative.dim = 0 then a.derivative <- a.zerosLike() + a.derivative
2916+
if a.derivative.dim = 0 then a.derivative <- a.derivative.expandAs(a)
29192917
a.derivative <- a.derivative.addSlice(Array.init a.dim (fun j -> if j=dim then i else 0), td.unsqueeze(dim))
29202918
push (check(a.zeroLike(), a) :: tt)
29212919
| CatTs(a, dim) ->
29222920
let sizes = a |> Array.map (fun x -> x.shape.[dim])
29232921
push (List.append (Array.zip (td.split(sizes, dim=dim)) a |> Array.map check |> Array.toList) tt)
29242922
| SplitT(a,sizes,dim,i) ->
2925-
if a.derivative.dim = 0 then a.derivative <- a.zerosLike() + a.derivative
2923+
if a.derivative.dim = 0 then a.derivative <- a.derivative.expandAs(a)
29262924
let locs = (0,sizes) ||> Array.scan (+)
29272925
a.derivative <- a.derivative.addSlice(Array.init a.dim (fun j -> if j=dim then locs.[i] else 0), td)
29282926
push (check(a.zeroLike(), a) :: tt)
2929-
| GatherT(a,dim,indices) ->
2930-
// TODO: The following is a minimal correct implementation. Faster and more memory efficient implementations should be possible.
2931-
let tflat = td.flatten()
2932-
let iflat = indices.flatten()
2933-
if a.derivative.dim = 0 then a.derivative <- a.zerosLike() + a.derivative
2934-
for i=0 to tflat.nelement-1 do
2935-
let mutable t = tflat.[i]
2936-
for k=0 to a.dim-1 do
2937-
t <- t.unsqueeze(0)
2938-
let j = iflat.[i].toScalar() :?> int
2939-
let loc = flatIndexToIndex a.shape i
2940-
loc.[dim] <- j
2941-
a.derivative <- a.derivative.addSlice(loc, t)
2942-
push (check(a.zeroLike(), a) :: tt)
2927+
| GatherT(a,dim,indices) -> push (check(td.scatter(dim, indices, a.shape), a) :: tt)
2928+
| ScatterT(a,dim,indices) -> push (check(td.gather(dim, indices), a) :: tt)
29432929
| PermuteT(a, inversePermutation) -> push (check(td.permute(inversePermutation), a) :: tt)
29442930
| TransposeT(a, dim0, dim1) -> push (check(td.transpose(dim0, dim1), a) :: tt)
29452931
| TransposeT2(a) -> push (check(td.transpose(), a) :: tt)
@@ -2952,7 +2938,7 @@ type Tensor =
29522938
| ClampT(a, mask) -> push (check(td * mask, a) :: tt)
29532939
| SliceT(a,bounds) ->
29542940
// TODO: a.zerosLike() below is to handle non-scalar TensorRs with a scalar derivative Tensor(0.) (representing the initialization before accumulation). This is correct but can be changed to eliminate the extra op.
2955-
if a.derivative.dim = 0 then a.derivative <- a.zerosLike() + a.derivative
2941+
if a.derivative.dim = 0 then a.derivative <- a.derivative.expandAs(a)
29562942
a.derivative <- a.derivative.addSlice(boundsToLocation bounds, td.view(boundsToShape bounds))
29572943
push (check(a.zeroLike(), a) :: tt)
29582944
| AddTTSlice(a,location,b) ->
@@ -3061,6 +3047,7 @@ and TensorOp =
30613047
| SplitT of Tensor * int[] * dim:int * i:int
30623048
| SliceT of Tensor * int[,]
30633049
| GatherT of Tensor * int * Tensor
3050+
| ScatterT of Tensor * int * Tensor
30643051
| PermuteT of Tensor * inversePermutation: int[]
30653052
| TransposeT of Tensor * int * int
30663053
| TransposeT2 of Tensor

0 commit comments

Comments
 (0)
0