8000 np.randint · SciSharp/TensorFlow.NET@9c5692b · GitHub
[go: up one dir, main page]

Skip to content

Commit 9c5692b

Browse files
committed
np.randint
1 parent 432ae20 commit 9c5692b

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ public Tensor uniform(Shape shape,
7171
string name = null)
7272
{
7373
if (dtype.is_integer())
74-
return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, dtype, seed, name);
74+
return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, seed, name);
7575
else
7676
return random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
7777
}

src/TensorFlowNET.Core/NumPy/Implementation/RandomizedImpl.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,18 @@ public void shuffle(NDArray x)
2323
public NDArray rand(params int[] shape)
2424
=> throw new NotImplementedException("");
2525

26+
[AutoNumPy]
2627
public NDArray randint(int low, int? high = null, Shape size = null, TF_DataType dtype = TF_DataType.TF_INT32)
27-
=> throw new NotImplementedException("");
28+
{
29+
if(high == null)
30+
{
31+
high = low;
32+
low = 0;
33+
}
34+
size = size ?? Shape.Scalar;
35+
var tensor = random_ops.random_uniform_int(shape: size, minval: low, maxval: (int)high);
36+
return new NDArray(tensor);
37+
}
2838

2939
public NDArray normal(float loc = 0.0f, float scale = 1.0f, Shape size = null)
3040
=> throw new NotImplementedException("");

src/TensorFlowNET.Core/Operations/random_ops.cs

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,6 @@ public static Tensor random_uniform(int[] shape,
9494
public static Tensor random_uniform_int(int[] shape,
9595
int minval = 0,
9696
int maxval = 1,
97-
TF_DataType dtype = TF_DataType.TF_FLOAT,
9897
int? seed = null,
9998
string name = null)
10099
{
@@ -103,8 +102,8 @@ public static Tensor random_uniform_int(int[] shape,
103102
name = scope;
104103
var (seed1, seed2) = random_seed.get_seed(seed);
105104
var tensorShape = tensor_util.shape_tensor(shape);
106-
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
107-
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
105+
var minTensor = ops.convert_to_tensor(minval, name: "min");
106+
var maxTensor = ops.convert_to_tensor(maxval, name: "max");
108107
return gen_random_ops.random_uniform_int(tensorShape, minTensor, maxTensor, seed: seed1, seed2: seed2);
109108
});
110109
}

0 commit comments

Comments
 (0)
0