8000 tf.zeros for dtype uint8 · SciSharp/TensorFlow.NET@99fc016 · GitHub
[go: up one dir, main page]

Skip to content

Commit 99fc016

Browse files
committed
tf.zeros for dtype uint8
1 parent 45e1365 commit 99fc016

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

src/TensorFlowNET.Core/Data/MnistModelLoader.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ private NDArray ExtractImages(string file, int? limit = null)
123123

124124
bytestream.Read(buf, 0, buf.Length);
125125

126-
var data = np.frombuffer(buf, (num_images, rows * cols), np.@byte);
126+
var data = np.frombuffer(buf, (num_images, rows * cols), np.uint8);
127127
return data;
128128
}
129129
}
@@ -146,7 +146,7 @@ private NDArray ExtractLabels(string file, bool one_hot = false, int num_classes
146146

147147
bytestream.Read(buf, 0, buf.Length);
148148

149-
var labels = np.frombuffer(buf, new Shape(num_items), np.@byte);
149+
var labels = np.frombuffer(buf, new Shape(num_items), np.uint8);
150150

151151
if (one_hot)
152152
return DenseToOneHot(labels, num_classes);

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,9 @@ public static Tensor zeros(Shape shape, TF_DataType dtype = TF_DataType.TF_FLOAT
9191
zeros = constant(0f);
9292
break;
9393
case TF_DataType.TF_INT8:
94+
zeros = constant((sbyte)0);
95+
break;
96+
case TF_DataType.TF_UINT8:
9497
zeros = constant((byte)0);
9598
break;
9699
default:

0 commit comments

Comments
 (0)
0