10000 np.permutation · SciSharp/TensorFlow.NET@a4a4da9 · GitHub
[go: up one dir, main page]

Skip to content

Commit a4a4da9

Browse files
committed
np.permutation
1 parent 2001619 commit a4a4da9

File tree

19 files changed

+113
-65
lines changed

19 files changed

+113
-65
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ public partial class tensorflow
2121
public MathApi math { get; } = new MathApi();
2222
public class MathApi
2323
{
24+
public Tensor argmax(Tensor input, Axis axis = null, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
25+
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
26+
2427
public Tensor log(Tensor x, string name = null)
2528
=> gen_math_ops.log(x, name);
2629

@@ -539,15 +542,12 @@ public Tensor reduce_mean(Tensor input_tensor, Axis? axis = null, bool keepdims
539542
public Tensor round(Tensor x, string name = null)
540543
=> gen_math_ops.round(x, name: name);
541544

542-
public Tensor cast(Tensor x, TF_DataType dtype = TF_DataType.DtInvalid, string name = null)
545+
public Tensor cast(Tensor x, TF_DataType dtype, string name = null)
543546
=> math_ops.cast(x, dtype, name);
544547

545548
public Tensor cumsum(Tensor x, int axis = 0, bool exclusive = false, bool reverse = false, string name = null)
546549
=> math_ops.cumsum(x, axis: axis, exclusive: exclusive, reverse: reverse, name: name);
547550

548-
public Tensor argmax(Tensor input, int axis = -1, string name = null, int? dimension = null, TF_DataType output_type = TF_DataType.TF_INT64)
549-
=> gen_math_ops.arg_max(input, axis, name: name, output_type: output_type);
550-
551551
public Tensor square(Tensor x, string name = null)
552552
=> gen_math_ops.square(x, name: name);
553553
public Tensor squared_difference(Tensor x, Tensor y, string name = null)

src/TensorFlowNET.Core/Binding.Util.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,8 @@ public static TF_DataType GetDataType(this object data)
549549
return tensors.dtype;
550550
case IEnumerable<Tensor> tensors:
551551
return tensors.First().dtype;
552+
case RefVariable variable:
553+
return variable.dtype;
552554
case ResourceVariable variable:
553555
return variable.dtype;
554556
default:

src/TensorFlowNET.Core/Contexts/ExecuteOpArgs.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using System;
22
using System.Collections.Generic;
3+
using System.Diagnostics;
34
using System.Text;
45
using static Tensorflow.Binding;
56

@@ -11,11 +12,13 @@ public class ExecuteOpArgs
1112
public object[] OpInputArgs { get; set; }
1213
public Dictionary<string, object> OpAttrs { get; set; }
1314

15+
[DebuggerStepThrough]
1416
public ExecuteOpArgs(params object[] inputArgs)
1517
{
1618
OpInputArgs = inputArgs;
1719
}
1820

21+
[DebuggerStepThrough]
1922
public ExecuteOpArgs SetAttributes(object attrs)
2023
{
2124
OpAttrs = ConvertToDict(attrs);

src/TensorFlowNET.Core/NumPy/Axis.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ public static implicit operator Tensor(Axis axis)
5656
=> constant_op.constant(axis);
5757

5858
public override string ToString()
59-
=> $"({string.Join(", ", axis)})";
59+
=> IsScalar ? $"{axis[0]}" : $"({string.Join(", ", axis)})";
6060
}
6161
}
6262

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Runtime.InteropServices;
4+
using System.Text;
5+
6+
namespace Tensorflow.NumPy
7+
{
8+
public class RandomizedImpl
9+
{
10+
[AutoNumPy]
11+
public NDArray permutation(int x) => new NDArray(random_ops.random_shuffle(math_ops.range(0, x)));
12+
13+
[AutoNumPy]
14+
public NDArray permutation(NDArray x) => new NDArray(random_ops.random_shuffle(x));
15+
16+
[AutoNumPy]
17+
public void shuffle(NDArray x)
18+
{
19+
var y = random_ops.random_shuffle(x);
20+
Marshal.Copy(y.BufferToArray(), 0, x.TensorDataPointer, (int)x.bytesize);
21+
}
22+
23+
public NDArray rand(params int[] shape)
24+
=> throw new NotImplementedException("");
25+
26+
public NDArray randint(long x)
27+
=> throw new NotImplementedException("");
28+
}
29+
}

src/TensorFlowNET.Core/NumPy/NDArray.Index.cs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ public NDArray this[NDArray mask]
4545
{
4646
if(mask.dtype == TF_DataType.TF_INT32)
4747
return GetData(mask.ToArray<int>());
48+
else if (mask.dtype == TF_DataType.TF_INT64)
49+
return GetData(mask.ToArray<long>().Select(x => Convert.ToInt32(x)).ToArray());
4850

4951
throw new NotImplementedException("");
5052
}

src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@ public static NDArray argsort(NDArray a, Axis axis = null)
1717
=> new NDArray(math_ops.argmax(a, axis ?? -1));
1818

1919
[AutoNumPy]
20-
public static NDArray unique(NDArray a)
21-
=> throw new NotImplementedException("");
20+
public static (NDArray, NDArray) unique(NDArray a)
21+
{
22+
var(u, indice) = array_ops.unique(a);
23+
return (new NDArray(u), new NDArray(indice));
24+
}
25+
26+
[AutoNumPy]
27+
public static void shuffle(NDArray x) => np.random.shuffle(x);
2228
}
2329
}

src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,6 @@ public partial class np
1313
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis));
1414

1515
[AutoNumPy]
16-
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.arg_max(x, axis));
16+
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis));
1717
}
1818
}

src/TensorFlowNET.Core/Numpy/NDArray.Creation.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ public NDArray(Shape shape, TF_DataType dtype = TF_DataType.TF_DOUBLE)
2525
public NDArray(byte[] bytes, Shape shape, TF_DataType dtype)
2626
: base(bytes, shape, dtype) { NewEagerTensorHandle(); }
2727

28+
public NDArray(long[] value, Shape? shape = null)
29+
: base(value, shape) { NewEagerTensorHandle(); }
30+
2831
public NDArray(IntPtr address, Shape shape, TF_DataType dtype)
2932
: base(address, shape, dtype) { NewEagerTensorHandle(); }
3033

src/TensorFlowNET.Core/Numpy/NDArray.cs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,11 +42,9 @@ public NDIterator<T> AsIterator<T>(bool autoreset = false) where T : unmanaged
4242
public NDArray reshape(Shape newshape) => new NDArray(tf.reshape(this, newshape));
4343
public NDArray astype(TF_DataType dtype) => new NDArray(math_ops.cast(this, dtype));
4444
public NDArray ravel() => throw new NotImplementedException("");
45-
public void shuffle(NDArray nd) => throw new NotImplementedException("");
45+
public void shuffle(NDArray nd) => np.random.shuffle(nd);
4646
public Array ToMuliDimArray<T>() => throw new NotImplementedException("");
4747
public byte[] ToByteArray() => BufferToArray();
48-
public static string[] AsStringArray(NDArray arr) => throw new NotImplementedException("");
49-
5048
public override string ToString() => NDArrayRender.ToString(this);
5149
}
5250
}

0 commit comments

Comments
 (0)
0