8000 Merge pull request #414 from DiffSharp/gunes/fix28 · DiffSharp/DiffSharp@dd49c85 · GitHub
[go: up one dir, main page]

Skip to content

Commit dd49c85

Browse files
authored
Merge pull request #414 from DiffSharp/gunes/fix28
Add dsharp.argmax
2 parents 2ba66fe + 443a0be commit dd49c85

File tree

2 files changed

+14
-2
lines changed

2 files changed

+14
-2
lines changed

src/DiffSharp.Core/DiffSharp.fs

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -542,10 +542,22 @@ type dsharp =
542542
/// <param name="input">The input tensor.</param>
543543
static member argmax(input:Tensor) = input.argmax()
544544

545+
/// <summary>Returns the indices of the maximum value of all elements in the input tensor.</summary>
546+
/// <param name="input">The input tensor.</param>
547+
/// <param name="dim">The dimension.</param>
548+
/// <param name="keepDim">Whether the output tensor has dim retained or not.</param>
549+
static member argmax(input:Tensor, dim:int, ?keepDim:bool) = input.argmax(dim=dim, ?keepDim=keepDim)
550+
545551
/// <summary>Returns the indices of the minimum value of all elements in the input tensor.</summary>
546552
/// <param name="input">The input tensor.</param>
547553
static member argmin(input:Tensor) = input.argmin()
548554

555+
/// <summary>Returns the indices of the minimum value of all elements in the input tensor.</summary>
556+
/// <param name="input">The input tensor.</param>
557+
/// <param name="dim">The dimension.</param>
558+
/// <param name="keepDim">Whether the output tensor has dim retained or not.</param>
559+
static member argmin(input:Tensor, dim:int, ?keepDim:bool) = input.argmin(dim=dim, ?keepDim=keepDim)
560+
549561
/// <summary>Returns the maximum value of all elements in the input tensor.</summary>
550562
/// <param name="input">The input tensor.</param>
551563
static member max(input:Tensor) = input.max()

tests/DiffSharp.Tests/TestTensor.fs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4359,7 +4359,7 @@ type TestTensor () =
43594359
let t1Argmax = t1.argmax(0)
43604360
let t1ArgmaxCorrect = combo.tensor(2, dtype=Dtype.Int32)
43614361

4362-
let t1ArgmaxKeepDim = t1.argmax(0, keepDim=true)
4362+
let t1ArgmaxKeepDim = dsharp.argmax(t1, 0, keepDim=true)
43634363
let t1ArgmaxKeepDimCorrect = combo.tensor([2], dtype=Dtype.Int32)
43644364

43654365
let t2 = combo.tensor([[1.;4.];[2.;3.]])
@@ -4452,7 +4452,7 @@ type TestTensor () =
44524452
let t1Argmin = t1.argmin(0)
44534453
let t1ArgminCorrect = combo.tensor(1, dtype=Dtype.Int32)
44544454

4455-
let t1ArgminKeepDim = t1.argmin(0, keepDim=true)
4455+
let t1ArgminKeepDim = dsharp.argmin(t1, 0, keepDim=true)
44564456
let t1ArgminKeepDimCorrect = combo.tensor([1], dtype=Dtype.Int32)
44574457

44584458
let t2 = combo.tensor([[1.;4.];[2.;3.]])

0 commit comments

Comments
 (0)
0